import os, atexit, functools try: import networkx as nx # type: ignore except ImportError: nx = None # graph won't work from collections import defaultdict from typing import Dict, List from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup from tinygrad.codegen.linearizer import UOps # **** debugging and graphing **** G = nx.DiGraph() if nx is not None else None cnts: Dict[OpType, int] = defaultdict(int) if DEBUG >= 2: def print_globalcounters(): if GlobalCounters.time_sum_s == 0: return print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") atexit.register(print_globalcounters) if GRAPH: def save_graph_exit(): for k,v in cnts.items(): print(k, v) print("saving", G) nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') # -Gnslimit=100 can make it finish, but you won't like results os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg') atexit.register(save_graph_exit) node_count = 0 def nm(x): global node_count if not hasattr(x, 'node_id'): setattr(x, 'node_id', node_count) node_count += 1 return x.node_id def get_sop(op: List[Op]): op = [x for x in op if x not in BufferOps] if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1]) if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1]) return str(len(op)) def str_dtype(dtyp): ret = str(dtyp)[7:] return "" if ret == 'float' else f"\n{ret}" @functools.lru_cache(None) def add_st_node(nmx, nmo, label, st): global node_count inter_node = node_count node_count += 1 G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{st.real_offset()}" if st.real_offset() != 0 else "")) G.add_edge(nmx, inter_node, color='#00000060') G.add_edge(inter_node, nmo, label=label, color='#00000060') logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None def log_schedule_item(si: ScheduleItem): if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n") show_graph = bool(GRAPH) if not DEBUG and not show_graph: return if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base)) if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return op: List[Op] = [x.op for x in si.ast.get_lazyops()] oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps] optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) cnts[optype] += 1 if show_graph: assert si.out.base == si.out, "all outputs based" top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'} # get inputs for shapetrackers input_to_st = defaultdict(list) for lo in si.ast.get_lazyops(): if lo.op != BufferOps.MEM: continue input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st) # add them to the graph, potentially with a movement op seperating them for x in input_to_st: for st in dedup(input_to_st[x]): if st.contiguous: G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060') else: add_st_node(nm(x), nm(si.out), get_sop(op), st) if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype) if nm(si.out) not in G.nodes: G.add_node(nm(si.out)) G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "") G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype] G.nodes[nm(si.out)]['color'] = 'black' G.nodes[nm(si.out)]['style'] = 'filled' def _tree(lazydata, prefix=""): if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ") if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] childs = [_tree(c) for c in lazydata.src[:]] for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]] def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))])) def graph_uops(uops): colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0", UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"} G = nx.DiGraph() for u in uops: G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) for v in u.vin: G.add_edge(v.num, u.num) GRAPHPATH = "/tmp/uops" nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')