openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
1.9 KiB

import numpy as np
from tinygrad import Tensor, Device, GlobalCounters
from tinygrad.helpers import Timing
d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
N = 256
FLOPS = N*N*N*2
# LazyBuffer should make three fields lists: self.st (all must have the same shape), self.realized, and self.device
def explicit_shard_W_axis_1(X, W):
Xs = [X.to(d0), X.to(d1)]
Ws = [W[:, :N//2].to(d0), W[:, N//2:].to(d1)] # TODO: these shouldn't make copies on the original device
# pad them to form the correct size
Ws = [Ws[0].pad((None, (0,N//2))), Ws[1].pad((None, (N//2,0)))]
for x in Xs: assert x.shape == X.shape
for w in Ws: assert w.shape == W.shape
# TODO: it shouldn't be faster with these realize
for x in Xs+Ws: x.realize()
def lm(x:Tensor, w:Tensor):
# these are movement ops on the local device
x = x.reshape(N, 1, N).expand(N, N, N)
w = w.T.reshape(1, N, N).expand(N, N, N)
m = x*w
assert m.lazydata.st.views[0].mask is not None
ret = m.sum(2)
return ret
#Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])]
Os = [Xs[0] @ Ws[0], Xs[1] @ Ws[1]]
for x in Os: x.realize()
return Os[0].to(Device.DEFAULT) + Os[1].to(Device.DEFAULT)
#return Tensor.cat(*[x.to(Device.DEFAULT) for x in Os], dim=1) # TODO: someday we can remove this copy too
def matmul(X, W):
return explicit_shard_W_axis_1(X, W)
#return X@W
if __name__ == "__main__":
with Timing("init devices: "):
Device[d0], Device[d1]
with Timing("create tensors: "):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()
#with Timing("warmup: "):
# O = matmul(X, W).numpy()
GlobalCounters.reset()
print("******** multiply start")
with Timing("******** multiply done: ", lambda x: f" {FLOPS/x:.2f} GFLOPS"):
O = matmul(X, W).realize()
Device[Device.DEFAULT].synchronize()
with Timing("testing: "):
val = X.numpy() @ W.numpy()
np.testing.assert_allclose(val, O.numpy(), atol=1e-5)