openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

148 lines
8.0 KiB

import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
from tinygrad.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
# *** allreduce implementation ***
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}")
# contiguous before we copy it
buf = buf.contiguous()
# copy to all devices. if you shrink later, that'll be handled
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
[UOp(Ops.COPY, buf.dtype, (buf, red.src[1]), arg=i) for i in range(len(buf.device))])
# new ring reduce
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
# extract chunks and scatter-reduce
reduced_chunks = []
for i,(s,e) in enumerate(chunks):
chunk = buf.reshape((numel,)).shrink(((s,e),))
reduced_chunk = chunk
for step in range(n_lbs-1):
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
# copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device
reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src).alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced_chunk)
# allgather + reassemble
pads = [((s,numel-e),) for s,e in chunks]
return functools.reduce(operator.add, [c.copy_to_device(buf.device).pad(pad) for pad,c in zip(pads, reduced_chunks)]).reshape(shape)
replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),])
# ***** multi functions *****
def alu_multi(root:UOp):
msrcs = root.src
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
axis = root.axis
assert axis is not None
srcs = []
for mlb in msrcs:
if mlb.axis == axis:
# same axis, just copy through
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0])
elif mlb.axis is None:
# no axis, shard it
assert mlb.op is not Ops.MULTI
srcs.append(mlb._shard(axis))
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
def reduce_multi(root:UOp, multi:UOp):
op, axis = root.arg
if multi.axis is not None and multi.axis in axis:
# all-reduce on sharded axes
return multi.src[0].r(op, axis).allreduce(op, multi.device)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return multi.src[0].r(op, axis).multi(axis=multi.axis)
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def reshape_multi(root:UOp, multi:UOp):
arg = root.arg
if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
assert prod(multi.src[0].shape[multi.axis:])%prod(arg[new_axis+1:]) == 0, f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
def expand_multi(root:UOp, multi:UOp):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
return multi.src[0].expand(_shape_to_single_shard(multi.axis, root.arg, multi.src[0])).multi(multi.axis)
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0,0), f"padding not supported for {root.arg=}"
return multi.src[0].pad(root.arg).multi(multi.axis)
def permute_multi(root:UOp, multi:UOp):
# all permutes supported!
return multi.src[0].permute(root.arg).multi(root.axis)
def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
f"shrinking not supported for {root.arg=}"
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
"cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
# we just copy it to all the devices, no real. this will be optimized out later
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.arg[multi.axis]))
return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))).multi(multi.axis)
def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
return multi.src[0].flip(root.arg).multi(multi.axis)
# from multiple devices -> one
def copy_multi(multi:UOp, device:UOp):
assert multi.axis is not None, "all multi ops have axis"
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
def assign_multi(dest:UOp, src:UOp):
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
return dest.src[0].assign(src.src[0]).multi(src.axis)
def passthrough_multi(root:UOp, multi:UOp):
return root.replace(src=(multi.src[0],)).multi(multi.axis)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])
@track_rewrites(named=True)
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}