openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.

58 lines
2.2 KiB

#!/usr/bin/env python
import unittest
from tinygrad import Tensor
from tinygrad.device import Device, Compiler
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context
class TestDevice(unittest.TestCase):
def test_canonicalize(self):
self.assertEqual(Device.canonicalize(None), Device.DEFAULT)
self.assertEqual(Device.canonicalize("CPU"), "CPU")
self.assertEqual(Device.canonicalize("cpu"), "CPU")
self.assertEqual(Device.canonicalize("GPU"), "GPU")
self.assertEqual(Device.canonicalize("GPU:0"), "GPU")
self.assertEqual(Device.canonicalize("gpu:0"), "GPU")
self.assertEqual(Device.canonicalize("GPU:1"), "GPU:1")
self.assertEqual(Device.canonicalize("gpu:1"), "GPU:1")
self.assertEqual(Device.canonicalize("GPU:2"), "GPU:2")
self.assertEqual(Device.canonicalize("disk:/dev/shm/test"), "DISK:/dev/shm/test")
self.assertEqual(Device.canonicalize("disk:000.txt"), "DISK:000.txt")
def test_getitem_not_exist(self):
with self.assertRaises(ModuleNotFoundError):
Device["TYPO"]
def test_lowercase_canonicalizes(self):
device = Device.DEFAULT
Device.DEFAULT = device.lower()
self.assertEqual(Device.canonicalize(None), device)
Device.DEFAULT = device
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()
class TestCompiler(unittest.TestCase):
def test_compile_cached(self):
diskcache_put("key", "123", None) # clear cache
getenv.cache_clear()
with Context(DISABLE_COMPILER_CACHE=0):
self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123"))
self.assertEqual(diskcache_get("key", "123"), str.encode("123"))
def test_compile_cached_disabled(self):
diskcache_put("disabled_key", "123", None) # clear cache
getenv.cache_clear()
with Context(DISABLE_COMPILER_CACHE=1):
self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123"))
self.assertIsNone(diskcache_get("disabled_key", "123"))
def test_device_compile(self):
getenv.cache_clear()
with Context(DISABLE_COMPILER_CACHE=1):
a = Tensor([0.,1.], device=Device.DEFAULT).realize()
(a + 1).realize()
if __name__ == "__main__":
unittest.main()