openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

97 lines
4.1 KiB

from typing import cast
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, unwrap, merge_dicts
# **** ScheduleItem return type
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[Variable, int] = field(default_factory=dict)
# **** unbind Variables
def unbind_view(ctx:list[dict[Variable, int]], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx.append(var_vals)
return x.replace(arg=st) if st != x.st else None
def unbind_bind(ctx:list[dict[Variable, int]], x:UOp):
var, val = x.unbind()
ctx.append({var.replace(src=()):val})
return var
pm_unbind = PatternMatcher([
(UPat(Ops.VIEW, name="x"), unbind_view),
(UPat(Ops.BIND, name="x"), unbind_bind),
])
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
# construct the KERNEL children graph based on assigns
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
for u in (toposort:=sched_sink.toposort()):
if u.op is not Ops.ASSIGN: continue
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src:
if s.op is not Ops.ASSIGN: continue
children[s.src[1]].append(k)
in_degree[k] += 1
# linearize KERNEL UOps into ScheduleItems in BFS order
queue = deque(k for k,v in in_degree.items() if v == 0)
schedule: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
while queue:
k = queue.popleft()
# unbind var_vals from the kernel
local_var_vals: list[dict[Variable, int]] = []
ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=local_var_vals, name="unbind vars")
var_vals = merge_dicts([var_vals, *local_var_vals])
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = k.src[1].buf_uop.buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src)
if any(isinstance(x, MultiBuffer) for x in ubufs):
if ast.op is Ops.COPY:
if isinstance(ubufs[1], MultiBuffer) and ast.arg is None: # src is multiple buffers, none selected
if isinstance(ubufs[0], MultiBuffer):
# COPY ALL -> ALL
for b1,b2 in zip(ubufs[0].bufs, ubufs[1].bufs): schedule.append(ScheduleItem(ast, (b1, b2), k.arg.metadata))
else:
# COPY ANY -> ONE. Currently we just select the first
schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata))
else:
src_buf = ubufs[1].bufs[ast.arg] if isinstance(ubufs[1], MultiBuffer) else ubufs[1]
if isinstance(ubufs[0], MultiBuffer):
# COPY ONE -> ALL (BROADCAST)
for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, src_buf), k.arg.metadata))
else: schedule.append(ScheduleItem(ast, (ubufs[0], src_buf), k.arg.metadata)) # COPY ONE -> ONE
else:
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
else:
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# map ASSIGN to BUFFER after ScheduleItems are constructed
becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN}
assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}"
return schedule, var_vals, becomes_map