openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

176 lines
6.3 KiB

import unittest
import numpy as np
from tinygrad import Tensor, GlobalCounters, Context, Device
from tinygrad.dtype import DTypeLike, dtypes
from tinygrad.helpers import DEBUG, get_single_element
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.device import is_dtype_supported
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
# only support axis =-1
x = x_in.reshape(-1, x_in.shape[-1])
nr_dim, r_dim = x.shape
inp = x.reshape(nr_dim, 1, 1, r_dim).expand(nr_dim, r_dim, 1, r_dim)
imx = x.reshape(nr_dim, 1, r_dim, 1).expand(nr_dim, r_dim, r_dim, r_dim).max(axis=-2, keepdim=True)
m = inp - imx.detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
ss = e.sum(axis=-1, keepdim=True)
inp = x.reshape(nr_dim, r_dim, 1, 1)
imx = x.reshape(nr_dim, 1, r_dim, 1).expand(nr_dim, r_dim, r_dim, 1).max(axis=-2, keepdim=True)
m = inp - imx.detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
out = e.div(ss).reshape(x_in.shape)
return out
def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run()
class TestFuse(unittest.TestCase):
def _test_fuse(self, fxn, *args, atol=1e-7, allow_multiple=False, **kwargs):
GlobalCounters.reset()
out_single = fxn(*args, **kwargs).fuse()
if not allow_multiple: run_one_schedule_item(out_single)
np_single = out_single.numpy()
GlobalCounters.reset()
np_multi = fxn(*args, **kwargs).numpy()
np.testing.assert_allclose(np_single, np_multi, atol=atol)
def test_fuse_norm(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a / a.mean(axis=1), a)
def test_fuse_argmax(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.argmax(axis=-1), a)
def test_fuse_softmax(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.softmax(axis=-1), a)
def test_fuse_gemm_softmax(self):
a = Tensor.rand(50,50).realize()
b = Tensor.rand(50,50).realize()
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_fuse_softmax_dtype(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
def test_fuse_arange_eye(self):
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
def test_double_gemm(self):
N = 32
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
a = (Tensor.rand(N,N)-0.5).realize()
b = (Tensor.rand(N,N)-0.5).realize()
c = (Tensor.rand(N,N)-0.5).realize()
self._test_fuse(lambda a,b,c: a@b@c, a, b, c, atol=1e-5)
def test_embedding(self):
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
vocab_sz = 123
embed_sz = 16
weight = (Tensor.rand(vocab_sz, embed_sz)-0.5).realize()
a = Tensor([1, 1, 2, 3]).realize()
def embedding(idx:Tensor):
arange = Tensor.arange(vocab_sz).unsqueeze(-1)
big_shp = idx.shape + (vocab_sz, embed_sz)
arange, vals = arange.expand(big_shp), weight.expand(big_shp)
idx = idx.reshape(idx.shape+(1, 1)).expand(big_shp)
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
self._test_fuse(embedding, a, atol=1e-5)
@unittest.skip("still broken")
def test_flash_attention(self):
BS = 4
HEADS = 2
MATDIM = 16
EMB = 8
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
q = Tensor.randn(BS, HEADS, MATDIM, EMB).realize()
k = Tensor.randn(BS, HEADS, MATDIM, EMB).realize()
v = Tensor.randn(BS, HEADS, MATDIM, EMB).realize()
# TODO: OPT is breaking things. NOOPT isn't linearizing
with Context(NOOPT=1):
self._test_fuse(Tensor.scaled_dot_product_attention, q, k, v)
class TestSoftmaxFusion(unittest.TestCase):
@classmethod
def setUpClass(cls):
with Context(TRACK_MATCH_STATS=0): cls.test = Tensor.rand(32, 10).contiguous().realize()
def setUp(self):
GlobalCounters.reset()
def test_norm(self):
print("*** norm ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
# NOTE: there's an implied expand on the mean here
sout = self.test / self.test.mean(-1, keepdim=True)
sout.realize()
print("*** single kernel norm ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
inp = self.test.reshape(32, 10, 1)
div = self.test.reshape(32, 1, 10).expand(32, 10, 10).mean(axis=-1, keepdim=True)
out = (inp / div).reshape(32, 10)
out.realize()
np.testing.assert_allclose(sout.numpy(), out.numpy())
def test_softmax(self):
# this is the softmax from scaled_dot_product_attention
# it becomes 3 kernels
print("*** softmax ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
sout = self.test.softmax(-1)
sout.realize()
print("*** single kernel softmax ***")
# NOTE: DONT_GROUP_REDUCES is required here
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1):
out = single_kernel_softmax(self.test)
out.realize()
np.testing.assert_allclose(sout.numpy(), out.numpy())
def test_auto_softmax(self):
print("*** softmax ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
sout = self.test.softmax(-1)
sout.realize()
print("*** auto single kernel softmax ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
out = self.test.contiguous().softmax(-1).fuse()
run_one_schedule_item(out)
np.testing.assert_allclose(sout.numpy(), out.numpy())
def test_softmax_bw(self):
print("*** softmax bw ***")
self.test.requires_grad_()
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
self.test.softmax(-1).sum().backward()
sg = self.test.grad.realize()
self.test.grad = None
print("*** single kernel softmax bw ***")
# NOTE: DONT_GROUP_REDUCES is required here
# TODO: fix RecursionError with DONT_GROUP_REDUCES
with self.assertRaises(RecursionError):
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1):
single_kernel_softmax(self.test).sum().backward()
g = self.test.grad.realize()
np.testing.assert_allclose(sg.numpy(), g.numpy(), atol=1e-7)
if __name__ == '__main__':
unittest.main()