You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							33 lines
						
					
					
						
							1.3 KiB
						
					
					
				
			
		
		
	
	
							33 lines
						
					
					
						
							1.3 KiB
						
					
					
				from tinygrad.helpers import getenv
 | 
						|
from tinygrad import dtypes, Tensor
 | 
						|
 | 
						|
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
 | 
						|
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
 | 
						|
 | 
						|
CNT = getenv("CNT", 8)
 | 
						|
BS = getenv("BS", 16)
 | 
						|
CIN = getenv("CIN", 128)
 | 
						|
COUT = getenv("COUT", 128)
 | 
						|
HW = getenv("HW", 128)
 | 
						|
K = getenv("K", 3)
 | 
						|
PADDING = getenv("PADDING", 1)
 | 
						|
COMP = getenv("COMP", 0)
 | 
						|
ATOL = getenv("ATOL", 1e-4)
 | 
						|
RTOL = getenv("RTOL", 3e-2)
 | 
						|
 | 
						|
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
 | 
						|
def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize()
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
  a, b = rand_input()
 | 
						|
  for i in range(CNT):
 | 
						|
    if i > 0 and getenv("RAND", 0) != 0:
 | 
						|
      a, b = rand_input()
 | 
						|
    c = a.conv2d(b, padding=PADDING, dtype=acc_dtype).realize()
 | 
						|
 | 
						|
  if COMP:
 | 
						|
    import numpy as np, time, torch
 | 
						|
    torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")
 | 
						|
    ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device)
 | 
						|
    tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING)
 | 
						|
    np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=ATOL, rtol=RTOL)
 | 
						|
 |