from typing import Optional, Any import unittest, math import numpy as np from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View # noqa F401 from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context, Timing from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401 from tinygrad.spec import spec from tinygrad.renderer import ProgramSpec from tinygrad.engine.schedule import fix_kernel_ops from tinygrad.engine.realize import CompiledRunner, get_kernel from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.codegen.symbolic import sym from tinygrad.device import is_dtype_supported from tinygrad.codegen.kernel import Kernel, Opt, OptOps def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) def _uops_to_prg(uops_list): uops = linearize_uop(full_graph_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer)) src = Device[Device.DEFAULT].renderer.render(uops) has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops, global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None)) def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(src), arg)) return uops[-1] def _test_single_value(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts)) alu = uop(uops, op, output_dtype, loads) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg([out]) prg.exec([buf]+buf2) ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, op, output_dtype, loads) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] def _test_uops_result(output_dtype, uops, res): # uops = [] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) # res = output_fn(uops) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] class TestUOps(unittest.TestCase): def _equal(self, v1, v2): assert isinstance(v2, (float, int, bool)) if isinstance(v2, float): np.testing.assert_allclose(v1, v2, rtol=2e-7) else: np.testing.assert_equal(v1, v2) def _test_uop_fxn(self, op, fxn, dts=(dtypes.float32, )): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: a = dtypes.as_const(a, dts[0]) self._equal(f([a], op, dts), fxn(a)) def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): a = dtypes.as_const(a, dts[0]) b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1]) self._equal(f([a,b], op, dts), fxn(a,b)) def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0, 1]: for b in [-3.0, 3.0]: for c in [-4.0, 4.0]: a = dtypes.as_const(a, dts[0]) b = dtypes.as_const(b, dts[1]) c = dtypes.as_const(c, dts[2]) self._equal(f([a,b,c], op, dts), fxn(a,b,c)) class TestFloatUOps(TestUOps): @unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop') def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a)) @unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop') def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) @unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop') def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a)) def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf')) def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b) def test_mul(self): self._test_bop_fxn(Ops.MUL, lambda a,b: a*b) def test_max(self): self._test_bop_fxn(Ops.MAX, lambda a,b: max(a,b)) def test_cmplt(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: a>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts") def test_shl_int32(self): self._test_bop_fxn(Ops.SHL, lambda a,b: int(a)<= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is Ops.IF) endif = next(u for u in uops if u.op is Ops.ENDIF) assert endif.src[0] is if_uop gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)]) self.assertEqual(len(gated_uops), 1) self.assertIs(gated_uops[-1].op, Ops.STORE) def test_gate_some_stores(self): gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0 * UOp.const(dtypes.int, 2) idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is Ops.IF) endif = next(u for u in uops if u.op is Ops.ENDIF) assert endif.src[0] is if_uop gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)]) self.assertEqual(len(gated_uops), 1) self.assertIs(gated_uops[-1].op, Ops.STORE) # scaled down version of TestLinearizerDumb.test_unmerged_ifs def test_merge_ifs_alt(self): gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) gate = gidx0= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) ifs = [u for u in uops if u.op is Ops.IF] endifs = [u for u in uops if u.op is Ops.ENDIF] self.assertEqual(len(ifs), 1) self.assertEqual(len(endifs), 1) gated_uops = tuple(uops[uops.index(ifs[0])+1:uops.index(endifs[0])]) self.assertEqual(len(gated_uops), 2) for x in gated_uops: self.assertIs(x.op, Ops.STORE) class TestLocalAccess(unittest.TestCase): # NOTE: this is failing on METAL CI, no idea why. Works locally. @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "failing only in CI") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_basic(self): uops = [] smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, local=True), (), 'smem') st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) # NOTE: webgpu specific, since only webgpu performs bitpacking for uchar @unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type") def test_local_packed(self): uops = [] smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, local=True), (), 'smem') st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") @unittest.skip("tinygrad doesn't support this behavior") def test_local_indirect(self): uops = [] smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, local=True), (), 'smem') st1 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), uop(uops, Ops.CONST, dtypes.int32, (), 2))) st2 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 2)), uop(uops, Ops.CONST, dtypes.int32, (), 42))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2)) ofs = uop(uops, Ops.LOAD, dtypes.int32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), barr)) sres = uop(uops, Ops.LOAD, dtypes.int32, (smem.index(ofs),)) self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42) @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends") class TestAssembly(unittest.TestCase): def test_bitshift_left(self): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) c1 = UOp(Ops.CONST, dtypes.int, (), 2) c2 = UOp(Ops.CONST, dtypes.int, (), 3) l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) a1 = UOp(Ops.MUL, dtypes.int, (l1, c1)) a2 = UOp(Ops.MUL, dtypes.int, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render(uops) ops = [x.op for x in uops] self.assertIn(Ops.SHL, ops) self.assertIn(Ops.MUL, ops) def test_bitshift_right(self): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0) c1 = UOp(Ops.CONST, dtypes.uint, (), 2) c2 = UOp(Ops.CONST, dtypes.uint, (), 3) l1 = UOp(Ops.LOAD, dtypes.uint, (g1.index(c1),)) a1 = UOp(Ops.IDIV, dtypes.uint, (l1, c1)) a2 = UOp(Ops.IDIV, dtypes.uint, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render(uops) ops = [x.op for x in uops] self.assertIn(Ops.SHR, ops) self.assertIn(Ops.IDIV, ops) def test_mulacc_unrolled(self): # test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3 # is not acc = acc + (a0*b0 + a1*b1 + a2*b2 + a3*b3) a = Tensor.empty(1024) b = Tensor.empty(1024) c = (a*b).sum() k = Kernel(c.schedule()[-1].ast) k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) uops = k.linearize().uops self.assertEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4) class TestUOpMethod(unittest.TestCase): @unittest.skip("uops lt no longer ordered") def test_compare_alu_same_src_different_arg(self): a = UOp(Ops.CONST, dtypes.float, (), 2.0) b = UOp(Ops.CONST, dtypes.float, (), 3.0) add = UOp(Ops.ADD, dtypes.float, (a, b)) mul = UOp(Ops.MUL, dtypes.float, (a, b)) assert (add < mul) or (mul < add), "add and mul with same src should have an order" def test_uop_variables(self): a = UOp.variable("a", 1, 10) uop_var = Tensor(a.bind(1)) st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1))) _, var_vals = (uop_var+st_var).schedule_with_vars() self.assertEqual(len(var_vals), 1) self.assertEqual(list(var_vals)[0], a) def test_const_factor(self): gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8)) self.assertEqual(UOp(Ops.CONST, dtypes.int, (), 17).const_factor(), 17) self.assertEqual(gidx0.const_factor(), 1) self.assertEqual((gidx0*3).const_factor(), 3) self.assertEqual((gidx0*3+6).const_factor(), 3) self.assertEqual((gidx0*3+1).const_factor(), 1) def test_replace(self): x = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) self.assertIs(x.replace(arg=None).arg, None) with self.assertRaises(AssertionError): x.replace(field="a") def test_device(self): x = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 1)), ShapeTracker.from_shape(())) self.assertEqual(x.device, Device.DEFAULT) # NOTE: CONST doesn't have device buffer, const = x.src self.assertEqual(buffer.device, Device.DEFAULT) self.assertEqual(const._device, None) with self.assertRaises(AssertionError): const.device class TestUOpStr(unittest.TestCase): def test_uop_str(self): a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0) for _ in range(20): a = a + a assert len(str(a)) < 10_000, "exponential string growth" assert str(eval(str(a))) == str(a) t = Tensor.arange(10) t = t + t * Tensor.rand(10) # nice big complicated uop with Context(NOOPT=1): sink = UOp(Ops.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],)) self.assertEqual(sink, eval(str(sink))) def test_vectorized_str(self): vec = UOp(Ops.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4))) assert str(eval(str(vec))) == str(vec) def test_device_arg(self): device = UOp(Ops.DEVICE, arg="GPU") assert str(eval(str(device))) == str(device) def test_reduceop_arg(self): sum_uop = Tensor.empty(32, 32).sum().lazydata assert str(eval(str(sum_uop))) == str(sum_uop) @unittest.skip("uop no longer has order like this") class TestIndexingOrdering(unittest.TestCase): # NOTE: these tests skip type_verify since they add dtype to STORE @unittest.expectedFailure def test_simple_order(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = to_uops_list([st1, st0], skip_check=True) stores = [st for st in uops if st.op is Ops.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" @unittest.expectedFailure def test_ordering_multi_output(self): buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) st0_0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1_0 = UOp(Ops.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st0_1 = UOp(Ops.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1_1 = UOp(Ops.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True) stores = [st for st in uops if st.op is Ops.STORE] print("\n".join(map(str, stores))) # buf0 stores come first self.assertEqual(stores[0].src[0].arg, stores[1].src[0].arg) # buf1 stores come next self.assertEqual(stores[2].src[0].arg, stores[3].src[0].arg) # both stores are aligned based on idx assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" def test_simple_order_with_special(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = linearize_uop(UOp.sink(st1, st0), skip_check=True) stores = [st for st in uops if st.op is Ops.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" class TestUPatHelpers(unittest.TestCase): def test_location(self): self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py") self.assertEqual(fix_kernel_ops.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py") self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py") with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? test_upat = UPat(Ops.CONST, dtypes.bool) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) class TestUopsObject(unittest.TestCase): # LOL, running this test breaks all instances of "4" """ @unittest.expectedFailure def test_immutable(self): const_4 = UOp.const(dtypes.int, 4) with self.assertRaises(Exception): const_4.arg = 5 """ def test_timing(self): with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)] assert len(ret) == 10000 class TestShapeSpec(unittest.TestCase): # ** CONST is CONST(VIEW(DEVICE)) -> RESHPAE -> EXPAND def test_expanded_const(self): a = Tensor(1).lazydata self.assertEqual(a.st, ShapeTracker.from_shape(())) a = Tensor.ones((4, 4)).lazydata self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4))) def test_padded_const(self): a = Tensor.ones((1, 1)).pad(((1, 1), (1, 1))) ast = a.contiguous().schedule()[0].ast valid_pattern = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar(), UPat.cvar())) valid_ternary = [x for x in ast.toposort if valid_pattern.match(x, {})][0] # the WHERE outputs a contiguous (3, 3) self.assertEqual(valid_ternary.st, ShapeTracker.from_shape((3, 3))) valid, x, y = valid_ternary.src # very notably, only the first source is padded self.assertIsNotNone(valid.st.views[-1].mask) assert x.st.views[-1].mask is y.st.views[-1].mask is None assert all(s.shape == (3, 3) for s in valid_ternary.src) # NOTE: CONST ShapeTracker comes from its source def test_scalar_const(self): a = Tensor(0).lazydata self.assertEqual(a.st, ShapeTracker.from_shape(())) def test_scalar_var(self): vv = UOp.variable("a", 1, 4).bind(2) t = Tensor(vv).lazydata self.assertEqual(t.st, ShapeTracker.from_shape(())) # ** ASSIGN is ASSIGN(VIEW(BUFFER), new_val) def test_assign_flat(self): buffer = Tensor.arange(4).realize() a = buffer.assign(Tensor.zeros((4,), dtype=dtypes.int)) assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat())) assert assign_pattern.match(a.lazydata, {}) a.realize() self.assertEqual(buffer.tolist(), [0, 0, 0, 0]) def test_assign_permuted(self): buffer = Tensor.arange(4).reshape(2, 1, 2).contiguous().realize() a = buffer.permute((1, 2, 0)).assign(Tensor.arange(4).reshape(1, 2, 2).contiguous()) a.realize() self.assertEqual(buffer.tolist(), [[[0, 2]], [[1, 3]]]) def test_assign_reshaped(self): buffer = Tensor.ones((4,)).contiguous().realize() a = buffer.reshape((2, 2)).assign(Tensor.zeros((2, 2))) assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat())) assert assign_pattern.match(a.lazydata, {}) a.realize() self.assertEqual(buffer.tolist(), [0, 0, 0, 0]) # setitem is a partial assign def test_setitem(self): a = Tensor.ones((4,)).contiguous().realize() assign = a.shrink(((1, 2),)).assign(Tensor.zeros((1,))) # the ASSIGN UOp has size=1 self.assertEqual(assign.lazydata.size, 1) # the ASSIGN views the buffer with a shrunk st self.assertEqual(assign.lazydata.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),))) # the underlying BUFFER has a size=4 self.assertEqual(assign.lazydata.buf_uop.size, 4) # NOTE: output shape is different from the BUFFER shape self.assertNotEqual(assign.lazydata.shape, a.lazydata.shape) assign.realize() self.assertEqual(a.tolist(), [1, 0, 1, 1]) def test_buffer_st(self): a = UOp.new_buffer(Device.DEFAULT, 10, dtypes.float) self.assertEqual(a.st, ShapeTracker.from_shape((10,))) def test_ops_st(self): # view / mop a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).lazydata self.assertEqual(a.st, ShapeTracker.from_shape((4, 2, 1)).permute((1, 2, 0))) # alu / reduce alu = a*2 self.assertEqual(alu.st, ShapeTracker.from_shape((2, 1, 4))) r = Tensor.empty(4, 4).sum(axis=1) self.assertEqual(r.lazydata.st, ShapeTracker.from_shape((4,))) def test_st_wmma_none(self): A = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('a', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 1))) B = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('b', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 2))) C = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('c', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 3))) wmma = UOp(Ops.WMMA, dtypes.float.vec(16), (A, B, C)) assert wmma.st is None class TestUOpChildren(unittest.TestCase): def test_children_exist(self): a = UOp.variable("weird_name_234", 0, 10) b = a*a self.assertEqual(len(a.children), 1) self.assertIs(list(a.children)[0](), b) def test_children_cleaned_up(self): a = UOp.variable("weird_name_235", 0, 10) b = a*a self.assertEqual(len(a.children), 1) del b self.assertEqual(len(a.children), 0) def test_children_cleaned_up_two(self): a = UOp.variable("weird_name_236", 0, 10) b = a*a c = a*2 self.assertEqual(len(a.children), 2) del b self.assertEqual(len(a.children), 1) del c self.assertEqual(len(a.children), 0) if __name__ == '__main__': unittest.main(verbosity=2)