openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
10 KiB

import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
from tinygrad.ops import Ops, UOp, sint
def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
# scatter-reduce
for step in range(n_lbs-1):
for i in range(len(chunks)):
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device))
# allgather
for step in range(n_lbs-1):
for i in range(len(chunks)):
src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device)
# assemble chunks back
pads = [((s,numel-e),) for s,e in chunks]
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
# ***** multi functions *****
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
def alu_multi(root:UOp):
msrcs = root.src
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
axis = root.axis
bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
srcs:list[list[UOp]] = []
not_all_real = not all(all(mlb.real) for mlb in msrcs)
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
for mlb in msrcs:
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
else:
assert axis is not None and bounds is not None
if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
return UOp.multi(*new_lbs, axis=axis, real=new_real)
def reduce_multi(root:UOp, multi:UOp):
op, axis = root.arg
if multi.axis is not None and multi.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
# if all partitions are real, do all_reduce
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis)
# only one partition is real, keep it
return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real)
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def reshape_multi(root:UOp, multi:UOp):
arg = root.arg
if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
return UOp.multi(*lbs, axis=new_axis, real=multi.real)
def expand_multi(root:UOp, multi:UOp):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
# pad on shard axis -> fill others with zeros and set real to all True
if multi.axis is not None and root.arg[multi.axis] != (0,0):
# pad back to whole axis, remove real mask
assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
def permute_multi(root:UOp, multi:UOp):
# all permutes supported!
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real)
def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
f"shrinking not supported for {root.arg=}"
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
"cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
idx = multi.bounds.index(root.arg[multi.axis])
# zero out other lbs to not create lb reference
return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
axis=multi.axis, real=multi.real)
def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
def copy_multi(multi:UOp, device:UOp):
# if we already have a copy on the device, return that
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
# copy lbs to device, pad to final shape, and sum
llbs:list[UOp] = []
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
if not real: continue
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
return functools.reduce(operator.add, llbs)
def assign_multi(dest:UOp, src:UOp):
assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])
@track_rewrites(named=True)
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}