openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

69 lines
2.5 KiB

import unittest
from dataclasses import dataclass, field
from tinygrad.uop.ops import PatternMatcher, UOp, graph_rewrite, Ops, UPat, GroupOp, RewriteNotReady
# we could insert CHILDREN node
@dataclass
class ChildrenContext:
children: dict[UOp, list[UOp]]|None = None
seen_children: dict[UOp, set[int]] = field(default_factory=dict)
seen_consts:int = 0
saved_seen_consts:int = 0
# this is a generic child labeller
def extract_children(ctx:ChildrenContext, x:UOp):
if ctx.children is not None: return
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1}
def mark_children(ctx:ChildrenContext, x:UOp):
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
return x.replace(src=tuple(new_srcs))
pm_children = PatternMatcher([
(UPat(Ops.SINK, name="x"), extract_children),
(UPat(GroupOp.All-{Ops.CHILD}, name="x"), mark_children),
])
# this is a generic pattern
def visit_child(ctx:ChildrenContext, x:UOp):
if x.src[0] not in ctx.seen_children: ctx.seen_children[x.src[0]] = set()
ctx.seen_children[x.src[0]].add(x.arg[0])
if len(ctx.seen_children[x.src[0]]) != x.arg[1]:
print(f"visit CHILD {x.arg} bottom up -- not ready {ctx.seen_children[x.src[0]]}")
raise RewriteNotReady
print(f"visit CHILD {x.arg} bottom up -- READY {ctx.seen_children[x.src[0]]}")
pm_child_visitor = PatternMatcher([
(UPat(Ops.CHILD, name="x"), visit_child),
])
# this is for the test
def see_const(ctx:ChildrenContext, c:UOp): ctx.seen_consts += c.arg
def save_seen_consts(ctx:ChildrenContext, x:UOp): ctx.saved_seen_consts = ctx.seen_consts
pm_consts = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), save_seen_consts),
(UPat()+UPat.cvar("c"), see_const),
])
class TestChildrenRewrite(unittest.TestCase):
def test_not_ready(self):
a = UOp.variable("a", 0, 10).exp2()
b = a+2
c = a+3
d = b+c
sink = d.sink()
# without children and not ready, we don't see both adds before the DEFINE_VAR
ctx = ChildrenContext()
sink = graph_rewrite(sink, pm_consts, ctx=ctx, bottom_up=True)
self.assertNotEqual(ctx.seen_consts, ctx.saved_seen_consts)
# with children and not ready we do
ctx = ChildrenContext()
sink = graph_rewrite(sink, pm_children, ctx=ctx, bottom_up=True)
sink = graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
self.assertEqual(ctx.seen_consts, ctx.saved_seen_consts)
if __name__ == '__main__':
unittest.main()