You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					60 lines
				
				1.9 KiB
			
		
		
			
		
	
	
					60 lines
				
				1.9 KiB
			| 
											1 week ago
										 | import numpy as np
 | ||
|  | from tinygrad import Tensor, Device, GlobalCounters
 | ||
|  | from tinygrad.helpers import Timing
 | ||
|  | 
 | ||
|  | d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
 | ||
|  | N = 256
 | ||
|  | FLOPS = N*N*N*2
 | ||
|  | 
 | ||
|  | # LazyBuffer should make three fields lists: self.st (all must have the same shape), self.realized, and self.device
 | ||
|  | 
 | ||
|  | def explicit_shard_W_axis_1(X, W):
 | ||
|  |   Xs = [X.to(d0), X.to(d1)]
 | ||
|  |   Ws = [W[:, :N//2].to(d0), W[:, N//2:].to(d1)]   # TODO: these shouldn't make copies on the original device
 | ||
|  |   # pad them to form the correct size
 | ||
|  |   Ws = [Ws[0].pad((None, (0,N//2))), Ws[1].pad((None, (N//2,0)))]
 | ||
|  |   for x in Xs: assert x.shape == X.shape
 | ||
|  |   for w in Ws: assert w.shape == W.shape
 | ||
|  | 
 | ||
|  |   # TODO: it shouldn't be faster with these realize
 | ||
|  |   for x in Xs+Ws: x.realize()
 | ||
|  |   def lm(x:Tensor, w:Tensor):
 | ||
|  |     # these are movement ops on the local device
 | ||
|  |     x = x.reshape(N, 1, N).expand(N, N, N)
 | ||
|  |     w = w.T.reshape(1, N, N).expand(N, N, N)
 | ||
|  |     m = x*w
 | ||
|  |     assert m.uop.st.views[0].mask is not None
 | ||
|  |     ret = m.sum(2)
 | ||
|  |     return ret
 | ||
|  |   #Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])]
 | ||
|  |   Os = [Xs[0] @ Ws[0], Xs[1] @ Ws[1]]
 | ||
|  |   for x in Os: x.realize()
 | ||
|  |   return Os[0].to(Device.DEFAULT) + Os[1].to(Device.DEFAULT)
 | ||
|  | 
 | ||
|  |   #return Tensor.cat(*[x.to(Device.DEFAULT) for x in Os], dim=1)   # TODO: someday we can remove this copy too
 | ||
|  | 
 | ||
|  | def matmul(X, W):
 | ||
|  |   return explicit_shard_W_axis_1(X, W)
 | ||
|  |   #return X@W
 | ||
|  | 
 | ||
|  | if __name__ == "__main__":
 | ||
|  |   with Timing("init devices: "):
 | ||
|  |     Device[d0], Device[d1]
 | ||
|  | 
 | ||
|  |   with Timing("create tensors: "):
 | ||
|  |     X = Tensor.kaiming_uniform(N, N).realize()
 | ||
|  |     W = Tensor.kaiming_uniform(N, N).realize()
 | ||
|  | 
 | ||
|  |   #with Timing("warmup: "):
 | ||
|  |   #  O = matmul(X, W).numpy()
 | ||
|  | 
 | ||
|  |   GlobalCounters.reset()
 | ||
|  |   print("******** multiply start")
 | ||
|  |   with Timing("******** multiply done: ", lambda x: f"  {FLOPS/x:.2f} GFLOPS"):
 | ||
|  |     O = matmul(X, W).realize()
 | ||
|  |     Device[Device.DEFAULT].synchronize()
 | ||
|  | 
 | ||
|  |   with Timing("testing: "):
 | ||
|  |     val = X.numpy() @ W.numpy()
 | ||
|  |     np.testing.assert_allclose(val, O.numpy(), atol=1e-5)
 |