You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
1.4 KiB
32 lines
1.4 KiB
import unittest
|
|
from tinygrad import dtypes, Device, Tensor, Context
|
|
from tinygrad.dtype import AddrSpace
|
|
from tinygrad.helpers import getenv
|
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.engine.realize import get_program, ExecItem, CompiledRunner
|
|
|
|
class TestDefineReg(unittest.TestCase):
|
|
def test_simple(self, at=AxisType.UPCAST):
|
|
N = 16
|
|
bout = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
|
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
|
|
a_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(N, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((N,N), (0,1)))
|
|
|
|
out = a_col.load(a_col.store(a.load()))
|
|
sink = bout.store(out).sink(arg=KernelInfo(name="regcopy", axis_types=(AxisType.LOOP, at)))
|
|
prg = get_program(sink, Device.default.renderer)
|
|
|
|
with Context(DEBUG=0):
|
|
a = Tensor.randn(N, N).realize()
|
|
b = Tensor.empty(N, N).realize()
|
|
hrunner = CompiledRunner(prg)
|
|
ExecItem(hrunner, [b.uop.buffer, a.uop.buffer]).run(wait=True)
|
|
with Context(DEBUG=0):
|
|
self.assertEqual((b-a).mean().item(), 0.0)
|
|
|
|
@unittest.skipIf(getenv("PTX"), "ptx needs regs to be unrolled")
|
|
def test_simple_loop(self): self.test_simple(AxisType.LOOP)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|