You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							90 lines
						
					
					
						
							2.4 KiB
						
					
					
				
			
		
		
	
	
							90 lines
						
					
					
						
							2.4 KiB
						
					
					
				| import numpy as np
 | |
| import halide as hl
 | |
| from tinygrad.helpers import Timing, getenv
 | |
| 
 | |
| # HL_DEBUG_CODEGEN=1
 | |
| N = getenv("N", 1024)
 | |
| 
 | |
| def gemm_pipeline(gpu=False):
 | |
|   # ---------------- Vars & Parameters ----------------
 | |
|   i, j = hl.Var("i"), hl.Var("j")  # output tile coordinates
 | |
| 
 | |
|   A = hl.InputBuffer(hl.Float(32), 2)  # [M, K]
 | |
|   B = hl.InputBuffer(hl.Float(32), 2)  # [K, N]
 | |
| 
 | |
|   A.dim(0).set_bounds(0, N)
 | |
|   A.dim(1).set_bounds(0, N)
 | |
|   B.dim(0).set_bounds(0, N)
 | |
|   B.dim(1).set_bounds(0, N)
 | |
| 
 | |
|   # ---------------- Definition ----------------
 | |
| 
 | |
|   k = hl.RDom([(0, N)])
 | |
| 
 | |
|   partial = hl.Func("partial")
 | |
|   partial[i, j] = 0.0
 | |
|   partial[i, j] += A[i, k] * B[k, j]
 | |
| 
 | |
|   C = hl.Func("C")
 | |
|   C[i, j] = partial[i, j]
 | |
| 
 | |
|   if not gpu:
 | |
|     # ---------------- Schedule ----------------
 | |
|     VEC = 16
 | |
|     TILE_I = 64
 | |
|     TILE_J = 64
 | |
| 
 | |
|     io, jo, ii, ji = hl.Var("io"), hl.Var("jo"), hl.Var("ii"), hl.Var("ji")
 | |
|     C.update().tile(i, j, io, jo, ii, ji, TILE_I, TILE_J).fuse(io, jo, io).parallel(io).vectorize(ji, VEC)
 | |
|   else:
 | |
|     # ---------------- Schedule ----------------
 | |
|     GRP_I    = 8     # output tile size
 | |
|     GRP_J    = 16
 | |
| 
 | |
|     #partial.store_in(hl.MemoryType.Register)
 | |
|     #partial.update().unroll(k, 4)
 | |
| 
 | |
|     io, jo, ii, ji = hl.Var(), hl.Var(), hl.Var(), hl.Var()
 | |
|     C.gpu_tile(i, j, io, jo, ii, ji, GRP_I, GRP_J, hl.TailStrategy.RoundUp)
 | |
| 
 | |
|   return C, A, B
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|   pipe, A, B = gemm_pipeline(gpu=True)
 | |
| 
 | |
|   # NOTE: meteal does nothing
 | |
|   target = hl.get_host_target().with_feature(hl.TargetFeature.Metal)
 | |
| 
 | |
|   a_np = np.random.randn(N, N).astype(np.float32)
 | |
|   b_np = np.random.randn(N, N).astype(np.float32)
 | |
| 
 | |
|   # reverse order is correct!
 | |
|   a_hal = hl.Buffer(b_np)
 | |
|   b_hal = hl.Buffer(a_np)
 | |
|   A.set(a_hal)
 | |
|   B.set(b_hal)
 | |
| 
 | |
|   pipe.compile_to_lowered_stmt("/tmp/my_function.html", [A, B], hl.StmtOutputFormat.HTML, target=target)
 | |
|   #exit(0)
 | |
| 
 | |
|   c_hal = hl.Buffer(hl.Float(32), [N,N])
 | |
|   with Timing("halide gemm "):
 | |
|     pipe.realize(c_hal, target)
 | |
|     c_hal.copy_to_host()
 | |
|     c_out = np.array(c_hal)
 | |
|   print(c_out)
 | |
| 
 | |
|   # tinygrad gets 60 ms with no BEAM, 20 ms with BEAM on CPU
 | |
|   with Timing("halide gemm "):
 | |
|     pipe.realize(c_hal, target)
 | |
|     c_hal.copy_to_host()
 | |
| 
 | |
|   # Check correctness
 | |
|   with Timing("numpy gemm "):
 | |
|     ref = a_np @ b_np
 | |
|   max_err = np.abs(ref - c_out).max()
 | |
|   print("Max absolute error:", max_err)
 | |
|   assert max_err < 1e-4, "GEMM result incorrect!"
 | |
| 
 | |
|   print("Pipeline ran on", target)
 | |
|   print("Success - GEMM Halide-Python output matches NumPy.")
 | |
| 
 |