openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

90 lines
2.4 KiB

import numpy as np
import halide as hl
from tinygrad.helpers import Timing, getenv
# HL_DEBUG_CODEGEN=1
N = getenv("N", 1024)
def gemm_pipeline(gpu=False):
# ---------------- Vars & Parameters ----------------
i, j = hl.Var("i"), hl.Var("j") # output tile coordinates
A = hl.InputBuffer(hl.Float(32), 2) # [M, K]
B = hl.InputBuffer(hl.Float(32), 2) # [K, N]
A.dim(0).set_bounds(0, N)
A.dim(1).set_bounds(0, N)
B.dim(0).set_bounds(0, N)
B.dim(1).set_bounds(0, N)
# ---------------- Definition ----------------
k = hl.RDom([(0, N)])
partial = hl.Func("partial")
partial[i, j] = 0.0
partial[i, j] += A[i, k] * B[k, j]
C = hl.Func("C")
C[i, j] = partial[i, j]
if not gpu:
# ---------------- Schedule ----------------
VEC = 16
TILE_I = 64
TILE_J = 64
io, jo, ii, ji = hl.Var("io"), hl.Var("jo"), hl.Var("ii"), hl.Var("ji")
C.update().tile(i, j, io, jo, ii, ji, TILE_I, TILE_J).fuse(io, jo, io).parallel(io).vectorize(ji, VEC)
else:
# ---------------- Schedule ----------------
GRP_I = 8 # output tile size
GRP_J = 16
#partial.store_in(hl.MemoryType.Register)
#partial.update().unroll(k, 4)
io, jo, ii, ji = hl.Var(), hl.Var(), hl.Var(), hl.Var()
C.gpu_tile(i, j, io, jo, ii, ji, GRP_I, GRP_J, hl.TailStrategy.RoundUp)
return C, A, B
if __name__ == "__main__":
pipe, A, B = gemm_pipeline(gpu=True)
# NOTE: meteal does nothing
target = hl.get_host_target().with_feature(hl.TargetFeature.Metal)
a_np = np.random.randn(N, N).astype(np.float32)
b_np = np.random.randn(N, N).astype(np.float32)
# reverse order is correct!
a_hal = hl.Buffer(b_np)
b_hal = hl.Buffer(a_np)
A.set(a_hal)
B.set(b_hal)
pipe.compile_to_lowered_stmt("/tmp/my_function.html", [A, B], hl.StmtOutputFormat.HTML, target=target)
#exit(0)
c_hal = hl.Buffer(hl.Float(32), [N,N])
with Timing("halide gemm "):
pipe.realize(c_hal, target)
c_hal.copy_to_host()
c_out = np.array(c_hal)
print(c_out)
# tinygrad gets 60 ms with no BEAM, 20 ms with BEAM on CPU
with Timing("halide gemm "):
pipe.realize(c_hal, target)
c_hal.copy_to_host()
# Check correctness
with Timing("numpy gemm "):
ref = a_np @ b_np
max_err = np.abs(ref - c_out).max()
print("Max absolute error:", max_err)
assert max_err < 1e-4, "GEMM result incorrect!"
print("Pipeline ran on", target)
print("Success - GEMM Halide-Python output matches NumPy.")