You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
215 lines
9.8 KiB
215 lines
9.8 KiB
from typing import Dict, List, Optional
|
|
import unittest, decimal, json
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
|
|
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
|
|
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
|
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
|
|
|
|
@track_rewrites(named=True)
|
|
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
|
|
|
|
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
|
|
rewrite(sink, pm, **kwargs)
|
|
assert len(contexts) == 1
|
|
assert len(contexts[0]) == 1
|
|
k = get_metadata(keys, contexts)[0][0]
|
|
g = get_details(*k)
|
|
return g.graphs[1:]
|
|
|
|
class TestViz(unittest.TestCase):
|
|
def setUp(self):
|
|
contexts.clear()
|
|
keys.clear()
|
|
self.tms = TRACK_MATCH_STATS.value
|
|
TRACK_MATCH_STATS.value = 2
|
|
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
|
|
|
|
def test_viz_simple(self):
|
|
pm = PatternMatcher([
|
|
(UPat.var("x")*1, lambda x:x),
|
|
])
|
|
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
|
uops = helper_test_viz(a*1, pm)
|
|
self.assertEqual(len(uops), 1)
|
|
self.assertEqual(uops[0], a)
|
|
|
|
def test_rewrite_twice(self):
|
|
pm = PatternMatcher([
|
|
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
|
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
|
|
])
|
|
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
|
uops = helper_test_viz(a+a, pm)
|
|
self.assertEqual(len(uops), 2)
|
|
self.assertEqual(uops[0], a*2)
|
|
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
|
|
|
def test_rewrite_with_ctx(self):
|
|
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
|
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
|
|
def store_load(ctx:Dict[UOp, None], x:UOp) -> Optional[UOp]:
|
|
if x in ctx: return None
|
|
ctx[x] = None
|
|
return UOp.store(*x.src, x)
|
|
pm = PatternMatcher([
|
|
(UPat(Ops.LOAD, name="x"), store_load),
|
|
])
|
|
uops = helper_test_viz(a+b, pm, ctx={})
|
|
self.assertEqual(len(uops), 2)
|
|
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
|
|
|
|
def test_track_rewrites(self):
|
|
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
|
@track_rewrites(named=True)
|
|
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
|
|
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
|
do_rewrite(ld*1)
|
|
do_rewrite(ld*2)
|
|
ret = get_metadata(keys, contexts)
|
|
self.assertEqual(len(ret), 2)
|
|
key, _, m = ret[0][0]
|
|
self.assertEqual(key, "do_rewrite_1")
|
|
self.assertEqual(len(m.upats), 1)
|
|
key, _, m = ret[1][0]
|
|
self.assertEqual(key, "do_rewrite_2")
|
|
self.assertEqual(len(m.upats), 0)
|
|
|
|
def test_track_rewrites_with_exception(self):
|
|
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
|
@track_rewrites()
|
|
def do_rewrite(x:UOp):
|
|
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
|
raise Exception("test")
|
|
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
|
with self.assertRaises(Exception): do_rewrite(ld*1)
|
|
ret = get_metadata(keys, contexts)
|
|
self.assertEqual(len(ret), 1)
|
|
|
|
def test_fold_const(self):
|
|
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
|
graph = uop_to_json(a)
|
|
assert not any(v[0].startswith("CONST") for v in graph.values())
|
|
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
|
|
|
@unittest.skip("TODO: bring this back with better testing")
|
|
def test_bottom_up_rewrite(self):
|
|
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
|
n1 = a.sin()
|
|
uop = n1.sin()
|
|
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
|
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
|
|
self.assertEqual(len(ret), 2)
|
|
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
|
|
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
|
|
|
|
def test_top_down_rewrite(self):
|
|
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
|
n1 = a.sin()
|
|
uop = n1.sin()
|
|
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
|
# if it wasn't bottom_up, it's rewritten once
|
|
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
|
|
self.assertEqual(len(ret), 1)
|
|
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
|
|
|
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
|
|
def test_rewrite_without_context(self):
|
|
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
|
@track_rewrites(named=True)
|
|
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
|
# test
|
|
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
|
|
untracked_graph_rewrite(add)
|
|
self.assertEqual(len(contexts), 0)
|
|
tracked_graph_rewrite(add)
|
|
self.assertEqual(len(contexts), 1)
|
|
|
|
def test_inner_rewrite_location(self):
|
|
# inner rewrite gets tracked in another context
|
|
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
|
|
@track_rewrites(named=True)
|
|
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
|
|
# test
|
|
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
|
|
tracked_graph_rewrite(add)
|
|
self.assertEqual(len(contexts), 1)
|
|
# location of context is inner_rewrite
|
|
fp, lineno = contexts[0][0].loc
|
|
self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno)
|
|
self.assertEqual(fp, inner_rewrite.__code__.co_filename)
|
|
|
|
class TextVizProfiler(unittest.TestCase):
|
|
def test_perfetto_node(self):
|
|
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
|
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
|
|
|
j = json.loads(to_perfetto(prof))
|
|
|
|
# Device regs always first
|
|
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
|
|
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
|
|
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
|
|
|
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
|
|
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
|
|
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
|
|
self.assertEqual(j['traceEvents'][1]['tid'], 0)
|
|
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
|
|
|
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
|
|
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
|
|
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
|
|
self.assertEqual(j['traceEvents'][2]['tid'], 1)
|
|
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
|
|
|
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
|
|
self.assertEqual(j['traceEvents'][3]['ts'], 0)
|
|
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
|
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
|
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
|
|
self.assertEqual(j['traceEvents'][3]['tid'], 0)
|
|
|
|
def test_perfetto_copy_node(self):
|
|
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
|
|
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
|
|
|
j = json.loads(to_perfetto(prof))
|
|
|
|
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
|
|
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
|
|
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
|
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
|
self.assertEqual(j['traceEvents'][3]['tid'], 1)
|
|
|
|
def test_perfetto_graph(self):
|
|
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
|
|
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
|
|
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
|
|
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
|
|
deps=[[], [0]],
|
|
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
|
|
|
|
j = json.loads(to_perfetto(prof))
|
|
|
|
# Device regs always first
|
|
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
|
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
|
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
|
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
|
|
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
|
|
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
|
|
|
|
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
|
|
self.assertEqual(j['traceEvents'][6]['ts'], 0)
|
|
self.assertEqual(j['traceEvents'][6]['dur'], 2)
|
|
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
|
|
|
|
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
|
|
self.assertEqual(j['traceEvents'][7]['ts'], 954)
|
|
self.assertEqual(j['traceEvents'][7]['dur'], 4)
|
|
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
|