openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

107 lines
4.1 KiB

import time
import triton
import triton.language as tl
from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.engine.realize import CompiledRunner, ExecItem, ProgramSpec
from tinygrad.helpers import getenv
np.set_printoptions(suppress=True)
@triton.jit
def matmul_kernel(c_ptr, a_ptr, b_ptr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
M, N, K = 4096, 4096, 4096
stride_am = 4096
stride_ak = 1
stride_bk = 4096
stride_bn = 1
stride_cm = 4096
stride_cn = 1
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = tl.cast(accumulator, tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_ptrs, c)
# CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py
if __name__ == "__main__":
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 64
M, N, K = 4096, 4096, 4096
# **** torch test ****
if getenv("TORCH"):
import torch
c = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
a = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
b = torch.empty((K, N), device='cuda:0', dtype=torch.float16)
for i in range(5):
st = time.perf_counter()
matmul_kernel[triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N)](
c, a, b, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
torch.cuda.synchronize()
et = time.perf_counter() - st
print(f"TFLOPS {2*M*N*K*1e-12/et:.2f}")
# **** tinygrad test ****
compiled = triton_compile(ASTSource(matmul_kernel, "*fp16,*fp16,*fp16",
attrs=AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=()),
constants={"BLOCK_SIZE_M": BLOCK_SIZE_M, "BLOCK_SIZE_N": BLOCK_SIZE_N, "BLOCK_SIZE_K": BLOCK_SIZE_K}))
print(compiled.metadata)
A, B = Tensor.normal(M, K, std=1e-1, dtype=dtypes.float16).realize(), Tensor.normal(K, N, std=1e-1, dtype=dtypes.float16).realize()
C = A.matmul(B)
sched = C.schedule()
si = sched[-1]
src = compiled.asm["ptx"]
# specify the shared memory here so we don't need to do it dynamically
src = src.replace(".extern .shared .align 16 .b8 global_smem[];", f".shared .align 16 .b8 global_smem[{compiled.metadata.shared}];")
# useless comment spam
src = src.replace("\t// begin inline asm\n", "")
src = src.replace("\t// end inline asm\n", "")
# remove debug sections
src = src.split("\t.file")[0]
assert '.extern .shared' not in src
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
tflops = []
for i in range(5):
tm = ei.run(wait=True)
tflops.append((2*M*K*N/tm)*1e-12)
print(f"TFLOPS: {max(tflops):.2f}")
# check correctness
if getenv("VERIFY"):
from tinygrad.engine.realize import run_schedule
triton_buf = np.frombuffer(si.bufs[0].as_buffer(), np.float16).reshape(M,N)
print(triton_buf)
run_schedule(sched)
tinygrad_buf = np.frombuffer(si.bufs[0].as_buffer(), np.float16).reshape(M,N)
print(tinygrad_buf)
np.testing.assert_allclose(triton_buf, tinygrad_buf)
print("correct!")