openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

204 lines
9.7 KiB

from typing import List, Tuple, Dict, Any
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, flatten
# *** image Tensor function replacements ***
from tinygrad.lazy import get_single_root
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
rcout = cout//groups
x, w = self, weight.reshape(groups, rcout, cin, H, W)
# hack for non multiples of 4 on cin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)
# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
oy, ox = x.shape[4:6]
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
# prepare weights
w = w.permute(0,4,2,5,1,3)
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
# the conv! (+ the bias)
ret = x*w
if IMAGE >= 2: ret = ret.cast(base_image_type((bs*oy, ox*cout//4, 4)))
ret = ret.sum((-4, -3, -2, -1))
# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
rcout -= added_output_channels
cout = groups * rcout
# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
# *** schedules with images need to be fixed to be valid ***
import dataclasses
from tinygrad.ops import ScheduleItem, BufferOps, LazyOp, UnaryOps, LoadOps, MemBuffer, get_lazyop_info
def fix_schedule_for_images(schedule:List[ScheduleItem]):
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
replace_inputs = {}
for i, si in enumerate(schedule):
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}")
si.out.dtype = dtypes.float32
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
# TODO: unit_stride axes will fail if there's a mask, even if the mask is divisble by four. this is too aggressive
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}")
if si.inputs[b.arg.idx-1].realized:
# have to copy it
replace_inputs[si.inputs[b.arg.idx-1]] = si.inputs[b.arg.idx-1].cast(dtypes.float32)
else:
# change it before it's created
si.inputs[b.arg.idx-1].dtype = dtypes.float32
# now fix up the schedule to reflect the new dtypes
fixed_schedule:List[ScheduleItem] = []
for i,si in enumerate(schedule):
ast = si.ast
inputs = si.inputs
# replace inputs with casted versions
if any(x in replace_inputs for x in inputs):
fixed_schedule += flatten([replace_inputs[x].schedule() for x in inputs if x in replace_inputs])
inputs = tuple(replace_inputs.get(x, x) for x in inputs)
# fix input dtypes to match what they actually are
replacements = {}
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
if b.arg.dtype != inputs[b.arg.idx-1].dtype:
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
if replacements: ast = ast.map_buffers(replacements)
# fix the ops to create the output dtype
if ast.op not in LoadOps:
info = get_lazyop_info(ast)
if info.dtype != si.out.dtype:
if DEBUG >= 3: print(f"{i:3d}: info.dtype {info.dtype} != {si.out.dtype} -> {si.out.dtype}")
ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False))
# put this in the fixed schedule
fixed_schedule.append(dataclasses.replace(si, ast=ast, inputs=inputs))
return fixed_schedule
# *** images have weird indexing requirements ***
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
idx = (idxy // 4) % base_shape[1]
idy = (idxy // (4 * base_shape[1]))
if valid.min == 0 and isinstance(idxy, SumNode):
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
val_dict: Dict[Node, Any] = {}
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
for node in nodes:
assert isinstance(node, LtNode)
node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars()
same_sym = [i for (i, var) in idxy_flat_var if var in node_vars]
if len(same_sym) == 0: continue
first, second = sorted(same_sym)[0], sorted(node_flat)[0]
f_b = 1 if isinstance(first, Variable) else first.b
s_b = 1 if isinstance(second, Variable) else second.b
sig = -1 if s_b < 0 else 1
key_node = sig*node.a
if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)]
val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1)
fakes = {}
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
fakes[fake_var] = key_node
idxy += multip*(fake_var - key_node)
idx = (idxy // 4) % base_shape[1]
idy = (idxy // (4 * base_shape[1]))
fake_rep = {fake: node for fake, node in fakes.items()}
idx = idx.substitute(fake_rep)
idy = idy.substitute(fake_rep)
idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), []
for node in nodes:
node_vars = set(node.vars())
if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds
# NOTE: Why does only idy is problematic? and not the idx
if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node)
valid = Variable.ands([i for i in nodes if i not in ones])
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
return (idx, idy), valid