openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

52 lines
1.9 KiB

#!/usr/bin/env python
import unittest
import torch
import numpy as np
from tinygrad.helpers import getenv, CI
from tinygrad.tensor import Tensor
from tinygrad.device import Device
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
MOCKGPU = getenv("MOCKGPU")
@unittest.skipIf(Device.DEFAULT not in ["METAL", "CUDA"] or MOCKGPU, f"no support on {Device.DEFAULT}")
class TestInterop(unittest.TestCase):
def setUp(self):
if Device.DEFAULT == "CUDA": self.torch_device = "cuda"
elif Device.DEFAULT == "METAL": self.torch_device = "mps"
def test_torch_interop(self):
inp = torch.rand(2, 2, 3, device=torch.device(self.torch_device))
if self.torch_device == "mps": torch.mps.synchronize()
else: torch.cuda.synchronize()
tg_data = Tensor.from_blob(inp.data_ptr(), inp.shape, dtype=_from_torch_dtype(inp.dtype))
tg_out = tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140
tg_res = tg_out.numpy()
if self.torch_device == "mps" and CI:
# MPS backend out of memory: https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773
# Calculate expected value on cpu.
inp = inp.cpu()
torch_out = inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140
np.testing.assert_allclose(tg_res, torch_out.cpu().numpy(), atol=1e-5, rtol=1e-5)
def test_torch_interop_write(self):
tg_data = Tensor.randn((4, 4), device=Device.DEFAULT)
out = torch.empty(4, 4, device=torch.device(self.torch_device), dtype=_to_torch_dtype(tg_data.dtype))
tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype))
tg_out.assign(tg_data).realize()
Device[Device.DEFAULT].synchronize()
torch_out_np = out.cpu().numpy()
np.testing.assert_allclose(tg_data.numpy(), torch_out_np, atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
unittest.main()