You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.5 KiB
63 lines
2.5 KiB
1 day ago
|
import unittest
|
||
|
from tinygrad import Tensor
|
||
|
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views
|
||
|
from tinygrad.engine.schedule import sym
|
||
|
|
||
|
class TestRewriteTrackedChildren(unittest.TestCase):
|
||
|
def test_children_in_context(self):
|
||
|
def print_children(ctx:RewriteContext, sink:UOp):
|
||
|
view_w_child = sink.src[0].src[0].src[0]
|
||
|
assert view_w_child.op is Ops.VIEW
|
||
|
assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3))
|
||
|
ctx.update_children()
|
||
|
assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4))
|
||
|
# this is the 3
|
||
|
assert len(ctx.children[sink.src[0].src[1]]) == 1
|
||
|
assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD
|
||
|
# this is the 4
|
||
|
assert len(ctx.children[sink.src[0].src[0]]) == 1
|
||
|
assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD
|
||
|
rewrite = PatternMatcher([
|
||
|
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
|
||
|
(UPat(Ops.SINK, name="sink"), print_children)
|
||
|
])
|
||
|
a = Tensor(2)
|
||
|
b = Tensor(3)
|
||
|
c = a + b
|
||
|
sink = c.lazydata.sink()
|
||
|
sink = graph_rewrite(sink, rewrite, track_children=True)
|
||
|
|
||
|
def test_simple_child(self):
|
||
|
rewrite = PatternMatcher([
|
||
|
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
|
||
|
])
|
||
|
a = Tensor(2)
|
||
|
b = Tensor(3)
|
||
|
c = a + b
|
||
|
sink = c.lazydata
|
||
|
view_w_child = a.lazydata.src[0]
|
||
|
print([x().arg for x in view_w_child.children])
|
||
|
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||
|
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
|
||
|
# children can either be added to or removed from the map with graph_rewrite
|
||
|
# added to is easy to detect, just hook the UOp constructor
|
||
|
# when are children removed?
|
||
|
# * if a rewrite rule returns a UOp, the matched node is removed from the graph
|
||
|
sink = graph_rewrite(sink, rewrite)
|
||
|
print([x().arg for x in view_w_child.children])
|
||
|
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||
|
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
|
||
|
|
||
|
@unittest.expectedFailure
|
||
|
def test_child_after_parent_update(self):
|
||
|
def print_children(ctx, r):
|
||
|
ctx.update_children()
|
||
|
print(ctx.children[r])
|
||
|
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
|
||
|
a = Tensor.empty(3, 3)
|
||
|
r = (a+0).sum()
|
||
|
graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|