You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
5.6 KiB
117 lines
5.6 KiB
import os, atexit, functools
|
|
try:
|
|
import networkx as nx # type: ignore
|
|
except ImportError:
|
|
nx = None # graph won't work
|
|
from collections import defaultdict
|
|
from typing import Dict, List
|
|
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
|
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
|
|
from tinygrad.codegen.linearizer import UOps
|
|
|
|
# **** debugging and graphing ****
|
|
|
|
G = nx.DiGraph() if nx is not None else None
|
|
cnts: Dict[OpType, int] = defaultdict(int)
|
|
if DEBUG >= 2:
|
|
def print_globalcounters():
|
|
if GlobalCounters.time_sum_s == 0: return
|
|
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
|
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
|
atexit.register(print_globalcounters)
|
|
if GRAPH:
|
|
def save_graph_exit():
|
|
for k,v in cnts.items(): print(k, v)
|
|
print("saving", G)
|
|
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
|
# -Gnslimit=100 can make it finish, but you won't like results
|
|
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
|
atexit.register(save_graph_exit)
|
|
|
|
node_count = 0
|
|
def nm(x):
|
|
global node_count
|
|
if not hasattr(x, 'node_id'):
|
|
setattr(x, 'node_id', node_count)
|
|
node_count += 1
|
|
return x.node_id
|
|
|
|
def get_sop(op: List[Op]):
|
|
op = [x for x in op if x not in BufferOps]
|
|
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
|
if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
|
return str(len(op))
|
|
|
|
def str_dtype(dtyp):
|
|
ret = str(dtyp)[7:]
|
|
return "" if ret == 'float' else f"\n{ret}"
|
|
|
|
@functools.lru_cache(None)
|
|
def add_st_node(nmx, nmo, label, st):
|
|
global node_count
|
|
inter_node = node_count
|
|
node_count += 1
|
|
G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{st.real_offset()}" if st.real_offset() != 0 else ""))
|
|
G.add_edge(nmx, inter_node, color='#00000060')
|
|
G.add_edge(inter_node, nmo, label=label, color='#00000060')
|
|
|
|
logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None
|
|
def log_schedule_item(si: ScheduleItem):
|
|
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
|
|
show_graph = bool(GRAPH)
|
|
if not DEBUG and not show_graph: return
|
|
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
|
|
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
|
|
|
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
|
|
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
|
|
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
|
cnts[optype] += 1
|
|
if show_graph:
|
|
assert si.out.base == si.out, "all outputs based"
|
|
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
|
|
|
|
# get inputs for shapetrackers
|
|
input_to_st = defaultdict(list)
|
|
for lo in si.ast.get_lazyops():
|
|
if lo.op != BufferOps.MEM: continue
|
|
input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st)
|
|
|
|
# add them to the graph, potentially with a movement op seperating them
|
|
for x in input_to_st:
|
|
for st in dedup(input_to_st[x]):
|
|
if st.contiguous:
|
|
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060')
|
|
else:
|
|
add_st_node(nm(x), nm(si.out), get_sop(op), st)
|
|
if 'label' not in G.nodes[nm(x)]:
|
|
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype)
|
|
|
|
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
|
|
|
|
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
|
|
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
|
|
G.nodes[nm(si.out)]['color'] = 'black'
|
|
G.nodes[nm(si.out)]['style'] = 'filled'
|
|
|
|
def _tree(lazydata, prefix=""):
|
|
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
|
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
|
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
|
childs = [_tree(c) for c in lazydata.src[:]]
|
|
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
|
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
|
|
|
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
|
|
|
|
def graph_uops(uops):
|
|
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
|
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
|
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"}
|
|
G = nx.DiGraph()
|
|
for u in uops:
|
|
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
|
|
for v in u.vin: G.add_edge(v.num, u.num)
|
|
GRAPHPATH = "/tmp/uops"
|
|
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
|
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
|
|