import unittest
import time
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item, run_schedule

class TestFusionOp(unittest.TestCase):
  def test_contiguous_add(self):
    def test(contig=False):
      bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
      x = bt.permute(1,0)
      if contig: x = x.contiguous()
      return (x.permute(1,0) + bt).data()
    assert test() == test(True)

  def test_expand_fuse(self):
    bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
    out = (bt*2).expand(10,10).sum(1)
    sched = create_schedule([out.lazydata])
    run_schedule(sched)
    outd = out.tolist()
    assert all(x == 20.0 for x in outd)

  def test_recursive_add(self):
    st = time.perf_counter()
    a = Tensor([1,2,3,4])
    for _ in range(24): a = a + a
    sched = create_schedule([a.lazydata])
    ei = lower_schedule_item(sched[-1])
    self.assertLess(time.perf_counter()-st, 2.0)
    assert len(ei.prg.p.src.splitlines()) < 250

  def test_recursive_add_cmp(self):
    st = time.perf_counter()
    a = Tensor([1,2,3,4])
    for _ in range(24): a = a + a
    sched1 = create_schedule([a.lazydata])
    b = Tensor([1,2,3,4])
    for _ in range(24): b = b + b
    sched2 = create_schedule([b.lazydata])
    c = Tensor([1,2,3,4])
    for _ in range(23): c = c + c
    sched3 = create_schedule([c.lazydata])
    self.assertEqual(sched1[-1].ast, sched2[-1].ast)
    with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
    self.assertLess(time.perf_counter()-st, 2.0)

  def test_recursive_pad(self):
    st = time.perf_counter()
    val = 1.0
    a = Tensor(val).realize()
    for _ in range(24): a = Tensor.stack(a, a)[0]
    r = a.item()
    self.assertEqual(r, val)
    self.assertLess(time.perf_counter()-st, 2.0)

  def test_recursive_reshape(self):
    st = time.perf_counter()
    a = Tensor.empty(32, 32).realize()
    b = Tensor.empty(16, 2).realize()
    r = a.sum(1)
    for _ in range(24): r = r.reshape(16, 2) + b
    sched = r.schedule()
    self.assertEqual(len(sched), 1)
    self.assertLess(time.perf_counter()-st, 2.0)

if __name__ == '__main__':
  unittest.main(verbosity=2)