openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

215 lines
9.8 KiB

from typing import Dict, List, Optional
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
@track_rewrites(named=True)
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
rewrite(sink, pm, **kwargs)
assert len(contexts) == 1
assert len(contexts[0]) == 1
k = get_metadata(keys, contexts)[0][0]
g = get_details(*k)
return g.graphs[1:]
class TestViz(unittest.TestCase):
def setUp(self):
contexts.clear()
keys.clear()
self.tms = TRACK_MATCH_STATS.value
TRACK_MATCH_STATS.value = 2
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
def test_viz_simple(self):
pm = PatternMatcher([
(UPat.var("x")*1, lambda x:x),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm)
self.assertEqual(len(uops), 1)
self.assertEqual(uops[0], a)
def test_rewrite_twice(self):
pm = PatternMatcher([
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)
self.assertEqual(len(uops), 2)
self.assertEqual(uops[0], a*2)
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
def test_rewrite_with_ctx(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
def store_load(ctx:Dict[UOp, None], x:UOp) -> Optional[UOp]:
if x in ctx: return None
ctx[x] = None
return UOp.store(*x.src, x)
pm = PatternMatcher([
(UPat(Ops.LOAD, name="x"), store_load),
])
uops = helper_test_viz(a+b, pm, ctx={})
self.assertEqual(len(uops), 2)
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
def test_track_rewrites(self):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites(named=True)
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
do_rewrite(ld*1)
do_rewrite(ld*2)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 2)
key, _, m = ret[0][0]
self.assertEqual(key, "do_rewrite_1")
self.assertEqual(len(m.upats), 1)
key, _, m = ret[1][0]
self.assertEqual(key, "do_rewrite_2")
self.assertEqual(len(m.upats), 0)
def test_track_rewrites_with_exception(self):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites()
def do_rewrite(x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this
raise Exception("test")
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
with self.assertRaises(Exception): do_rewrite(ld*1)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 1)
def test_fold_const(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
graph = uop_to_json(a)
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
@unittest.skip("TODO: bring this back with better testing")
def test_bottom_up_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
self.assertEqual(len(ret), 2)
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
def test_top_down_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# if it wasn't bottom_up, it's rewritten once
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
self.assertEqual(len(ret), 1)
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
def test_rewrite_without_context(self):
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
untracked_graph_rewrite(add)
self.assertEqual(len(contexts), 0)
tracked_graph_rewrite(add)
self.assertEqual(len(contexts), 1)
def test_inner_rewrite_location(self):
# inner rewrite gets tracked in another context
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
tracked_graph_rewrite(add)
self.assertEqual(len(contexts), 1)
# location of context is inner_rewrite
fp, lineno = contexts[0][0].loc
self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno)
self.assertEqual(fp, inner_rewrite.__code__.co_filename)
class TextVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][1]['tid'], 0)
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][2]['tid'], 1)
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
self.assertEqual(j['traceEvents'][3]['ts'], 0)
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][3]['tid'], 0)
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['tid'], 1)
def test_perfetto_graph(self):
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
self.assertEqual(j['traceEvents'][6]['ts'], 0)
self.assertEqual(j['traceEvents'][6]['dur'], 2)
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
self.assertEqual(j['traceEvents'][7]['ts'], 954)
self.assertEqual(j['traceEvents'][7]['dur'], 4)
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()