You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
11 KiB
217 lines
11 KiB
#!/usr/bin/env python3
|
|
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
|
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
from urllib.parse import parse_qs, urlparse
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Any, Callable, Optional
|
|
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
|
|
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
|
|
|
|
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
|
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
|
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
|
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4",
|
|
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
|
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}
|
|
|
|
# ** API spec
|
|
|
|
@dataclass
|
|
class GraphRewriteMetadata:
|
|
"""Overview of a tracked rewrite to viz the sidebar"""
|
|
loc: tuple[str, int]
|
|
"""File_path, Lineno"""
|
|
code_line: str
|
|
"""The Python line calling graph_rewrite"""
|
|
kernel_name: str
|
|
"""The kernel calling graph_rewrite"""
|
|
upats: list[tuple[tuple[str, int], str, float]]
|
|
"""List of all the applied UPats"""
|
|
|
|
@dataclass
|
|
class GraphRewriteDetails(GraphRewriteMetadata):
|
|
"""Full details about a single call to graph_rewrite"""
|
|
graphs: list[UOp]
|
|
"""Sink at every step of graph_rewrite"""
|
|
diffs: list[list[str]]
|
|
""".diff style before and after of the rewritten UOp child"""
|
|
changed_nodes: list[list[int]]
|
|
"""Nodes that changed at every step of graph_rewrite"""
|
|
kernel_code: Optional[str]
|
|
"""The program after all rewrites"""
|
|
|
|
# ** API functions
|
|
|
|
# NOTE: if any extra rendering in VIZ fails, we don't crash
|
|
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
|
try: return fxn(*args, **kwargs)
|
|
except Exception as e: return f"ERROR: {e}"
|
|
|
|
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]:
|
|
kernels: dict[str, list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {}
|
|
for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"):
|
|
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
|
|
for ctx in ctxs:
|
|
if pickle.loads(ctx.sink).op is Ops.CONST: continue
|
|
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
|
|
kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
|
return list(kernels.values())
|
|
|
|
def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
|
|
assert isinstance(x, UOp)
|
|
graph: dict[int, tuple[str, str, list[int], str, str]] = {}
|
|
excluded = set()
|
|
for u in x.toposort:
|
|
if u.op in {Ops.CONST, Ops.DEVICE}:
|
|
excluded.add(u)
|
|
continue
|
|
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f" / {v.offset}" if v.offset else "") for v in u.arg.views])) if u.op is Ops.VIEW else str(u.arg)
|
|
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
|
for idx,x in enumerate(u.src):
|
|
if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}"
|
|
if x.op is Ops.DEVICE: label += f"\nDEVICE{idx} {x.arg}"
|
|
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x not in excluded], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
|
return graph
|
|
def _replace_uop(base:UOp, replaces:dict[UOp, UOp]) -> UOp:
|
|
if (found:=replaces.get(base)) is not None: return found
|
|
ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src))
|
|
if (final := replaces.get(ret)) is not None:
|
|
return final
|
|
replaces[base] = ret
|
|
return ret
|
|
@functools.lru_cache(None)
|
|
def _prg(k:Kernel): return k.to_program().src
|
|
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
|
g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
|
|
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None)
|
|
replaces: dict[UOp, UOp] = {}
|
|
sink = g.graphs[0]
|
|
for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches):
|
|
u0 = pickle.loads(u0_b)
|
|
# if the match didn't result in a rewrite we move forward
|
|
if u1_b is None:
|
|
replaces[u0] = u0
|
|
continue
|
|
replaces[u0] = u1 = pickle.loads(u1_b)
|
|
# first, rewrite this UOp with the current rewrite + all the matches in replaces
|
|
new_sink = _replace_uop(sink, {**replaces})
|
|
# sanity check
|
|
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
|
# update ret data
|
|
g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST])
|
|
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
|
g.graphs.append(sink:=new_sink)
|
|
return g
|
|
|
|
# Profiler API
|
|
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
|
|
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
|
|
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
|
|
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
|
|
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
|
|
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
|
|
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
|
|
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
|
|
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
|
|
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
|
|
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
|
|
ret = []
|
|
for i,e in enumerate(ev.ents):
|
|
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
|
|
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
|
|
for dep in ev.deps[i]:
|
|
d = ev.ents[dep]
|
|
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
|
|
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
|
|
return ret
|
|
def to_perfetto(profile:list[ProfileEvent]):
|
|
# Start json with devices.
|
|
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
|
|
for ev in tqdm(profile, desc="preparing profile"):
|
|
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
|
|
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
|
|
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
|
|
|
|
# ** HTTP server
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
ret, status_code, content_type = b"", 200, "text/html"
|
|
|
|
if (url:=urlparse(self.path)).path == "/":
|
|
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
|
|
elif (url:=urlparse(self.path)).path == "/profiler":
|
|
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
|
|
elif self.path.startswith("/assets/") and '/..' not in self.path:
|
|
try:
|
|
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
|
|
if url.path.endswith(".js"): content_type = "application/javascript"
|
|
if url.path.endswith(".css"): content_type = "text/css"
|
|
except FileNotFoundError: status_code = 404
|
|
elif url.path == "/kernels":
|
|
query = parse_qs(url.query)
|
|
if (qkernel:=query.get("kernel")) is not None:
|
|
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
|
|
jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]}
|
|
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
|
|
ret, content_type = json.dumps(jret).encode(), "application/json"
|
|
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
|
else: status_code = 404
|
|
|
|
# send response
|
|
self.send_response(status_code)
|
|
self.send_header('Content-Type', content_type)
|
|
self.send_header('Content-Length', str(len(ret)))
|
|
self.end_headers()
|
|
return self.wfile.write(ret)
|
|
|
|
# ** main loop
|
|
|
|
def reloader():
|
|
mtime = os.stat(__file__).st_mtime
|
|
while not stop_reloader.is_set():
|
|
if mtime != os.stat(__file__).st_mtime:
|
|
print("reloading server...")
|
|
os.execv(sys.executable, [sys.executable] + sys.argv)
|
|
time.sleep(0.1)
|
|
|
|
def load_pickle(path:str):
|
|
if path is None or not os.path.exists(path): return None
|
|
with open(path, "rb") as f: return pickle.load(f)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
|
|
parser.add_argument('--profile', type=str, help='Path profile', default=None)
|
|
args = parser.parse_args()
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
|
|
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
|
|
stop_reloader = threading.Event()
|
|
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
|
|
st = time.perf_counter()
|
|
print("*** viz is starting")
|
|
|
|
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
|
|
|
|
kernels = get_metadata(*contexts) if contexts is not None else []
|
|
|
|
if getenv("FUZZ_VIZ"):
|
|
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
|
|
print(f"fuzzed {len(ret)} rewrite details")
|
|
|
|
perfetto_profile = to_perfetto(profile) if profile is not None else None
|
|
|
|
server = HTTPServer(('', PORT), Handler)
|
|
reloader_thread = threading.Thread(target=reloader)
|
|
reloader_thread.start()
|
|
print(f"*** started viz on {HOST}:{PORT}")
|
|
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
|
|
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
|
|
try: server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
print("*** viz is shutting down...")
|
|
stop_reloader.set()
|
|
|