You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.5 KiB
70 lines
2.5 KiB
3 days ago
|
import unittest
|
||
|
from dataclasses import dataclass, field
|
||
|
from tinygrad.uop.ops import PatternMatcher, UOp, graph_rewrite, Ops, UPat, GroupOp, RewriteNotReady
|
||
|
|
||
|
# we could insert CHILDREN node
|
||
|
|
||
|
@dataclass
|
||
|
class ChildrenContext:
|
||
|
children: dict[UOp, list[UOp]]|None = None
|
||
|
seen_children: dict[UOp, set[int]] = field(default_factory=dict)
|
||
|
seen_consts:int = 0
|
||
|
saved_seen_consts:int = 0
|
||
|
|
||
|
# this is a generic child labeller
|
||
|
def extract_children(ctx:ChildrenContext, x:UOp):
|
||
|
if ctx.children is not None: return
|
||
|
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1}
|
||
|
|
||
|
def mark_children(ctx:ChildrenContext, x:UOp):
|
||
|
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
|
||
|
return x.replace(src=tuple(new_srcs))
|
||
|
|
||
|
pm_children = PatternMatcher([
|
||
|
(UPat(Ops.SINK, name="x"), extract_children),
|
||
|
(UPat(GroupOp.All-{Ops.CHILD}, name="x"), mark_children),
|
||
|
])
|
||
|
|
||
|
# this is a generic pattern
|
||
|
def visit_child(ctx:ChildrenContext, x:UOp):
|
||
|
if x.src[0] not in ctx.seen_children: ctx.seen_children[x.src[0]] = set()
|
||
|
ctx.seen_children[x.src[0]].add(x.arg[0])
|
||
|
if len(ctx.seen_children[x.src[0]]) != x.arg[1]:
|
||
|
print(f"visit CHILD {x.arg} bottom up -- not ready {ctx.seen_children[x.src[0]]}")
|
||
|
raise RewriteNotReady
|
||
|
print(f"visit CHILD {x.arg} bottom up -- READY {ctx.seen_children[x.src[0]]}")
|
||
|
|
||
|
pm_child_visitor = PatternMatcher([
|
||
|
(UPat(Ops.CHILD, name="x"), visit_child),
|
||
|
])
|
||
|
|
||
|
# this is for the test
|
||
|
def see_const(ctx:ChildrenContext, c:UOp): ctx.seen_consts += c.arg
|
||
|
def save_seen_consts(ctx:ChildrenContext, x:UOp): ctx.saved_seen_consts = ctx.seen_consts
|
||
|
pm_consts = PatternMatcher([
|
||
|
(UPat(Ops.DEFINE_VAR, name="x"), save_seen_consts),
|
||
|
(UPat()+UPat.cvar("c"), see_const),
|
||
|
])
|
||
|
|
||
|
class TestChildrenRewrite(unittest.TestCase):
|
||
|
def test_not_ready(self):
|
||
|
a = UOp.variable("a", 0, 10).exp2()
|
||
|
b = a+2
|
||
|
c = a+3
|
||
|
d = b+c
|
||
|
sink = d.sink()
|
||
|
|
||
|
# without children and not ready, we don't see both adds before the DEFINE_VAR
|
||
|
ctx = ChildrenContext()
|
||
|
sink = graph_rewrite(sink, pm_consts, ctx=ctx, bottom_up=True)
|
||
|
self.assertNotEqual(ctx.seen_consts, ctx.saved_seen_consts)
|
||
|
|
||
|
# with children and not ready we do
|
||
|
ctx = ChildrenContext()
|
||
|
sink = graph_rewrite(sink, pm_children, ctx=ctx, bottom_up=True)
|
||
|
sink = graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
|
||
|
self.assertEqual(ctx.seen_consts, ctx.saved_seen_consts)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|