You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					50 lines
				
				1.5 KiB
			
		
		
			
		
	
	
					50 lines
				
				1.5 KiB
			| 
								 
											2 weeks ago
										 
									 | 
							
								import os
							 | 
						||
| 
								 | 
							
								#os.environ["METAL"] = "1"
							 | 
						||
| 
								 | 
							
								import numpy as np
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								BS = 64
							 | 
						||
| 
								 | 
							
								CIN = 256
							 | 
						||
| 
								 | 
							
								COUT = 256
							 | 
						||
| 
								 | 
							
								HW = 32
							 | 
						||
| 
								 | 
							
								K = 3
							 | 
						||
| 
								 | 
							
								PADDING = 0
							 | 
						||
| 
								 | 
							
								# TODO: this is doing some trick, since with CIN=256 COUT=256 it's over 10.4 TFLOPS.
							 | 
						||
| 
								 | 
							
								# are winograd convs less flops? it appears so if they are batched
							 | 
						||
| 
								 | 
							
								# https://www.cse.ust.hk/~weiwa/papers/yan-ppopp20.pdf
							 | 
						||
| 
								 | 
							
								FLOPS = BS*K*K*CIN*HW*HW*COUT*2
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								nb = np.random.default_rng().standard_normal(size=(BS,CIN,HW,HW), dtype=np.float32)
							 | 
						||
| 
								 | 
							
								nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float32)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								try:
							 | 
						||
| 
								 | 
							
								  import time, torch, torch.mps
							 | 
						||
| 
								 | 
							
								  b = torch.from_numpy(nb).to('mps')
							 | 
						||
| 
								 | 
							
								  c = torch.from_numpy(nc).to('mps')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  def torch_prog(b, c):
							 | 
						||
| 
								 | 
							
								    st = time.perf_counter()
							 | 
						||
| 
								 | 
							
								    a = torch.nn.functional.conv2d(b, c, padding=PADDING)
							 | 
						||
| 
								 | 
							
								    torch.mps.synchronize()
							 | 
						||
| 
								 | 
							
								    return time.perf_counter() - st
							 | 
						||
| 
								 | 
							
								  tm = min([torch_prog(b, c) for _ in range(20)])
							 | 
						||
| 
								 | 
							
								  print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
							 | 
						||
| 
								 | 
							
								except RuntimeError:
							 | 
						||
| 
								 | 
							
								  print("no torch metal conv")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from tinygrad.tensor import Tensor
							 | 
						||
| 
								 | 
							
								from tinygrad.engine.jit import TinyJit
							 | 
						||
| 
								 | 
							
								from tinygrad import Device
							 | 
						||
| 
								 | 
							
								b = Tensor(nb)
							 | 
						||
| 
								 | 
							
								c = Tensor(nc)
							 | 
						||
| 
								 | 
							
								# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
							 | 
						||
| 
								 | 
							
								@TinyJit
							 | 
						||
| 
								 | 
							
								def tiny_jit(b, c):
							 | 
						||
| 
								 | 
							
								  return b.conv2d(c, padding=PADDING).realize()
							 | 
						||
| 
								 | 
							
								def tiny_prog(b, c):
							 | 
						||
| 
								 | 
							
								  st = time.perf_counter()
							 | 
						||
| 
								 | 
							
								  a = tiny_jit(b, c)
							 | 
						||
| 
								 | 
							
								  Device[a.device].synchronize()
							 | 
						||
| 
								 | 
							
								  return time.perf_counter() - st
							 | 
						||
| 
								 | 
							
								tm = min([tiny_prog(b, c) for _ in range(5)])
							 | 
						||
| 
								 | 
							
								print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
							 |