openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

53 lines
2.2 KiB

#!/usr/bin/env python
import unittest
from unittest.mock import patch
import os
from tinygrad import Tensor
from tinygrad.device import Device, Compiler
from tinygrad.helpers import diskcache_get, diskcache_put, getenv
class TestDevice(unittest.TestCase):
def test_canonicalize(self):
self.assertEqual(Device.canonicalize(None), Device.DEFAULT)
self.assertEqual(Device.canonicalize("CPU"), "CPU")
self.assertEqual(Device.canonicalize("cpu"), "CPU")
self.assertEqual(Device.canonicalize("GPU"), "GPU")
self.assertEqual(Device.canonicalize("GPU:0"), "GPU")
self.assertEqual(Device.canonicalize("gpu:0"), "GPU")
self.assertEqual(Device.canonicalize("GPU:1"), "GPU:1")
self.assertEqual(Device.canonicalize("gpu:1"), "GPU:1")
self.assertEqual(Device.canonicalize("GPU:2"), "GPU:2")
self.assertEqual(Device.canonicalize("disk:/dev/shm/test"), "DISK:/dev/shm/test")
self.assertEqual(Device.canonicalize("disk:000.txt"), "DISK:000.txt")
def test_getitem_not_exist(self):
with self.assertRaises(ModuleNotFoundError):
Device["TYPO"]
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()
class TestCompiler(unittest.TestCase):
def test_compile_cached(self):
diskcache_put("key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123"))
self.assertEqual(diskcache_get("key", "123"), str.encode("123"))
def test_compile_cached_disabled(self):
diskcache_put("disabled_key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123"))
self.assertIsNone(diskcache_get("disabled_key", "123"))
def test_device_compile(self):
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}):
a = Tensor([0.,1.], device=Device.DEFAULT).realize()
(a + 1).realize()
if __name__ == "__main__":
unittest.main()