openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2115 lines
87 KiB

# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
# ruff: noqa: E501
import unittest
import numpy as np
import functools
from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize, remove_movement_ops
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if isinstance(t, Tensor): outs = t.lazydata.lbs
elif isinstance(t, List): outs = flatten([r.lazydata.lbs for r in t])
else: outs = [t]
if to_prerealize:
for pre in to_prerealize: pre.schedule()
sched = create_schedule(outs)
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
if len(sched) != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if DEBUG >= 3:
for i,s in enumerate(sched):
print("kernel", i+1)
print(s.ast)
raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not Ops.SINK: continue
get_runner(s.bufs[0].device, s.ast)
return sched
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()
def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
old_default_float, dtypes.default_float = dtypes.default_float, dtype
dtypes.default_float = dtype
Tensor.manual_seed(0)
BS, CIN = 2, 3
img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True).realize()
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
ret = Tensor.conv2d(img, w).relu().mean().backward()
dtypes.default_float = old_default_float
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
if getenv("CHECK", 1):
import torch
ref_img = torch.tensor(img.numpy(), requires_grad=True)
ref_w = torch.tensor(w.numpy(), requires_grad=True)
torch.nn.functional.conv2d(ref_img, ref_w).relu().mean().backward()
assert ref_img.grad is not None and ref_w.grad is not None and img.grad is not None and w.grad is not None
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = a+b+c
check_schedule(d, 1)
def test_basic_binop_fusion_deep(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
e = a+b+c+d
check_schedule(e, 1)
def test_mulacc_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum()
check_schedule(c, 1)
def test_mulacc_relu_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum().relu()
check_schedule(c, 1)
def test_binop_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(5,2)
d = (a+b).reshape(5,2)+c
check_schedule(d, 1)
def test_binop_permute_fusion(self):
a = Tensor.empty(2,5)
b = Tensor.empty(2,5)
c = Tensor.empty(5,2)
d = (a+b).permute(1,0)+c
check_schedule(d, 1)
def test_constants_are_embedded(self):
a = Tensor.empty(3,3) * 2
check_schedule(a, 2, filter_sink=False)
def tests_constants_are_folded(self):
a = Tensor(2)
check_schedule(a, 0)
def test_constants_can_store(self):
a = Tensor(2).contiguous()
run_schedule(check_schedule(a, 1))
np.testing.assert_equal(a.numpy(), 2)
def test_binop_elu_fusion(self):
a = Tensor.empty(10)
b = a.elu()
check_schedule(b, 1)
def test_binop_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True)
check_schedule(c, 1)
def test_reduce_reshape_binop_fusion(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10)
c = a.sum(axis=0) + b
check_schedule(c, 1)
# not pushing permutes through reduces
def test_reduce_permute_binop_fusion(self):
a = Tensor.empty(10,10,10)
b = Tensor.empty(10,10,1)
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
with self.assertRaises(KernelCountException): check_schedule(c, 1)
def test_binop_early_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = Tensor.empty(10,10)
d = ((a+b).reshape(10,10) + c).sum(axis=0)
check_schedule(d, 1)
def test_diamond_folded(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
ab = a+b
e = (ab+c) + (ab+d)
check_schedule(e, 1)
def test_cache_binaryop(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a+b
check_schedule(d, 0, [c])
# failing in new lazy
def test_cache_binaryop_reshaped(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
# failing in new lazy
def test_cache_binaryop_transpose(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10,10)
c = (a.T*b.T).T #.contiguous()
d = a*b
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
b = a.sum()
c = a.sum()
bc = b+c
check_schedule(bc, 1)
def test_cache_reduce_parent(self):
x = Tensor.empty(32)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
out = r0 + r1
schedule = check_schedule(out, 2)
reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_cache_reduce_multiple_children(self):
x = Tensor.empty(32)
y = Tensor.empty(4, 4)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule([out0, out1], 4)
reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_dedup_assign(self):
a = Tensor.ones(4).contiguous().realize()
b = Tensor.full((4,), 2.).contiguous()
first = a.assign(b)
second = a.assign(b)
check_schedule([first, second], 1)
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
# should contiguous dedup?
def test_dedup_contiguous(self):
a = Tensor.ones(4).contiguous()
b = Tensor.ones(4).contiguous()
sched = check_schedule([a, b], 1)
run_schedule(sched)
# a and b share the same underlying device memory
self.assertIs(a.lazydata.realized, b.lazydata.realized)
# EMPTY and COPY are assigned to unique device Buffers
def test_no_dedup_copy(self):
src = Tensor.ones(4).contiguous().realize()
a = src.clone()
b = src.clone()
sched = check_schedule([a, b], 2, filter_sink=False)
run_schedule(sched)
# a and b are assigned to different device Buffers
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
def test_no_dedup_empty(self):
a = Tensor.empty((4,))
b = Tensor.empty((4,))
sched = check_schedule([a, b], 2, filter_sink=False)
run_schedule(sched)
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().__neg__()
check_schedule(out, 1)
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True)
out = bn(c1(img)).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 15)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_batchnorm_backward(self):
with Context(FUSE_CONV_BW=1):
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
# run
img = Tensor.ones(2,3,64,64)
out = c1(img).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_relu_alt(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
def test_fold_conv_relu_nobias(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
out = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
check_schedule(out, 2, [c1.weight, c2.weight, img])
def test_fold_conv_elu(self):
c1 = nn.Conv2d(3,16,3)
# run
img = Tensor.rand(2,3,64,64)
out = c1(img).elu()
check_schedule(out, 1, [c1.weight, c1.bias, img])
def test_fold_conv_elu_alt(self):
img = Tensor.ones(1,4,8,8).contiguous()
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu])
check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
def test_two_sum(self):
img = Tensor.empty(64,64)
x = (img.sum(0) + img.sum(1))
out = x.relu()
check_schedule(out, 2)
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape(self):
a = Tensor.empty(16,16)
b = Tensor.empty(16,16)
c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous()
check_schedule(c, 1)
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape_alt(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(4,4,4,4)
c = (a+b).reshape(16,16).permute(1,0).contiguous()
check_schedule(c, 1)
def test_no_binop_rerun(self):
a = Tensor.empty(16)
b = Tensor.empty(16)
c = a+b
d = (a+b).reshape(16,1)
check_schedule(d, 0, [c])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(16)
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
check_schedule(c, 1)
def test_fancy_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
out = c.sum() + d.sum()
with self.assertRaises(KernelCountException): check_schedule(out, 1)
def test_children_dont_push(self):
a = Tensor.empty(10, 10, 1)
b = Tensor.empty(10, 10, 1)
d = (a+b).expand(10, 10, 10)
e = (a+b).permute(2,1,0)
f = d+e
check_schedule(f, 2)
# failing in new lazy
def test_dont_fuse_binops_with_children(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
keep_me = a+b
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me+c
with self.assertRaises(KernelCountException): check_schedule(d, 2)
with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d])
#@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
a = Tensor.empty(10, 10, 10)
b = Tensor.empty(10, 10)
c = (a.sum(axis=2) + b).permute(1,0)
d = c.permute(1,0)
check_schedule(d, 1)
def test_some_permute_fusion(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(1, 16)
d = (a.T + b.expand(8192, 16).T)
c = a + b.expand(8192, 16)
e = d.T
check_schedule(c, 1)
check_schedule(e, 1)
def test_shrink_fuse(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(8192, 16)
c = a * b
d = Tensor.empty(1, 16)
e = c[0] * d
check_schedule(e, 1)
def test_expand_nofuse(self):
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)
c = a * b
d = Tensor.empty(8192, 16)
e = c * d
check_schedule(e, 2)
# this is the failing case in openpilot...it's very simple like this
def test_image_conv_fusion(self):
w1 = Tensor.empty(16, 16, 1, 1)
b1 = Tensor.empty(16)
w2 = Tensor.empty(16, 16, 1, 1)
b2 = Tensor.empty(16)
w3 = Tensor.empty(16, 16, 1, 1)
b3 = Tensor.empty(16)
x = Tensor.empty(1, 16, 32, 32)
x = base = x.image_conv2d(w1, b1)
x = x.image_conv2d(w2, b2) + base
x = x.image_conv2d(w3, b3)
# NOOP, 3 convs, contiguous
with self.assertRaises(KernelCountException): check_schedule(x, 5)
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)
b2 = Tensor.empty(16)
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16,1)
x = p(x)
x = x + b2.reshape(16,1)
x = x + base
del base
x = p(x)
check_schedule(x, 4)
def test_image_conv_fusion_more_minimal(self):
b1 = Tensor.empty(16)
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16,1)
x = p(x)
del base
check_schedule(x, 3)
def test_resnet_block(self):
old_training = Tensor.training
Tensor.training = False
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
bn2 = nn.BatchNorm2d(planes)
x = Tensor.empty(1, 64, 32, 32)
out = bn1(conv1(x)).relu()
out = bn2(conv2(out))
out = (out + x).relu()
check_schedule(out, 2, [conv1.weight, conv2.weight])
Tensor.training = old_training
def test_contiguous_while_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.contiguous()
check_schedule(out, 1, filter_sink=False)
def test_contiguous_while_not_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.permute(0,2,3,1).contiguous()
check_schedule(out, 2, filter_sink=False)
def test_fold_with_contiguous(self):
a = Tensor.randn(16, 16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = (a.sum(2).contiguous() + b).contiguous()
check_schedule(c, 2)
def test_double_from(self):
x = Tensor([1,2,3,4])
out = x.to('python')
check_schedule(out, 0, filter_sink=False)
def test_pow_const_tensor_simplified(self):
x = Tensor([1,2,3,4])
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
out = x ** Tensor(2)
check_schedule(out, 1)
def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])
out = x ** Tensor(0)
# NOTE: this is ConstBuffer 0 + ConstBuffer 1
check_schedule(out, 0)
def test_zero_size(self):
x = Tensor.empty(2, 3, 0)
out = x + 1
check_schedule(out, 0, filter_sink=False)
def test_reduce_permute_nofuse(self):
x = Tensor.empty(32, 32, 32)
y = Tensor.empty(32, 32)
out = x.sum(axis=2).T+y
check_schedule(out, 2)
def test_two_elus_sum(self):
x = Tensor.empty(32, 32)
y = Tensor.empty(32, 32)
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
check_schedule(out, 2)
# multireduce spec
@unittest.skipUnless(SPLIT_REDUCEOP, "Testing split reducop requires SPLIT_REDUCEOP")
def test_preserve_multistage_reduce(self):
big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
x = Tensor.randn(big_enough).realize()
out = (x - x.max(keepdim=True)).max()
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(keepdims=True)).max())
def test_multistage_reduce(self):
x = Tensor.empty(32, 32, 32)
out = x.sum(2).relu().sum(1)
check_schedule(out, 2)
def test_multistage_reduce_fork(self):
x = Tensor.empty(32, 32, 32)
x = x.sum(2)
out2 = x + 1
out = x.relu().sum(1) + out2[0]
check_schedule(out, 2)
# multireduce spec
def test_example_matmul(self):
x = Tensor.eye(64, requires_grad=True)
y = Tensor.eye(64, requires_grad=True)
z = y.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_contiguous_add(self):
x = Tensor.empty(32)
y = Tensor.empty(32)
z = Tensor.empty(32)
out = (x+y).contiguous()+z
check_schedule(out, 2)
def test_double_sum_ref(self):
x = Tensor.empty(32, 32, 32)
x = x.sum(2)
out = x + x[:, 4]
check_schedule(out, 2)
def test_reduce_shrink(self):
x = Tensor.empty(32, 32)
y = Tensor.empty(16)
x = x.sum(1)
x = x[:16]
out = x + y
check_schedule(out, 2) # TODO: this should be 1
# multireduce spec
def test_multireduce_shrink(self):
Tensor.manual_seed(0)
a = Tensor.randn(32, 32).realize()
b = Tensor.randn(32, 32).realize()
c = Tensor.randn(16).realize()
a_out = a.sum(1)
a_out = a_out[:16]
b_out = b.sum(1)
b_out = b_out[:16]
out = a_out + b_out + c
# run_schedule(check_schedule(out, 2)) # TODO: this should be 1 (can we make it 1 with the new linearizer?)
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), a.numpy().sum(axis=1)[:16] + b.numpy().sum(axis=1)[:16] + c.numpy(), atol=1e-4, rtol=1e-4)
# broken due to const folding and two contiguous are different kernels
# NOTE: passes after delete_lazy
def test_const_no_recompute(self):
x = Tensor(2) + Tensor(2)
y = Tensor(2) + Tensor(2)
out = x.contiguous() + y.contiguous()
check_schedule(out, 2, filter_sink=False)
# multireduce spec
def test_reduce_same_size(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
out0 = a.sum() + 2
out1 = a.sum() + 4
out2 = out0 * out1
run_schedule(check_schedule([out0, out1, out2], 1))
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
# multireduce spec
def test_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
out0 = a.sum().exp2()
# out1 has two paths to a.sum()
out1 = a.sum() + out0
run_schedule(check_schedule([out0, out1], 1))
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6)
# multireduce spec
def test_multireduce_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
out0 = a.sum().exp2()
out1 = a.sum() + out0
b = (a + out0 + out1)
out2 = b.sum().exp2()
out3 = b.sum() + out2
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
run_schedule(check_schedule([out0, out1, out2, out3], 2))
np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
np_b = (a.numpy() + np_out0 + np_out1)
np.testing.assert_allclose(out2.numpy(), np_out2:=np.exp2(np_b.sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out3.numpy(), np_b.sum()+np_out2, atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_ext_reduce_child(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
b = Tensor.randn(4, 4).realize()
# b.sum() is not a descendant of the fused nodes
out0 = a.sum() + b.sum() + 2
out1 = a.sum() + b.sum() + 4
# run_schedule(check_schedule([out0, out1], 1))
run_schedule(check_schedule([out0, out1], 4))
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_multiple_paths_midreduce(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
r = a.sum()
out0 = r.exp2()
# reduce node in the indirect path from r to out2
out1 = (a - out0).max()
out2 = r + out1
# run_schedule(check_schedule([r, out0, out1, out2], 1))
run_schedule(check_schedule([r, out0, out1, out2], 4))
np.testing.assert_allclose(r.numpy(), r_np:=a.numpy().sum(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(r_np), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), out1_np:=(a.numpy() - out0_np).max(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out2.numpy(), r_np + out1_np, atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_multiple_paths_midreduce_fused(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
b = Tensor.randn(4, 4).realize()
out0 = a.sum() + 4
out1 = b.max() + out0*2
out2 = a.sum() + out1
# run_schedule(check_schedule([out0, out1, out2], 1))
run_schedule(check_schedule([out0, out1, out2], 4))
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
# multireduce spec
def test_reduce_multiple_paths_midexpand(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
b = Tensor.randn(4, 4, 4).realize()
r = a.sum()
out0 = r.exp2()
# e1 is in the indirect path from a.sum() to out1
e = b + out0
out1 = r + e[0][0][0]
# run_schedule(check_schedule([r, out0, out1, e], 3)) # 1 or 2 or 3? should be 1 (one reduce) but the different outputs might make it 3
run_schedule(check_schedule([r, out0, out1, e], 4))
np.testing.assert_allclose(r.numpy(), r_np:=a.numpy().sum(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(r_np), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=b.numpy() + out0_np, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), r_np + e_np[0][0][0], atol=1e-4, rtol=1e-4)
# changed by multireduce
def test_reduce_expand_child(self):
Tensor.manual_seed(0)
a = Tensor.randn((32, 32, 32)).realize()
b = Tensor.randn((1, 16)).realize()
out0 = a.sum() + 2
out1 = a.sum() + b
# run_schedule(check_schedule([out0, out1], 2))
run_schedule(check_schedule([out0, out1], 4))
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4)
def test_reduce_shrink_child(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
c = a.sum() + b[0]
d = a.sum() + 2
check_schedule([c, d], 1)
def test_reduce_multiple_paths_midshrink(self):
a = Tensor.empty(4, 4)
r = a.sum(axis=1)
out0 = r.exp2()
out1 = out0[0] + out0
check_schedule([r, out0, out1], 3)
def test_reduce_shrink_output(self):
a = Tensor.empty(4, 4)
r = a.sum(keepdim=True)
out0 = r.exp2()
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
# multireduce spec
def test_std_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.std(-1)
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_argmin_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmin(-1)
run_schedule(check_schedule(out, 3))
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
# multireduce spec
def test_argmax_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmax(-1)
run_schedule(check_schedule(out, 3))
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
def test_scaled_dot_product_attention_multireduce_fusion(self):
Tensor.manual_seed(0)
q = Tensor.randn(32,8,16,64).realize()
k = Tensor.randn(32,8,16,64).realize()
v = Tensor.randn(32,8,16,64).realize()
out = Tensor.scaled_dot_product_attention(q,k,v)
run_schedule(check_schedule(out, 5))
if getenv("CHECK", 1):
import torch
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3)
# multireduce spec
def test_ugly_reduceop_pairing(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
b = Tensor.randn(4, 32).realize()
c = Tensor.randn(4, 32).realize()
out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), \
(c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_expand_reduce_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out = (a+a.sum(-1, keepdim=True)).sum(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_expand_reduce_expand_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), \
a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_branching_reduces_and_expands_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out0 = a+a.sum(-1, keepdim=True)
out1 = out0.sum(-1)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule([out0, out1], 3))
np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_simple_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = (y + x.sum(axis=-1, keepdim=True)).sum(axis=-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (y.numpy() + x.numpy().sum(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_simple_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = y.sum(axis=-1) + x.sum(axis=-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.std(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = x.std(-1) + y.std(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_diffops_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = (x - x.max(-1, keepdim=True)).sum(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_diffops_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = x.sum(-1) + y.max(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_sequential_and_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
mu = (x - x.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True) + (y - y.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True)
out = [((x - mu).square().sum(-1)/x.shape[-1]).sqrt(), ((y - mu).square().sum(-1)/y.shape[-1]).sqrt()]
np_mu = (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True) + \
(y.numpy() - y.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 6))
np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multimatmul_fusion(self):
Tensor.manual_seed(0)
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
c,d = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
out = a@b + c@d
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), a.numpy()@b.numpy() + c.numpy()@d.numpy(), atol=1e-4, rtol=1e-4)
def test_softmax_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64).realize()
out = x.softmax()
run_schedule(check_schedule(out, 3))
expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
def test_softmax_backward(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize()
x.softmax().sum().backward()
run_schedule(check_schedule(x.grad, 4))
# changed by: multireduce spec
def test_layernorm_onelayer_fusion(self):
Tensor.manual_seed(0)
layer = nn.LayerNorm([10, 10])
layer.weight = Tensor.randn(10,10).realize()
layer.bias = Tensor.randn(10,10).realize()
x = Tensor.randn(20, 5, 10, 10).realize()
out = layer(x)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule(out, 3))
y = (x.numpy() - x.numpy().mean(layer.axis, keepdims=True))
expected = y / np.sqrt((y*y).mean(layer.axis, keepdims=True) + layer.eps)
np.testing.assert_allclose(out.numpy(), expected * layer.weight.numpy() + layer.bias.numpy(), atol=1e-4, rtol=1e-4)
def test_scaled_dot_product_attention_fusion(self):
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m)
check_schedule(out, 5)
def test_scaled_dot_product_attention_causal_fusion(self):
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m, is_causal=True)
check_schedule(out, 6)
def test_adam_step_fusion(self):
with Tensor.train():
x = Tensor.empty(4, 64, 768)
layer = nn.Linear(768, 768*4)
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
def test_adam_conv_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
def test_adam_2convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
_realize_weights([c1, c2])
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 13)
def test_sgd_conv_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
opt = nn.optim.SGD(nn.state.get_parameters(c1))
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 5)
def test_sgd_2convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
_realize_weights([c1, c2])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 8)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
_realize_weights([c1, c2])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
def test_sgd_4convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
_realize_weights([c1, c2, c3, c4])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 18)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
img = Tensor.empty(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
_realize_weights([c1, c2, c3, c4])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize()
# y = Tensor.ones(4).contiguous().realize()
z = Tensor.ones(4, 4).contiguous().realize()
# should not create extra kernel if output will be realized anyways
dummy = x.sum().half().float()
check_schedule(dummy, 1)
dummy = x.sum().half().float().contiguous() + 1
check_schedule(dummy, 2)
# shared between two outputs
shared = x.sum().half().float()
a = shared * 2
b = shared * 3
sched = check_schedule([a, b], 1)
for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
# reduce
a = z.sum(axis=0).half().float().sum(axis=0)
sched = check_schedule(a, 2)
for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs)
# expand
# expand will realize just after the .float(), so requires change to realize-before-expand
# normal = (x.sum().half().float().reshape(1) * y).sum()
# sched = check_schedule(normal, 2)
# for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs[:-1])
# parallel reduce
# a = x.sum().half().float() * y.sum().half().float()
# b = a + 1
# c = a + 2
# sched = check_schedule([b, c], 4)
# doesn't store either in half because it doesn't chase
def test_reduce_simple_chase(self):
a = Tensor.empty(4, 4, 4)
r = a.sum(0) + 6
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_simple_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4, 4).realize()
r = (a + (a.sum(0, keepdim=True) + 6)).sum(0) * 2
b = r.sum(0) + 8
c = r.sum(1) + 12
np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2
# schedule = check_schedule([b,c], 3)
# self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL)
schedule = check_schedule([b,c], 4)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(c.numpy(), np_r.sum(1) + 12, atol=1e-4, rtol=1e-4)
def test_push_permute_chase(self):
a = Tensor.empty(4, 4, 4)
b = Tensor.empty(4, 4)
r = a.sum(2) + b
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_permute_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4, 4).realize()
b = Tensor.randn(4, 4).realize()
r = a.sum(2) + b
d = r.T * 4
e = r * (d + a).sum(2)
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
def test_push_shrink_chase(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(4)
c = Tensor.empty(16, )
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_shrink_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(4).realize()
c = Tensor.randn(16, ).realize()
d = Tensor.randn(16, 16).realize()
r = a.sum(1) + c
out = r[:4] * b + d.sum(1)[:4]
# schedule = check_schedule(out, 2)
schedule = check_schedule(out, 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
def test_midreduce_nochase(self):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
schedule = check_schedule(b, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
# multireduce spec
def test_multireduce_midreduce_nochase(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
# schedule = check_schedule(b, 2)
schedule = check_schedule(b, 4)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
# changed by: multireduce spec
# pattern in test_transformer
def test_partial_fuse1(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = (a.sum() - b.sum()) * 4
# run_schedule(check_schedule([c, d], 1))
run_schedule(check_schedule([c, d], 3))
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)
# changed by: multireduce spec
# pattern in conv
def test_partial_fuse2(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = b.sum() - c
# run_schedule(check_schedule([c, d], 1))
run_schedule(check_schedule([c, d], 2))
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), b.numpy().sum()-(a.numpy().sum()+2), atol=1e-4, rtol=1e-4)
# changed by: multireduce spec
# pattern in adam
def test_partial_fuse3(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = a.sum() * 2
e = c * d
f = b.sum() - e
# run_schedule(check_schedule([c, d, e, f], 1))
run_schedule(check_schedule([c, d, e, f], 2))
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4)
# changed by: multireduce spec
def test_partial_fuse4(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = a.sum() * 2
e = c * d
f = (b - d).sum() - e
# run_schedule(check_schedule([c, d, e, f], 1))
run_schedule(check_schedule([c, d, e, f], 3))
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(f.numpy(), (b.numpy()-d_np).sum()-e_np, atol=1e-4, rtol=1e-4)
def test_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
# multireduce spec
def test_multireduce_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).realize()
b = Tensor.randn(3, 4, 5).realize()
out = (a.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()).contiguous()
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \
np.pad(b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-4)
def test_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
# multireduce spec
def test_multireduce_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).abs().realize()
b = Tensor.randn(3, 4, 5).abs().realize()
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-6)
def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()
b = Tensor.ones((3, )).contiguous().realize()
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_equal(out.numpy(), [2, 0])
def test_base_change_shrink_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:-1, :-1]
d = c.pad(((0, 1), (0, 1))) * 2
run_schedule(check_schedule(d, 2))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
def test_base_change_expand_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:, None, :]
d = c.pad(((0, 0), (1, 1), (0, 0))) * 2
run_schedule(check_schedule(d, 2))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
# TODO like openpilot with imagef
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_base_change_expand_expand(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.cast(dtypes.half).expand(2, 4, 4)
c = b.cast(dtypes.int).expand(2, 2, 4, 4)
run_schedule(check_schedule(c, 2))
np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32))
def test_base_change_pad_expand(self):
a = Tensor.full((4, 4), 1.).contiguous().realize()
b = Tensor.full((4, 4), 2.).contiguous().realize()
c = (a + b).pad(((1, 1), (1, 1)))
d = c.cast(dtypes.int).expand((2, 6, 6)) * 4
run_schedule(check_schedule(d, 2))
c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0)
np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4)
def test_pad_reduce_unsafe_multiview_st(self):
P = Tensor.ones(3, 3).contiguous()
sums = P.sum(axis=1, keepdim=True)
P /= sums
p = P[0]
p = p.pad(((1, 0), ))
p = p.repeat([2])
run_schedule(check_schedule(p, 3))
tiny_ret = p.numpy()
P = np.ones((3, 3), dtype=np.float32)
sums = P.sum(axis=1, keepdims=True)
P /= sums
p = P[0]
p = np.pad(p, (1, 0), 'constant')
p = np.tile(p, 2)
np.testing.assert_allclose(tiny_ret, p)
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
def test_bitcast_subbufer(self):
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=True)
b = x.cast(dtypes.int32, True, allow_buffer_view=True)
b = a.alu(Ops.ADD, b)
check_schedule(b, 2) # this should fuse when it makes sense
def test_bitcast_disable_subbufer(self):
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
b = a.alu(Ops.ADD, b)
check_schedule(b, 1)
def test_reduceop_reshape_dont_push(self):
Tensor.manual_seed(0)
x = Tensor.randn(10, 20).realize()
out = x.argmax(1)
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
def test_conv2d(self): _test_conv2d(7)
def test_conv2d_fused(self): _test_conv2d(6, FUSE_CONV_BW=1)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_conv2d_half(self): _test_conv2d(7, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.expectedFailure
def test_conv2d_fused_half(self): _test_conv2d(5, dtype=dtypes.half)
def _test_buf_cnt(self, cnt:int, allowed:int):
if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)])
s = alu.schedule()
assert len(s) == allowed
run_schedule(s)
expected = functools.reduce(lambda x,y: x+y, [np.ones((1, 1)) for _ in range(cnt-1)])
np.testing.assert_equal(alu.numpy(), expected)
def test_buf_cnt_at_limit(self): self._test_buf_cnt(31, allowed=1)
@unittest.expectedFailure
def test_buf_cnt_over_limit(self): self._test_buf_cnt(32, allowed=2)
@unittest.expectedFailure
def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3)
def test_schedule_mem_used(self):
base = GlobalCounters.mem_used
Tensor.ones(256).contiguous().realize()
Tensor.ones(5, 5).contiguous().schedule()
self.assertEqual(GlobalCounters.mem_used-base, 0)
@unittest.skip("TODO: this is consistently creating non reproducible failures")
def test_schedule_mem_used_with_inputs(self):
base = GlobalCounters.mem_used
x = Tensor.ones(256).contiguous().realize()
(x+Tensor.ones(256).contiguous()).schedule()
self.assertEqual(GlobalCounters.mem_used-base, 1024)
def test_const_schedule(self):
constv = Tensor.empty(2, 2).lazydata.const_like(10)
self.assertEqual(len(create_schedule([constv])), 0)
def test_const_schedule_contig(self):
constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous()
self.assertEqual(len(create_schedule([constv])), 1)
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
def test_image_matmul(self):
with Context(IMAGE=2):
x = Tensor.randn((9, 9)).realize()
y = Tensor.randn((9, 9)).realize()
out = x@y
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
def _test_fusion(self, shapes, f, cnt):
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
run_schedule(check_schedule(compare:=f(*args), cnt))
if getenv("COMPARE", 1):
import torch
good = f(*[torch.tensor(x.numpy()) for x in args])
np.testing.assert_allclose(compare.numpy(), good.numpy(), atol=1e-4, rtol=1e-4)
def test_late_fusion_simple(self):
self._test_fusion([(4, 4), (4, 1)], lambda a,b:a.sum(1, keepdim=True)+b, 1)
def test_late_fusion_post_reshape(self):
self._test_fusion([(4, 4), (1, 4)], lambda a,b:a.sum(1).reshape(b.shape)+b, 1)
def test_late_fusion_post_permute(self):
self._test_fusion([(4, 6, 4), (4, 4, 1)], lambda a,b:a.sum(1, keepdim=True).permute((2, 0, 1))+b, 2)
def test_late_fusion_double_transpose(self):
self._test_fusion([(32, 16, 1)],
lambda a:(a.expand(32, 16, 16).sum((2,), keepdim=True).permute((1, 0, 2))+2).permute((1, 0, 2)).contiguous(), 1)
def test_late_fusion_post_expand(self):
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
lst = [xt] if isinstance(xt, Tensor) else xt
s = Tensor.schedule(*lst)
kernels = [si for si in s if si.ast.op is Ops.SINK]
for si in kernels: verify_ast(si.ast)
run_schedule(s.copy())
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
return s
def test_simple_indexing(self):
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = X[idxs]
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
@unittest.skip("TODO: support pads in graph_rewrite")
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0]]
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
@unittest.expectedFailure
def test_advanced_indexing_alt(self):
X = Tensor.arange(6).reshape(3, 2)+1
xt = X[[Tensor([2]), Tensor([1])]]
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)
@unittest.skip("TODO: support pads in graph_rewrite")
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
self.check_schedule(xt, 2)
def test_push_through_reshape(self):
Tensor.manual_seed(0)
x = Tensor.randn(10, 20).realize()
out = x.argmax(1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1))
def test_arange_push_through_expand(self):
Tensor.manual_seed(0)
a = Tensor.arange(4,)
b = Tensor.randn(4, 4).realize()
out = a+b
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), np.arange(4)+b.numpy())
def test_argmin(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmin(-1)
self.check_schedule(out, 2)
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
def test_argmax(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmax(-1)
self.check_schedule(out, 2)
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
def test_arange_transposed(self):
Tensor.manual_seed(0)
x = Tensor.randint(4, 1).realize()
a = (Tensor.arange(4,)*x).T
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T)
def test_arange_transposed_descendants(self):
Tensor.manual_seed(0)
x = Tensor.randint(4, 1).realize()
a = (Tensor.arange(4,)*x).T
b = Tensor.randint(4, 4).realize()
out = a+b
self.check_schedule(out, 1)
np.testing.assert_equal(out.numpy(), (np.arange(4)*x.numpy()).T+b.numpy())
def test_arange_index(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)
out = (x + a[2]).sum()
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_child(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)+1
out = (x + a[2]).sum()
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous_child(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = (Tensor.arange(10)+1).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_childless_base(self):
a = Tensor.arange(4)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), np.arange(4))
def test_arange_childless_view(self):
a = Tensor.arange(4).reshape(2, 2)
a[0] = 4
np.testing.assert_equal(a.numpy(), [[4, 4], [2, 3]])
def test_arange_group_childless_base(self):
Tensor.manual_seed(0)
x = Tensor.randint(4).realize()
a = Tensor.arange(4)+x
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy())
def test_arange_group_childless_view(self):
Tensor.manual_seed(0)
x = Tensor.ones(4).contiguous().realize()
a = Tensor.arange(4)+x
a[0] = 6
np.testing.assert_equal(a.numpy(), [6., 2., 3., 4.])
@unittest.skipUnless(Device.DEFAULT in view_supported_devices, "need view")
def test_arange_view_op(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
sched = self.check_schedule(a, 1)
self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@unittest.skipIf(Device.DEFAULT == "CLANG", "tests copy from ext device")
def test_arange_shrink_copy(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG")
sched = self.check_schedule(a, 1)
self.assertIs(sched[-1].ast.op, Ops.COPY)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@unittest.skipIf(Device.DEFAULT == "CLANG", "tests copy from ext device")
def test_arange_expand_copy(self):
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).contiguous().to("CLANG")
sched = self.check_schedule(a, 1)
self.assertIs(sched[1].ast.op, Ops.COPY)
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.ADD)
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])
@unittest.skip("TODO: support pads in graph_rewrite")
#@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_precompute_freqs_cis(self):
args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000, "dtype":dtypes.half}
fused = precompute_freqs_cis(**args)
self.check_schedule(fused, 1)
if getenv("CHECK", 1):
ref = precompute_freqs_cis(**args)
run_schedule(check_schedule(ref, 3))
np.testing.assert_equal(fused.numpy(), ref.numpy())
def test_fuse_assign_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)
self.check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2)
np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]])
def test_assign_non_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
y = Tensor.randint(4, 2)
a = Tensor.arange(8).reshape(4, 2)+y
x.shrink((None, (0, 2))).assign(a).realize()
xref = np.zeros((4, 4), dtype=int)
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
np.testing.assert_equal(x.numpy(), xref)
def test_sparse_categorical_crossentropy_simple(self):
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
Y = Tensor([1, 2]).realize()
loss = X.sparse_categorical_crossentropy(Y)
self.check_schedule(loss, 4)
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Validation error on WebGPU")
def test_mnist_val(self):
from tinygrad.nn.datasets import mnist
import torch
_, Y_train, _, _ = mnist()
samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize()
yt = Tensor.randn(BS, 10).realize()
with Context(SPLIT_REDUCEOP=0):
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
self.check_schedule(loss, 6)
loss_fused = loss.numpy()
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
def test_arange_fuse_grouped_children(self):
X = Tensor.randn(4, 4).realize()
r = (X+Tensor.arange(16).reshape(4, 4)).sum()
out0 = r+2
out1 = r+3
self.check_schedule([out0, out1], 1)
r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum()
np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7)
np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7)
@unittest.expectedFailure
def test_fold_arange_view(self):
X = Tensor.randn(4, 4).realize()
r = (X+Tensor.arange(16).reshape(4, 4).contiguous()).sum(1, keepdim=True)
self.check_schedule([r], 1)
np.testing.assert_allclose(r.numpy(), (X.numpy()+np.arange(16).reshape(4, 4)).sum(1, keepdims=True))
@unittest.expectedFailure
def test_multiview_arange_children(self):
X = Tensor.randn(2,3,4,4).numpy()
with Context(FUSE_ARANGE=1):
compare = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
with Context(FUSE_ARANGE=0, TRACK_MATCH_STATS=0):
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
swizzle = ast.src[0].src[2].reshape((4, 1))
new_uop = swizzle_rewrite(swizzle)
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_no_rewrite_elementwise(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
def test_simple_store_reshape(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + 2
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
verify_ast(rsink)
def test_no_reshape_reduceop(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = swizzle_rewrite(sink)
k = Kernel(sink)
p = k.to_program()
a = Tensor.randint(32, 32).realize()
b = Tensor.empty((), dtype=dtypes.int).realize()
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
expected_out = (a.numpy() + a.numpy().sum()).sum()
np.testing.assert_equal(b.numpy(), expected_out)
def test_single_swizzle(self):
# ast in tensor style
a = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+1
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,)))
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
alu = swizzle_r+1
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = swizzle_rewrite(sink)
# verify output
k = Kernel(sink)
p = k.to_program()
b = Tensor.empty((1,), dtype=dtypes.int).realize()
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
np.testing.assert_equal(b.numpy(), expected_out)
def test_double_swizzle_possible(self):
# ast in tensor style
Tensor.manual_seed(0)
a = Tensor.randint(4,).realize()
b = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,)))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,)))
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501
# graph rewrite
sink = swizzle_rewrite(sink)
# verify output
k = Kernel(sink)
p = k.to_program()
c = Tensor.empty((1,), dtype=dtypes.int).realize()
CompiledRunner(p).exec([c.lazydata.buffer, a.lazydata.buffer, b.lazydata.buffer])
np.testing.assert_equal(c.numpy(), expected_out)
def test_swizzle_rewrite_alt(self):
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an UNROLL pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
ret = swizzle_rewrite(swizzle)
# UNROLL is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD
new_load_st = unwrap([x for x in ret.toposort if x.op is Ops.VIEW][0].st)
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_permute_rewrite(self):
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
x1:=UOp(Ops.BUFFER, dtypes.float, arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
@unittest.expectedFailure
def test_fuse_conv2_relu_bw(self):
# fuse (relu bw, conv2d, conv2d bw, relu)
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 128, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x6:=UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
x9:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),
x9,)),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(9, ('METAL', 96, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(48, 0, 0, 4, 1, 16, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(16, ('METAL', 432, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(0, 0, 27, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),
x6,)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(18, ('METAL', 128, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 3, 2, 3), strides=(64, 4, 2, 0, 1, 0), offset=0, mask=((0, 2), (0, 16), (0, 2), (0, 1), (0, 2), (0, 1)), contiguous=False), View(shape=(1, 2, 1, 16, 3, 2, 3, 2), strides=(0, 576, 0, 36, 12, 6, 2, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
@unittest.expectedFailure
def test_swizzle_failure_permute(self):
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()),
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12,
UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()),
x15,)),)),
x6,)),)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12,
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
x15,)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()),
x10,)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
def test_non_contiguous_view_simplify(self):
st = ShapeTracker(views=(View(shape=(2048, 2048), strides=(1, 2048), offset=0, mask=None, contiguous=False),))
a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, 4194304, dtypes.char), st.to_uop()))
ret = swizzle_rewrite(a.view(st))
self.assertEqual(ret.st_arg, st+st)
def test_contiguous_view_simplify(self):
base = ShapeTracker.from_shape((32, 32))
a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
swizzle = a.reshape((64, 16))
swizzle = graph_rewrite(swizzle, remove_movement_ops)
self.assertEqual(swizzle_cnt(swizzle), 1)
ret = swizzle_rewrite(swizzle)
self.assertEqual(ret.st_arg, base.reshape((64, 16))) # late rewrite
reswizzle = a.reshape((64, 16)).reshape((32, 32))
self.assertEqual(swizzle_cnt(reswizzle), 0) # instant rule
ret = swizzle_rewrite(reswizzle)
self.assertIs(ret, reswizzle)
def test_late_fusion_post_permute_simpler(self):
base = ShapeTracker.from_shape((32, 16, 1))
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
r = start.expand((32, 16, 16)).r(Ops.ADD, (2,))
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
to_store = add.permute((1, 0, 2)).contiguous()
to_store = graph_rewrite(to_store, remove_movement_ops)
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
self.assertIs(to_store.src[0].op, Ops.VIEW)
ret = graph_rewrite(to_store, view_left)
self.assertEqual(swizzle_cnt(ret), 1)
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST Ops
a = Tensor.rand(10, 10).realize()
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
# TODO: this VALID can clean up, where do we need st?
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
def test_mask_dim_1(self):
# mask out dim = 1 works too
a = Tensor.rand(10, 10).realize()
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
def test_zero_size_alt(self):
st = ShapeTracker.from_shape((135, 0, 9))
a = UOp(Ops.VIEW, dtypes.float, (UOp.new_buffer(Device.DEFAULT, 121, dtypes.float), UOp(Ops.EMPTY, dtypes.float)), st)
b = a.pad(pad_arg:=((0, 0), (0, 0), (18, 0)))
self.assertEqual(b.st, st.pad(pad_arg))
# TODO: why does this help?
b = graph_rewrite(b, remove_movement_ops)
self.assertIs(b.base.src[1], UOp.const(dtypes.float, 0))
def test_partial_mask(self):
# partial masked out does not degrade into CONST
a = Tensor.rand(10, 10).realize()
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
self.assertEqual(store_val(sched[-1]).st_arg, b.lazydata.st)
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
@track_rewrites(named=True)
def big_graph_rewrite(big_graph:UOp, ctx) -> UOp: return graph_rewrite(big_graph, do_realize, ctx)
class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):
x = UOp.const(dtypes.int, 0)
big_graph = big_graph_rewrite(x.sink(), ctx:=ScheduleContext())
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(ctx.realizes), 0)
def test_sink_childless_const_alt(self):
x = UOp.const(dtypes.int, 0)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int, (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
big_graph = big_graph_rewrite(UOp.sink(x, y), ctx:=ScheduleContext())
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(ctx.realizes), 0)
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST (post expand)
y = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 2, dtypes.int), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), ctx:=ScheduleContext())
self.assertIs(big_graph, out.sink())
self.assertEqual(len(ctx.realizes), 1)
tensor_const_pm = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, src=()))), lambda: True),
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),
])
class TestConst(unittest.TestCase):
# ** part 1: basic functionality of a tensor directly created from CONST
def test_tensor_const(self):
a = Tensor(1)
print(a.lazydata)
self.assertTrue(tensor_const_pm.rewrite(a.lazydata))
def test_tensor_variable(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)
print(a.lazydata)
self.assertTrue(tensor_const_pm.rewrite(a.lazydata))
def test_uop_methods(self):
a = Tensor(1)
self.assertTrue(a.lazydata.is_unrealized_const())
self.assertTrue(a.lazydata.is_unrealized_unmasked_const())
a = Tensor.ones((4, 4))
self.assertTrue(a.lazydata.is_unrealized_const())
self.assertTrue(a.lazydata.is_unrealized_unmasked_const())
a = Tensor.ones((4, 4)).pad((1, 1),)
self.assertTrue(a.lazydata.is_unrealized_const())
self.assertFalse(a.lazydata.is_unrealized_unmasked_const())
def test_const_schedule(self):
a = Tensor.ones((4, 4))
sched = a.schedule()
self.assertEqual(len(sched), 0)
def test_const_contiguous_schedule(self):
# this ends up in the big graph
a = Tensor.ones((4,)).contiguous()
sched = a.schedule()
self.assertEqual(len(sched), 1)
def test_const_ast(self):
a = Tensor.ones((4,)).pad((1, 1)).contiguous()
sched = a.schedule()
print(sched[0].ast)
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0)))),))
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
run_schedule(sched)
self.assertListEqual(a.tolist(), [0, 1, 1, 1, 1, 0])
# TOOD: currently even unmasked constants are VALID until codegen
def test_unmasked_const_ast(self):
a = Tensor.ones((4,)).contiguous()
sched = a.schedule()
print(sched[0].ast)
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0)))),))
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
run_schedule(sched)
self.assertListEqual(a.tolist(), [1, 1, 1, 1])
# ** part 2: scheduler behavior when const folding happens later
def test_const_folding_no_realize(self):
a = Tensor([1, 2, 3, 4])*0
sched = a.schedule()
self.assertEqual(len(sched), 0)
def test_src_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.full((4,), 1).contiguous().realize()
b = Tensor.full((4,), 2).contiguous().realize()
mul0 = a*0
add = b+mul0
sched = add.schedule()
self.assertEqual(len(sched), 0)
# b+0 and b share the same underlying device memory
self.assertIs(add.lazydata.realized, b.lazydata.realized)
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
def test_src_masked_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.full((4,), 1).contiguous().realize()
b = Tensor.full((6,), 2).contiguous().realize()
mul0 = a*0
add = b+mul0.pad((1, 1), value=2)
sched = add.schedule()
self.assertEqual(len(sched), 1)
run_schedule(sched)
# add gets assigned to a new buffer
self.assertIsNot(add.lazydata.realized, b.lazydata.realized)
self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4])
# ** part 3: Tensor variable bindings
@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
def test_var_schedule(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)
sched = a.schedule()
self.assertEqual(len(sched), 0)
def test_add_tvar(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)+2
sched, var_vals = a.schedule_with_vars()
self.assertEqual(len(sched), 1)
run_schedule(sched, var_vals)
self.assertEqual(a.tolist(), 3)
if __name__ == '__main__':
unittest.main(verbosity=2)