import unittest
from tinygrad import Device
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, CI

def multidevice_test(fxn):
  exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
  def ret(self):
    for device in Device._devices:
      if device in ["REMOTE", "DISK", "NPY", "FAKE", "DSP", "NULL"]: continue
      if not CI: print(device)
      if device in exclude_devices:
        if not CI: print(f"WARNING: {device} test is excluded")
        continue
      with self.subTest(device=device):
        try:
          Device[device]
        except Exception:
          if not CI: print(f"WARNING: {device} test isn't running")
          continue
        fxn(self, device)
  return ret

class TestExample(unittest.TestCase):
  @multidevice_test
  def test_convert_to_cpu(self, device):
    a = Tensor([[1,2],[3,4]], device=device)
    assert a.numpy().shape == (2,2)
    b = a.to("CPU")
    assert b.numpy().shape == (2,2)

  @multidevice_test
  def test_2_plus_3(self, device):
    a = Tensor([2], device=device)
    b = Tensor([3], device=device)
    result = a + b
    print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
    assert result.numpy()[0] == 5.

  @multidevice_test
  def test_example_readme(self, device):
    x = Tensor.eye(3, device=device, requires_grad=True)
    y = Tensor([[2.0,0,-2.0]], device=device, requires_grad=True)
    z = y.matmul(x).sum()
    z.backward()

    x.grad.numpy()  # dz/dx
    y.grad.numpy()  # dz/dy

    assert x.grad.device == device
    assert y.grad.device == device

  @multidevice_test
  def test_example_matmul(self, device):
    try:
      Device[device]
    except Exception:
      print(f"WARNING: {device} test isn't running")
      return

    x = Tensor.eye(64, device=device, requires_grad=True)
    y = Tensor.eye(64, device=device, requires_grad=True)
    z = y.matmul(x).sum()
    z.backward()

    x.grad.numpy()  # dz/dx
    y.grad.numpy()  # dz/dy

    assert x.grad.device == device
    assert y.grad.device == device

if __name__ == '__main__':
  unittest.main()