openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

132 lines
5.5 KiB

import os
os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad import Device, dtypes
from tinygrad.helpers import getenv, flat_mv
from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, MetalCompiler
N = getenv("N", 2048)
LID = 2
device = MetalDevice("METAL")
metalalloc = MetalAllocator(device)
a = metalalloc.alloc(N*N*4)
b = metalalloc.alloc(N*N*4)
c = metalalloc.alloc(N*N*4)
na = np.zeros((N,N),dtype=np.float32)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)N
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
metalalloc._copyin(b,nb.tobytes())
metalalloc._copyin(c,nc.tobytes())
FLOPS = N*N*N*2
BW = N*N*3*4
prog = MetalProgram(device, "test", MetalCompiler(device).compile(f"""
#include <metal_stdlib>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal;
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
a += gid.x * 32 * {N} + (gid.y * {LID} + lid.y) * 32;
data1 += gid.x * 32 * {N};
data2 += (gid.y * {LID} + lid.y) * 32;
simdgroup_float8x8 acc[4][4];
for (uint i = 0; i < 4; i++) {{
for (uint j = 0; j < 4; j++) {{
acc[i][j] = simdgroup_float8x8(0);
}}
}}
simdgroup_float8x8 A[4];
simdgroup_float8x8 B[4];
for (uint k = 0; k < {N}; k+=8) {{
threadgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_load(A[0], data1+k+{0*N}, {N}, ulong2(0, 0));
simdgroup_load(A[1], data1+k+{8*N}, {N}, ulong2(0, 0));
simdgroup_load(A[2], data1+k+{16*N}, {N}, ulong2(0, 0));
simdgroup_load(A[3], data1+k+{24*N}, {N}, ulong2(0, 0));
simdgroup_load(B[0], data2+0+k*{N}, {N}, ulong2(0, 0));
simdgroup_load(B[1], data2+8+k*{N}, {N}, ulong2(0, 0));
simdgroup_load(B[2], data2+16+k*{N}, {N}, ulong2(0, 0));
simdgroup_load(B[3], data2+24+k*{N}, {N}, ulong2(0, 0));
simdgroup_multiply_accumulate(acc[0][0], A[0], B[0], acc[0][0]);
simdgroup_multiply_accumulate(acc[0][1], A[1], B[0], acc[0][1]);
simdgroup_multiply_accumulate(acc[0][2], A[2], B[0], acc[0][2]);
simdgroup_multiply_accumulate(acc[0][3], A[3], B[0], acc[0][3]);
simdgroup_multiply_accumulate(acc[1][0], A[0], B[1], acc[1][0]);
simdgroup_multiply_accumulate(acc[1][1], A[1], B[1], acc[1][1]);
simdgroup_multiply_accumulate(acc[1][2], A[2], B[1], acc[1][2]);
simdgroup_multiply_accumulate(acc[1][3], A[3], B[1], acc[1][3]);
simdgroup_multiply_accumulate(acc[2][0], A[0], B[2], acc[2][0]);
simdgroup_multiply_accumulate(acc[2][1], A[1], B[2], acc[2][1]);
simdgroup_multiply_accumulate(acc[2][2], A[2], B[2], acc[2][2]);
simdgroup_multiply_accumulate(acc[2][3], A[3], B[2], acc[2][3]);
simdgroup_multiply_accumulate(acc[3][0], A[0], B[3], acc[3][0]);
simdgroup_multiply_accumulate(acc[3][1], A[1], B[3], acc[3][1]);
simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]);
simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]);
}}
simdgroup_store(acc[0][0], a+{0+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][0], a+{8+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][0], a+{16+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][0], a+{24+0*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[0][1], a+{0+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][1], a+{8+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][1], a+{16+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][1], a+{24+8*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[0][2], a+{0+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][2], a+{8+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][2], a+{16+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][2], a+{24+16*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[0][3], a+{0+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
}}"""))
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
comp = nb@nc
metalalloc._copyout(flat_mv(na.data), a)
if N <= 32:
print(na)
print(comp)
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na, comp, atol=1e-3)
import torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
def torch_prog(b, c):
st = time.perf_counter()
a = b@c
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return (b@c).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
Device["METAL"].synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")