openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

536 lines
28 KiB

import numpy as np
import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.codegen import apply_rewrites, rewrites_for_views
from tinygrad.renderer.ptx import PTXRenderer
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
# without contiguous folding, rand.to("CPU") and rand.contiguous().to("CPU") are different UOps.
# this test asserts they are the identical Buffer
# having different buffers is fine for correctness, because the outputs match.
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = [x[1] for x in lower_schedule(c.schedule())]
for ei in lowered: ei.run()
rawbufs = lowered[-1].bufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
def test_load_removed(self):
a = Tensor.rand(1).realize()
b = Tensor.rand(1).realize()
ta = Tensor.where(Tensor(True), a, b).numpy()
tb = Tensor.where(Tensor(False), a, b).numpy()
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
def test_multioutput(self):
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)]
a = UOp(Ops.LOAD, dtype, src=(g2.view(st),))
b = UOp(Ops.LOAD, dtype, src=(g3.view(st),))
out0 = UOp(Ops.STORE, dtypes.void, src=(g0.view(st), a + b))
out1 = UOp(Ops.STORE, dtypes.void, src=(g1.view(st), a * b))
sink = UOp(Ops.SINK, src=(out0, out1))
a_t = Tensor.full(st.shape, 2).contiguous().realize()
b_t = Tensor.full(st.shape, 3).contiguous().realize()
helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])
uops = get_program(sink, opts=[]).uops
stores = [u for u in uops if u.op is Ops.STORE]
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
assert len(mutable_bufs) == len(stores) == 2
self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1]))
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
def test_two_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum()
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])
uops = get_program(ast, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])
uops = get_program(ast, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
ast = helper_linearizer_opt(out, wanna_output=[24])
uops = get_program(ast, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE
assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]])
assert not any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]])
assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in uops[ranges[1]:])
def test_range_outer_op_before_phi(self):
a = Tensor.randn(4, 1).realize()
b = Tensor.randn(1, 1).realize()
out = (a + b[0]).sum() + b[0]
ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])
uops = get_program(ast, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
# LOAD -> RANGE -> LOAD -> STORE
assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])
uops = get_program(ast, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.
a = Tensor.randn(4).realize()
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops <= 1, "more alu uops than needed"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_reduce_upcast(self):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
stores = [u for u in uops if u.op is Ops.STORE]
assert len(accs) == 0 # it's removed now
assert len(stores) == 1
assert stores[0].src[1].dtype == dtypes.float.vec(4)
# NOTE: can reenable, it does work. it just makes BEAM slow
@unittest.expectedFailure
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
def test_upcast_with_locals_cpu(self):
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)]).uops
self.assertEqual(len(prg.src.split("for")), 5)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason")
def test_upcast_with_locals(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
program = get_program(r.schedule()[-1].ast, opts=opts_to_apply)
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
# the first store is to lds and can be upcasted
assert stores[0].src[1].dtype == dtypes.float.vec(4)
assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort())
# the second store is to gds with no upcasts
assert stores[1].src[1].dtype == dtypes.float
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort())
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops == 0, "more alu uops than needed"
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule()[-1].ast
program = get_program(realized_ast, opts=[])
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
assert local[0].dtype.base == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
realized_ast = c.schedule()[-1].ast
program = get_program(realized_ast, opts=[])
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
self.assertEqual(local[0].dtype.base, expected_dtype)
tests = (
(dtypes.float16, None, dtypes.float),
(dtypes.bfloat16, None, dtypes.float),
(dtypes.float, None, dtypes.float),
(dtypes.float16, dtypes.float16, dtypes.float16),
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
(dtypes.float, dtypes.float16, dtypes.float16),
)
for tensor_dtype, acc_dtype, expected_dtype in tests:
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype):
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, dtype=acc_dtype), expected_dtype)
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
helper_arg_acc_dtype(d.conv2d(w, dtype=acc_dtype), expected_dtype)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_simple_unroll_no_between_phi_dependencies(self):
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
r = (x@y).relu()
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
ast = helper_linearizer_opt(r, [opt])
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
uops = get_program(ast, opts=opt).uops
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0]
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
for u in uops:
if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG:
if uops.index(u) < begin_range:
assert u.src[1].op is Ops.CONST
else:
assert u.src[1].op in GroupOp.ALU
assert begin_range < uops.index(u) < end_range
# children of STORE are placed after ENDRANGE
if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src):
assert end_range < uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse
t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6)))
ast = helper_linearizer_opt(t+1)
uops = get_program(ast, opts=[]).uops
idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL])
idxs = sorted(idxs, key=lambda uop: uop.arg)
assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0]
assert (idxs[1].arg, idxs[1].src[0].arg) == ('gidx1', 5), idxs[1].arg
assert (idxs[2].arg, idxs[2].src[0].arg) == ('gidx2', 4), idxs[2].arg
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in t.schedule() if si.ast.op is Ops.SINK]
# sum_collapse is a full collapse now
assert len(sched) == 1
assert not any(u.op is Ops.REDUCE_AXIS for u in sched[0].ast.toposort()), "found reduce in sum collapse"
#lin = Kernel(sched[0].ast)
#assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(a+m)
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(b.where(2, a))
sched = a.schedule()
assert len(sched) == 1
sched_copy = sched[:]
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
program = get_program(sched_copy[-1].ast, opts=())
assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
ast = helper_linearizer_opt(t)
uops = get_program(ast).uops
# ignore kernel optimized IF statements for now
if if_op:=next((u for u in uops if u.op is Ops.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
reg_stores = [u for u in uops if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace == AddrSpace.REG]
assert len(reg_stores) == 0, "STORE to reg should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)
# NOTE: both of these split the reduce (this just wasn't tracked before)
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
#helper(Tensor.arange(256), max_ops=2)
helper(Tensor.arange(255), max_ops=2)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason")
def test_grouped_store_phis(self):
"""
float4 acc0 = float4(0.0,0.0,0.0,0.0);
{
acc0 = // ...
}
*((device float4*)(data0+alu2)) = float4(acc0.x,acc0.y,acc0.z,acc0.w);
simplifies to:
*((device float4*)(data0+alu2)) = acc0;
"""
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
out = x.matmul(y)
with Context(TC=0):
ast = helper_linearizer_opt(out)
uops = get_program(ast).uops
# check that the float4 cast collapses
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_values(self):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
ast = helper_linearizer_opt(out)
store_val = [u.src[1] for u in get_program(ast).uops if u.op is Ops.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_locals_and_globals(self):
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
out = x@y
opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
ast = helper_linearizer_opt(out, opts=[opt])
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
uops = get_program(ast, opts=opt).uops
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
barrier = [u for u in uops if u.op is Ops.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
assert len([u for u in uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason")
def test_grouped_store_local_only(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
ast = helper_linearizer_opt(r)
uops = get_program(ast).uops
stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
# the float4 value stores directly in lds and we skip upcast
self.assertEqual(stores[0].src[1].dtype, dtypes.float.vec(4))
#assert stores[0].src[-1].op is not Ops.VECTORIZE
# the global store doesn't change
assert stores[1].src[1].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts(self):
Tensor.manual_seed(0)
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=())
c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))
c4 = c3.load()
c5 = c1.store(c4)
ast = c5.sink()
opt = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)]
helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])
out = [u for u in get_program(ast, opts=opt).uops if u.op is Ops.STORE][0]
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts_with_gep(self):
Tensor.manual_seed(0)
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=())
c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))
c4 = c3.load()
c5 = c1.store(c4)
ast = c5.sink()
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])
out = [u for u in get_program(ast).uops if u.op is Ops.STORE][0]
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1
# *** helpers ***
def push_views(ast): return apply_rewrites(ast, rewrites_for_views)
def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
if isinstance(r, Tensor): r = [r]
s = Tensor.schedule(*r)
run_schedule(s[:-1]) # run all kernels except the last one
assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}"
# now all input buffers in s[-1] should be realized
# create fresh buffers for the outputs
bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
return push_views(s[-1].ast), bufs
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
assert isinstance(ast, UOp), "ast must be UOp"
inbufs = [x.uop.base.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() for out in ast.src]
_helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs):
realized_ast, real_bufs = helper_realized_ast(r)
_helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
return realized_ast
def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]:
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
def reset_bufs(bufs:list[Buffer]):
for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):
outbufs = real_bufs[:len(realized_ast.src)]
device = real_bufs[0].device
wanna_output = [np.array(x).flatten() for x in wanna_output]
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, opts=opts), device=device))
def check_opt(opts):
prg = get_prg(opts=opts)
reset_bufs(outbufs)
prg.exec(real_bufs)
for x,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(x, want, atol=atol, rtol=rtol)
# Get baseline if it is not provided, which is not optimized at all.
prg = get_prg(opts=())
prg.exec(real_bufs)
if len(wanna_output) == 0: wanna_output = copyout_outputs(outbufs)
else:
for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol)
# Check correctness of handcoded optimiztions.
prg = get_prg(opts=None)
reset_bufs(outbufs)
prg.exec(real_bufs)
for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol)
for x in opts: # Check custom transformations if any.
check_opt(([Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, 1))] if apply_tc else [])+x)
if __name__ == '__main__':
unittest.main()