openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

232 lines
12 KiB

import numpy as np, os
from tinygrad.helpers import getenv, flat_mv
from tinygrad import dtypes
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Self
# for copied uops
from tinygrad.codegen.kernel import Kernel, KernelOptError
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps, KernelInfo
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.dtype import PtrDType, DType, DTYPES_DICT
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
script_dir = os.path.dirname(os.path.abspath(__file__))
# problem variations
DTYPE_IN = DTYPES_DICT[getenv("DTYPE_IN", "half")]
DTYPE_OUT = DTYPES_DICT[getenv("DTYPE_OUT", "half")]
DTYPE_ACC = DTYPES_DICT[getenv("DTYPE_ACC", "float")]
N = getenv("N", 4096)
M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 5e-3 if DTYPE_IN == dtypes.float else 1e-2)
RTOL = getenv("RTOL", 1e-4 if DTYPE_IN == dtypes.float else 1e-3)
FLOPS = M * N * K * 2
BW = 2 * ((M*K) + (K*N) + (M*N))
# algorithm variations
INPUT = getenv("INPUT", "RAND")
GEMM_VARIATION = getenv("GEMM_VARIATION", "nv_hcopt")
def randoms():
if INPUT == "RAND":
na = np.random.default_rng().normal(scale=1.0, size=(M,K)).astype(dtype=np.float32)
nb = np.random.default_rng().normal(scale=1.0, size=(K,N)).astype(dtype=np.float32)
elif INPUT == "IDENTITY" and M==N==K:
na = np.identity(K, dtype=np.float32)
nb = np.identity(K, dtype=np.float32)
elif INPUT == "OUTPUTONES" and M==K:
na = np.identity(K, dtype=np.float32)
nb = np.ones((K,N), dtype=np.float32)
else:
na = np.ones((M,K), dtype=np.float32)
nb = np.ones((K,N), dtype=np.float32)
nc = np.zeros(M*N, np.float32)
if DTYPE_IN != dtypes.float:
na = na.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
nb = nb.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
if DTYPE_OUT != dtypes.float:
nc = nc.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
return na, nb, nc
def ast_to_cuda_prog(compiler, ast, opts):
k = Kernel(ast)
k.required_optimizations()
for opt in opts:
k.apply_opt(opt)
p = k.to_program()
return CUDAProgram(device, p.function_name, compiler.compile(p.src))
if __name__ == "__main__":
print(f"gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}")
prog, global_size, local_size = None, None, None
if getenv("CUDA") == 1:
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
device = CUDADevice("cuda:0")
compiler = CUDACompiler(device.arch)
cudaalloc = CUDAAllocator(device)
a = cudaalloc.alloc(M*K*DTYPE_IN.itemsize)
b = cudaalloc.alloc(K*N*DTYPE_IN.itemsize)
c = cudaalloc.alloc(M*N*DTYPE_OUT.itemsize)
if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA and triton-generated kernel")
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py`
# this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS.
# WMMA element size is (M, N, K) = (16, 8, 16)
# warpgroup size in WMMA tiles is (B_M, B_N, B_K) = (2, 8, 4) so 64 HMMA calls per threadgroup reduce iteration
# thread block size is (T_M, T_N, T_K) = (2, 2, 1), i.e. macro blocks in M and N, so 256 HMMA calls per kernel reduce iteration
# kernel reduce iteration size in elements = (64, 128, 64)
# single iteration SMEM_A = (64 * 64) * (2 bytes / half) = 8192 bytes, SMEM_B = (128 * 64) * (2 bytes / half) = 16384 bytes
# double-buffer smem = (8192 + 16384) * 2 = 49152 bytes
# reduce for_loop size = [1, 1, (4096 // 16 // 4)==64]
# NOTE: T_K > 0 would be group_for_reduce
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.max.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "2_stage_swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA, 2-stage reduce pipeline, swizzled SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA, swizzled SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "flat_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA, flat SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "hcopt" and M == N == K == 4096 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.float:
print("Using CUDA and generated hcopt")
# [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4)]
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp16.hcopt.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [32, 64, 1],
'local_size': [16, 2, 4], # 16,2 are warp, 4 workgroups upcasted to axis=1
'wait': True,
}
elif GEMM_VARIATION == "2_stage" and (M%64)== 0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and un-optimized 2-stage, swizzled SMEM inputs and direct acc to output kernel")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.2_stage.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "3_stage" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix)")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "3_stage_swizzled" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix) and swizzled SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "max" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.max.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "no_xor" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.no_xor.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
else:
raise RuntimeError(f"invalid gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}")
tms = []
na, nb, nc = randoms()
cudaalloc.copyin(a, bytearray(na))
cudaalloc.copyin(b, bytearray(nb))
for i in range(CNT):
tms.append(prog(*args, **kwargs))
cudaalloc.copyout(flat_mv(nc.data), c)
comp = na.astype(np.float32) @ nb.astype(np.float32)
result = nc.reshape(M, N).astype(np.float32)
print(f"{N*N:10d} {min(tms)*1e6:9.2f} us, would be {FLOPS*1e-9/min(tms):9.2f} GFLOPS matmul, {BW*1e-9/min(tms):.2f} GB/s")
try:
np.testing.assert_allclose(result, comp, atol=ATOL, rtol=RTOL)
except AssertionError as e:
if getenv("DEBUG_VALUES") > 0:
indices = np.where(~np.isclose(result, comp, rtol=RTOL, atol=ATOL))
non_matching_elements_result = result[indices]
non_matching_elements_comp = comp[indices]
print("valid :", np.where(np.isclose(result, comp, rtol=RTOL, atol=ATOL)))
print("invalid :", indices)
print("result :", non_matching_elements_result)
print("ground truth:", non_matching_elements_comp)
print("result sum :", np.sum(result))
print("ground sum :", np.sum(comp))
raise e
if getenv("DEBUG_VALUES") > 0:
print(comp)
print("ground sum :", np.sum(comp))
print(result)
print("result sum :", np.sum(result))
elif getenv("AMD") == 1:
# note: https://hipfft.readthedocs.io/en/rocm-6.1.2/how-to/fine-tuning-llms/optimizing-triton-kernel.html
# also this is different than the rocblas/tensile approach to GEMM
# see: https://github.com/ROCm/Tensile/blob/develop/Tensile/KernelWriterAssembly.py
raise RuntimeError("invalid max_matmul device")
else:
raise RuntimeError("invalid max_matmul device")