openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

221 lines
8.0 KiB

import unittest
from tinygrad import Tensor, UOp, Context
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import KernelInfo, AxisType
# **** kernels ****
def custom_arange_kernel(C:UOp) -> UOp:
i = UOp.range(C.size, 0)
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
def custom_eye_kernel(C:UOp) -> UOp:
i = UOp.range(C.shape[0], 0)
j = UOp.range(C.shape[1], 1)
return C[i, j].store((i.eq(j)).cast(C.dtype.base)).end(i, j).sink(arg=KernelInfo(name=f"custom_eye_{C.size}"))
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert B.size == A.size
i = UOp.range(A.size, 0)
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp:
C,A,B = C.flatten(), A.flatten(), B.flatten()
i = UOp.range(C.size, 0)
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten()
assert C.size == D.size
i = UOp.range(C.size, 0)
store_c = C[i].store(A[i]+B[i])
store_d = D[i].store(A[i]*B[i])
return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.size}")).simplify()
def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
assert A.shape[1] == B.shape[0]
i, j, k = UOp.range(C.shape[0], 0), UOp.range(C.shape[1], 1), UOp.range(A.shape[1], 2, axis_type=AxisType.REDUCE)
C = C[i, j].set(0.0)
C = C[i, j].set(C.after(k)[i, j] + A[i, k] * B[k, j], end=k)
prog = C.end(i, j)
return prog.sink(arg=KernelInfo(name=f"custom_gemm_{C.shape[0]}_{C.shape[1]}_{A.shape[1]}", opts_to_apply=()))
def custom_sum(B:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0, axis_type=AxisType.REDUCE)
B = B[0].set(0.0)
B = B[0].set(B.after(i)[0] + A[i], end=i)
return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
def flip_contract_kernel(dest:UOp, src:UOp):
i = UOp.range(dest.shape[0], 0)
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
vec = src[i, j].contract(j)
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
def slice_sum_kernel(dest:UOp, src:UOp):
G = UOp.range(src.shape[0], 0)
slice_src = src[G, :]
reg = UOp.placeholder((1,), dest.dtype.base, 0, addrspace=AddrSpace.REG)
reg = reg.after(G)[0].set(0)
R = UOp.range(src.shape[1], 1, AxisType.REDUCE)
reg = reg[0].set(reg.after(R)[0] + slice_src[R], end=R)
ast = dest[G].set(reg[0], end=G)
return ast.sink(arg=KernelInfo(name=f"slice_sum_{src.shape[0]}_{src.shape[1]}", opts_to_apply=()))
def simple_qkv_kernel(O:UOp, Q:UOp, K:UOp, V:UOp) -> UOp:
# attention without softmax
N, d = Q.shape[0], Q.shape[1]
i = UOp.range(N, 0) # output row
d_out = UOp.range(d, 1) # output column
j = UOp.range(N, 2, axis_type=AxisType.REDUCE)
k_inner = UOp.range(d, 3, axis_type=AxisType.REDUCE)
qk_acc = UOp.placeholder((1,), Q.dtype.base, 0, addrspace=AddrSpace.REG)
qk_acc = qk_acc.after(i, j)[0].set(0.0)
qk_acc = qk_acc[0].set(qk_acc.after(k_inner)[0] + Q[i, k_inner] * K[j, k_inner], end=k_inner)
qk_score = qk_acc[0] / (d ** 0.5)
out_acc = UOp.placeholder((1,), Q.dtype.base, 1, addrspace=AddrSpace.REG)
out_acc = out_acc.after(i, d_out)[0].set(0.0)
out_acc = out_acc[0].set(out_acc.after(j)[0] + qk_score * V[j, d_out], end=j)
store = O[i, d_out].store(out_acc[0])
return store.end(d_out).end(i).sink(arg=KernelInfo(name=f"simple_qkv_{N}_{d}", opts_to_apply=()))
# **** backward callbacks ****
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
out, a, b = kernel.src
grad_a = (Tensor(gradient) @ Tensor(b).T).uop
grad_b = (Tensor(a).T @ Tensor(gradient)).uop
return (None, grad_a, grad_b)
def backward_gemm_custom(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
out, a, b = kernel.src
grad_a = Tensor.empty_like(Tensor(a)).custom_kernel(Tensor(gradient), Tensor(b).T, fxn=custom_gemm)[0].uop
grad_b = Tensor.empty_like(Tensor(b)).custom_kernel(Tensor(a).T, Tensor(gradient), fxn=custom_gemm)[0].uop
return (None, grad_a, grad_b)
# **** tests ****
class TestCustomKernel(unittest.TestCase):
def test_simple(self):
a = Tensor.ones(16, 16).contiguous()
b = Tensor.ones(16, 16).contiguous()
c = Tensor.empty(16, 16)
c = Tensor.custom_kernel(c,a,b, fxn=custom_elementwise_add_kernel)[0]
out = c.flatten().tolist()
assert all(x == 2 for x in out), "all 2"
def test_multioutput(self):
a = Tensor.full((16, 16), 3.).contiguous()
b = Tensor.full((16, 16), 3.).contiguous()
c = Tensor.empty(16, 16)
d = Tensor.empty(16, 16)
c,d = Tensor.custom_kernel(c,d,a,b, fxn=custom_elementwise_addmul_kernel)[:2]
Tensor.realize(c,d)
assert all(x == 6 for x in c.flatten().tolist()), "all 6"
assert all(x == 9 for x in d.flatten().tolist()), "all 9"
def test_arange(self):
ref = Tensor.arange(100)
tst = Tensor.empty_like(ref)
tst = tst.custom_kernel(fxn=custom_arange_kernel)[0]
self.assertTrue((ref == tst).all().item())
def test_eye(self):
ref = Tensor.eye(1024).contiguous().realize()
tst = Tensor.empty_like(ref)
tst = tst.custom_kernel(fxn=custom_eye_kernel)[0]
self.assertTrue((ref == tst).all().item())
def test_flip_contract(self):
a = Tensor.randn(10,4)
b = Tensor.empty_like(a)
b = b.custom_kernel(a, fxn=flip_contract_kernel)[0]
self.assertTrue((a.flip(1) == b).all().item())
def test_noncontig(self):
a = Tensor.ones(16, 16).contiguous()
tst = Tensor.empty_like(a)
b = a+1
b_p1 = Tensor.custom_kernel(tst, b, fxn=custom_add_one_kernel)[0]
self.assertTrue((b_p1 == 3).all().item())
def test_sum(self):
# TODO: this only works for float, and silently fails with int
a = Tensor([1.0, 2, 3, 4, 5])
tst = Tensor.empty(1)
b = Tensor.custom_kernel(tst, a, fxn=custom_sum)[0]
self.assertEqual(b.item(), 15)
def test_slice_sum(self):
A = Tensor.randn(16, 16).contiguous()
B = Tensor.empty(16)
B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0]
self.assertTrue(B.allclose(A.sum(1)))
def test_gemm(self):
N = 16
a = Tensor.randn(N, N)
b = Tensor.randn(N, N)
c = Tensor.empty(N, N)
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0]
err = (tst - (a@b)).square().max()
self.assertLess(err.item(), 1e-6)
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
# NOTE: grad_fxn doesn't work with pyrender
@Context(SPEC=1)
def test_gemm_backward(self, custom_backward_gemm=False):
N = 4
a_rand = Tensor.randn(N, 8)
b_rand = Tensor.randn(8, N)
Tensor.realize(a_rand, b_rand)
a, b = Tensor(a_rand.numpy(), requires_grad=True), Tensor(b_rand.numpy(), requires_grad=True)
c = Tensor.empty(N, N)
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm, grad_fxn=backward_gemm_custom if custom_backward_gemm else backward_gemm)[0]
tst.sum().backward()
grad_a, grad_b = a.grad, b.grad
Tensor.realize(tst, grad_a, grad_b)
a, b = Tensor(a_rand.numpy(), requires_grad=True), Tensor(b_rand.numpy(), requires_grad=True)
ref = (a@b)
ref.sum().backward()
real_grad_a, real_grad_b = a.grad, b.grad
Tensor.realize(ref, real_grad_a, real_grad_b)
err = (tst - ref).square().max()
self.assertLess(err.item(), 1e-6)
err = (grad_a - real_grad_a).square().max()
self.assertLess(err.item(), 1e-6)
err = (grad_b - real_grad_b).square().max()
self.assertLess(err.item(), 1e-6)
def test_simple_qkv(self):
N, d = 8, 4
Q = Tensor.randn(N, d)
K = Tensor.randn(N, d)
V = Tensor.randn(N, d)
O = Tensor.empty(N, d)
O_custom = Tensor.custom_kernel(O, Q, K, V, fxn=lambda o,q,k,v: simple_qkv_kernel(o,q,k,v))[0]
O_ref = ((Q @ K.T) / (d ** 0.5)) @ V
Tensor.realize(O_custom, O_ref)
err = (O_custom - O_ref).square().max()
self.assertLess(err.item(), 1e-6)
if __name__ == '__main__':
unittest.main()