openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

113 lines
4.9 KiB

import numpy as np
import time, torch, torch.mps
from tinygrad import Tensor, TinyJit, Device
from tinygrad.helpers import flat_mv
from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, MetalCompiler
N = 16384
M = 4096
FLOPS = N*M*2
nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
def torch_prog(b, c):
st = time.perf_counter()
a = b@c
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
torch_a = (b@c).cpu()
device = MetalDevice("METAL")
metalalloc = MetalAllocator(device)
WORKSIZE_ROW = 16
WORKSIZE_COL = 1
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1]
prog = MetalProgram(device, "test", MetalCompiler(device).compile(f"""
#include <metal_stdlib>
using namespace metal;
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
int gidx0 = gid.x; /* {GLOBAL_SIZE[0]} */
int lidx1 = lid.x; /* {LOCAL_SIZE[0]} */
int lidx2 = lid.y; /* {LOCAL_SIZE[1]} */
int lidx3 = lid.z; /* {LOCAL_SIZE[2]} */
// 4 rows per thread
threadgroup float4 acc0[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}];
int acc0_index = ((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*lidx3);
acc0[acc0_index] = float4(0.0f,0.0f,0.0f,0.0f);
threadgroup float4 val1[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}];
// iterate over the columns
for (int ridx2 = 0; ridx2 < {N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))}; ++ridx2) {{
// load 4*threadgroup_size columns into shared memory
int col_1 = (((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+(lidx1*{LOCAL_SIZE[1]})+lidx2;
val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+((lidx1*{LOCAL_SIZE[1]})+lidx2)] = *((device float4*)(data1+(col_1*4)));
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int ridx3 = 0; ridx3 < {LOCAL_SIZE[0]*LOCAL_SIZE[1]}; ++ridx3) {{
int col = ((((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+ridx3);
float4 val1_0 = val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+ridx3];
float4 val2_0 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*0})));
float4 val2_1 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*1})));
float4 val2_2 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*2})));
float4 val2_3 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*3})));
acc0[acc0_index] = ((val1_0.x*val2_0)+acc0[acc0_index]);
acc0[acc0_index] = ((val1_0.y*val2_1)+acc0[acc0_index]);
acc0[acc0_index] = ((val1_0.z*val2_2)+acc0[acc0_index]);
acc0[acc0_index] = ((val1_0.w*val2_3)+acc0[acc0_index]);
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
}} /* reduce */
if (lidx3 == 0) {{
float4 out = float4(0.0f,0.0f,0.0f,0.0f);
for (int n = 0; n < {LOCAL_SIZE[2]}; n++) {{
out += acc0[((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*n)];
}}
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
}}
}}
"""))
a = metalalloc.alloc(M*4)
b = metalalloc.alloc(N*4)
c = metalalloc.alloc(N*M*4)
metalalloc._copyin(b,nb.tobytes())
metalalloc._copyin(c,nc.tobytes())
def metalrun():
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
return a
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(metalrun) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
metal_a = np.zeros(M, dtype=np.float32)
metalalloc._copyout(flat_mv(metal_a.data), a)
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return (b@c).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
Device["METAL"].synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
tiny_a = tiny_jit(b, c).numpy()
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)