You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
3.2 KiB
74 lines
3.2 KiB
import os, atexit, itertools
|
|
try:
|
|
import networkx as nx # type: ignore
|
|
except ImportError:
|
|
nx = None # graph won't work
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional
|
|
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
|
from tinygrad.helpers import getenv, DEBUG
|
|
|
|
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
|
|
|
# **** debugging and graphing ****
|
|
|
|
G = nx.DiGraph() if nx is not None else None
|
|
cnts : Dict[OpType, int] = defaultdict(int)
|
|
if GRAPH:
|
|
def save_graph_exit():
|
|
for k,v in cnts.items(): print(k, v)
|
|
if PRUNEGRAPH: prune_graph()
|
|
print("saving", G)
|
|
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
|
# -Gnslimit=100 can make it finish, but you won't like results
|
|
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
|
atexit.register(save_graph_exit)
|
|
|
|
node_count = 0
|
|
def nm(x):
|
|
global node_count
|
|
if not hasattr(x, 'node_id'):
|
|
setattr(x, 'node_id', node_count)
|
|
node_count += 1
|
|
return x.node_id
|
|
|
|
def get_sop(op : List[Op]):
|
|
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
|
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
|
|
return str(len(op))
|
|
|
|
def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None):
|
|
if show_graph is None: show_graph = GRAPH
|
|
if not DEBUG and not show_graph: return
|
|
op : List[Op] = [x.op for x in get_lazyops(ast)]
|
|
inp : List[DeviceBuffer] = get_buffers(ast)
|
|
if len(inp) == 1 and inp[0] == ret:
|
|
if show_graph and nm(ret) in G.nodes: G.nodes[nm(ret)]['style'] += ', bold'
|
|
return # don't log self loops
|
|
oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
|
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
|
cnts[optype] += 1
|
|
if DEBUG >= 4: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
|
|
if show_graph:
|
|
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", FusedOps: "#ff8080"}
|
|
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
|
|
|
|
for x in inp:
|
|
G.add_edge(nm(x), nm(ret), label=get_sop(op))
|
|
if 'label' not in G.nodes[nm(x)]:
|
|
G.nodes[nm(x)]['label'] = str(x.shape)
|
|
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
|
|
|
G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape)
|
|
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else str())) if optype in top_colors else "#ffffff"
|
|
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
|
|
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
|
|
|
|
# prune movementops and loadops
|
|
def prune_graph():
|
|
dead_nodes = []
|
|
for n in G.nodes:
|
|
if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']:
|
|
G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))])
|
|
dead_nodes.append(n)
|
|
G.remove_nodes_from(dead_nodes)
|
|
|