openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

168 lines
6.8 KiB

from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp, KernelInfo, sint, AxisType
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import getenv
N = 4096
M = K = N
run_count = 5
# ---------------------------
# launch/config constants
# ---------------------------
WARP_SIZE = 32
# Threadblock tile sizes (block-level tile of C that a block computes)
BLOCK_N = 128 # columns of C (N-dim) per block
BLOCK_M = 128 # rows of C (M-dim) per block
BLOCK_K = 8 # K-slice per block iteration
# Register tile sizes (per-thread accumulator tile of C)
TN = 4 # columns per thread
TM = 4 # rows per thread
is_kernel5 = getenv("K5", 0)
THREADS_PER_BLOCK = 128 if is_kernel5 else 256
assert THREADS_PER_BLOCK % BLOCK_N == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_N"
assert THREADS_PER_BLOCK % BLOCK_K == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_K"
assert (BLOCK_N * BLOCK_K) % THREADS_PER_BLOCK == 0
assert (BLOCK_M * BLOCK_K) % THREADS_PER_BLOCK == 0
WARPS_PER_BLOCK = THREADS_PER_BLOCK // WARP_SIZE
WAVE_TILE_N = 128 if is_kernel5 else 64
WAVE_TILE_M = BLOCK_N * BLOCK_M // WARPS_PER_BLOCK // WAVE_TILE_N
assert BLOCK_N % WAVE_TILE_N == 0, "BN must be a multiple of WN"
assert BLOCK_M % WAVE_TILE_M == 0, "BM must be a multiple of WM"
WAVES_IN_BLOCK_X = BLOCK_N // WAVE_TILE_N
WAVES_IN_BLOCK_Y = BLOCK_M // WAVE_TILE_M
assert WAVES_IN_BLOCK_X * WAVES_IN_BLOCK_Y == WARPS_PER_BLOCK, "wave grid must match warps/block"
LANES_PER_WAVE_X = 8
LANES_PER_WAVE_Y = 4
ITERS_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_X * TN)
ITERS_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_Y * TM)
assert WAVE_TILE_N % (LANES_PER_WAVE_X * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_X*TN"
assert WAVE_TILE_M % (LANES_PER_WAVE_Y * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_Y*TM"
def rngs_for_shape(shape:tuple[sint, ...], rng:int, axis_type=AxisType.LOOP): return [UOp.range(s, rng+i, axis_type) for i,s in enumerate(shape)]
def copy(dest:UOp, src:UOp, rng:int, set=False, upcast=False):
assert dest.shape == src.shape
rngs = rngs_for_shape(src.shape, rng, AxisType.UPCAST if upcast else AxisType.LOOP)
copy = dest[*rngs].store(src[*rngs]).end(*rngs)
return dest.after(copy) if set else copy
def hand_spec_kernel3():
# ---------------------------
# block indices & placeholders
# ---------------------------
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
a = UOp.placeholder((N, N), dtypes.float, slot=1)
b = UOp.placeholder((N, N), dtypes.float, slot=2)
c = UOp.placeholder((N, N), dtypes.float, slot=0)
# index the output with the globals
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[blockIdx_y, :, blockIdx_x, :]
# open the main reduction range
k_tile_range = UOp.range(N // BLOCK_K, 0, AxisType.REDUCE)
a = a.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_K, BLOCK_K)[blockIdx_y, :, k_tile_range, :]
b = b.reshape(N // BLOCK_K, BLOCK_K, N // BLOCK_N, BLOCK_N)[k_tile_range, :, blockIdx_x, :]
# globals are no longer used, they are already in the indexes
del blockIdx_y, blockIdx_x
# ---------------------------
# GLOBAL -> LOCAL (As, Bs)
# ---------------------------
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
# A: read BM x BK tiles (permute on store into locals)
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
As_store = copy(As.permute((1,0)).reshape(-1, THREADS_PER_BLOCK)[:, tid], a.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=100)
# B: read BK x BN tiles
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
Bs_store = copy(Bs.reshape(-1, THREADS_PER_BLOCK)[:, tid], b.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=200)
# TODO: can we automate barrier?
barrier = UOp.barrier(As_store, Bs_store)
As, Bs = As.after(barrier), Bs.after(barrier)
# open inner k range
k = UOp.range(BLOCK_K, 3, AxisType.REDUCE)
# ---------------------------
# LOCAL -> REG (per-wave tiles)
# ---------------------------
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_X
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_X
assert laneIdy.vmax+1 == LANES_PER_WAVE_Y
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
A_col = copy(A_col, As[k, :].reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)[waveIdy, :, laneIdy, :], 300, set=True, upcast=True)
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
B_row = copy(B_row, Bs[k, :].reshape(WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)[waveIdx, :, laneIdx, :], 400, set=True, upcast=True)
# ---------------------------
# FMA: c_regs += A_col * B_row
# ---------------------------
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
i = UOp.range(c_regs.size, 16)
c_regs = c_regs.after(c_regs.flatten()[i].store(0.0).end(i))
# TODO: why don't these work as upcast?
# why if the ranges merge is it slow?!? (if you change the order on end, they will merge. big slowdown on METAL)
iterWaveM, yt, iterWaveN, xt = rngs = rngs_for_shape(c_regs.shape, 500)
sink = c_regs[*rngs].store(c_regs.after(k)[*rngs] + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt)
# Close k, sync, and close K tiles
sink = sink.end(k).barrier().end(k_tile_range)
# ---------------------------
# REG -> GLOBAL (epilogue)
# ---------------------------
c = c.reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
c = c[waveIdy, :, laneIdy, :,
waveIdx, :, laneIdx, :]
sink = copy(c, c_regs.after(sink), rng=600)
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
def test_matmul(sink:UOp, N=N):
with Context(DEBUG=0):
a = Tensor.randn(N, N)
b = Tensor.randn(N, N)
hc = Tensor.empty(N, N)
Tensor.realize(a, b, hc)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
GlobalCounters.reset()
ets = []
with Context(DEBUG=2):
for _ in range(run_count):
ets.append(ei.run(wait=True))
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
GlobalCounters.reset()
with Context(DEBUG=2):
tc = (a @ b).realize()
with Context(DEBUG=0):
err = (hc - tc).square().mean().item()
print(f"mean squared error {err}")
if err > 1e-06:
raise RuntimeError("matmul is wrong!")
if __name__ == "__main__":
test_matmul(hand_spec_kernel3(), N=N)