You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
4.9 KiB
98 lines
4.9 KiB
from __future__ import annotations
|
|
import unittest
|
|
|
|
from tinygrad import Tensor
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.helpers import DEBUG
|
|
from tinygrad.ops import UOp, Ops, print_uops
|
|
from tinygrad.codegen.kernel import verify_ast
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad import dtypes
|
|
from tinygrad.shape.view import View
|
|
|
|
class InvalidASTException(Exception): pass
|
|
def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
|
sink = UOp(Ops.SINK, dtypes.void, stores)
|
|
if DEBUG >= 3:
|
|
for op in stores: print(op)
|
|
try: verify_ast(sink)
|
|
except AssertionError as e: raise InvalidASTException(e.args)
|
|
k = Kernel(sink)
|
|
k.linearize()
|
|
if DEBUG >= 6: print_uops(k.uops)
|
|
if DEBUG >= 4: print(k.to_program().src)
|
|
return k
|
|
|
|
class TestVerifyAST(unittest.TestCase):
|
|
def test_tiny_add(self):
|
|
dtype = dtypes.int
|
|
buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
|
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
|
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2)
|
|
a = UOp(Ops.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
b = UOp(Ops.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
store = UOp(Ops.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
|
|
helper_test_verify_ast(store)
|
|
|
|
def test_exactly_one_full_shape(self):
|
|
dtype = dtypes.int
|
|
bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
|
|
a = UOp(Ops.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
b = UOp(Ops.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
|
|
a = UOp(Ops.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
|
|
b = UOp(Ops.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
|
|
st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
|
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
|
|
|
def test_no_implicit_broadcasting(self):
|
|
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
|
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
|
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,)))
|
|
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
|
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
|
|
|
def test_shrink_ok(self):
|
|
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
|
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
|
|
b = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
|
|
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
|
helper_test_verify_ast(st)
|
|
|
|
def test_reduce_store(self):
|
|
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
|
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
|
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
|
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
|
|
|
def test_reduce_add_store(self):
|
|
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
|
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
|
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
|
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
|
|
|
def test_buffer_uops_st(self):
|
|
a = Tensor.randn(4, 4)+2
|
|
uop_sts = verify_ast(a.schedule()[-1].ast)
|
|
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
|
|
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
|
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
|
|
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
|
|
|
def test_assert_swizzle(self):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
|
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))
|
|
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
|
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
|
|
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
|
|
|
|
def test_flat_const_always_valid(self):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
|
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
|
|
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
|
|
helper_test_verify_ast(st)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|