openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
6.1 KiB

from __future__ import annotations
from typing import Optional, Callable
import functools
from dataclasses import dataclass, field, replace
from tinygrad.helpers import to_function_name, dedup, prod
from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
from tinygrad.dtype import DType
@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
dims: tuple[int,int,int] # N, M, K
dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D
threads: list[tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
reduce_axes: list[tuple[int,int]] # list of (TC dim,amt) that constructs the shape of the reduce dim
@property
def early_upcast_axes(self) -> list[tuple[int,int]]: # list of (TC dim,amt) that upcasts the threads remainders of dims [0,1]
return [(d,self.dims[d]//sz) for d,sz in [(dim,prod(sz for d,sz in self.threads if d==dim)) for dim in range(2)] if self.dims[d]>sz]
upcast_axes: tuple[list[tuple[int,int]], list[tuple[int,int]], list[tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
st1_pattern: Optional[tuple[tuple[tuple[int,int], ...], tuple[tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
st2_pattern: Optional[tuple[tuple[tuple[int,int], ...], tuple[tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
expanded_shape: Optional[tuple[int, ...]] = None
opts_seq: tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
@dataclass(frozen=True)
class Estimates:
# number of FLOPS used in the Kernel
ops:sint = 0
# bytes accessed in loads and stores
lds:sint = 0
# total bytes accessed, counting only once for bytes that are accessed multiple times
mem:sint = 0
def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem)
def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem))
@staticmethod
def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
flops: sint = 0
lds: sint = 0
mults: sint = 1
mult_stack: list[sint] = []
dont_count: set[UOp] = set()
if ignore_indexing:
for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}:
dont_count = dont_count.union(u.src[0].toposort)
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort)
for u in uops:
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= (u.src[1] - u.src[0]).ssimplify()
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
@dataclass
class ProgramSpec:
name:str
src:str
device:str
uops:Optional[list[UOp]]=None
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
# filled in from uops (if we have uops)
global_size:Optional[list[int]]=None
local_size:Optional[list[int]]=None
vars:list[Variable]=field(default_factory=list)
globals:list[int]=field(default_factory=list)
outs:list[int]=field(default_factory=list)
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
def __post_init__(self):
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0][0] == 'i': self.local_size = None
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
assert special_size is not None
special_size[int(u.arg[0][-1])] = u.arg[1]
self.vars = sorted(self.vars, key=lambda v: v.arg)
self.outs = sorted(dedup(self.outs))
self._ran_post_init = True
@functools.cached_property
def estimates(self) -> Estimates:
return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
def launch_dims(self, var_vals:dict[Variable, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
class Renderer:
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
shared_max: int = 32768
tensor_cores: list[TensorCore] = []
extra_matcher: Optional[PatternMatcher] = None
code_for_op: dict[Ops, Callable] = {}
def __reduce__(self): return self.__class__, ()
def render(self, name:str, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")