You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
11 KiB
207 lines
11 KiB
#!/usr/bin/env python3
|
|
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver
|
|
from http.server import BaseHTTPRequestHandler
|
|
from urllib.parse import parse_qs, urlparse
|
|
from typing import Any, Callable, TypedDict, Generator
|
|
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap
|
|
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
|
|
from tinygrad.dtype import dtypes
|
|
|
|
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
|
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
|
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
|
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000",
|
|
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
|
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.NAME:"#808080"}
|
|
|
|
# VIZ API
|
|
|
|
# NOTE: if any extra rendering in VIZ fails, we don't crash
|
|
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
|
try: return fxn(*args, **kwargs)
|
|
except Exception as e: return f"ERROR in {fxn.__name__}: {e}"
|
|
|
|
# ** Metadata for a track_rewrites scope
|
|
|
|
class GraphRewriteMetadata(TypedDict):
|
|
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
|
|
match_count: int # total match count in this context
|
|
code_line: str # source code calling graph_rewrite
|
|
kernel_code: str|None # optionally render the final kernel code
|
|
name: str|None # optional name of the rewrite
|
|
|
|
@functools.lru_cache(None)
|
|
def render_program(k:Kernel): return k.opts.render(k.uops)
|
|
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
|
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
|
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name}
|
|
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
|
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
|
|
|
# ** Complete rewrite details for a graph_rewrite call
|
|
|
|
class GraphRewriteDetails(TypedDict):
|
|
graph: dict # JSON serialized UOp for this rewrite step
|
|
uop: str # strigified UOp for this rewrite step
|
|
diff: list[str]|None # string diff of the single UOp that changed
|
|
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
|
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
|
|
|
def uop_to_json(x:UOp) -> dict[int, dict]:
|
|
assert isinstance(x, UOp)
|
|
graph: dict[int, dict] = {}
|
|
excluded: set[UOp] = set()
|
|
for u in (toposort:=x.toposort):
|
|
# always exclude DEVICE/CONST/UNIQUE
|
|
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE}: excluded.add(u)
|
|
# only exclude CONST VIEW source if it has no other children in the graph
|
|
if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
|
|
excluded.update(u.src)
|
|
for u in toposort:
|
|
if u in excluded: continue
|
|
argst = str(u.arg)
|
|
if u.op is Ops.VIEW:
|
|
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+
|
|
("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views]))
|
|
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
|
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
|
for idx,x in enumerate(u.src):
|
|
if x in excluded:
|
|
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
|
|
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
|
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
|
|
return graph
|
|
|
|
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
|
yield {"graph":uop_to_json(next_sink:=ctx.sink), "uop":str(ctx.sink), "changed_nodes":None, "diff":None, "upat":None}
|
|
replaces: dict[UOp, UOp] = {}
|
|
for u0,u1,upat in tqdm(ctx.matches):
|
|
replaces[u0] = u1
|
|
new_sink = next_sink.substitute(replaces)
|
|
yield {"graph": (sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort if id(x) in sink_json],
|
|
"diff":list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())), "upat":(upat.location, upat.printable())}
|
|
if not ctx.bottom_up: next_sink = new_sink
|
|
|
|
# Profiler API
|
|
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
|
|
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
|
|
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
|
|
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
|
|
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
|
|
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
|
|
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
|
|
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
|
|
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
|
|
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
|
|
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
|
|
ret = []
|
|
for i,e in enumerate(ev.ents):
|
|
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
|
|
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
|
|
for dep in ev.deps[i]:
|
|
d = ev.ents[dep]
|
|
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
|
|
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
|
|
return ret
|
|
def to_perfetto(profile:list[ProfileEvent]):
|
|
# Start json with devices.
|
|
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
|
|
for ev in tqdm(profile, desc="preparing profile"):
|
|
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
|
|
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
|
|
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
|
|
|
|
# ** HTTP server
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
ret, status_code, content_type = b"", 200, "text/html"
|
|
|
|
if (url:=urlparse(self.path)).path == "/":
|
|
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
|
|
elif (url:=urlparse(self.path)).path == "/profiler":
|
|
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
|
|
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
|
|
try:
|
|
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
|
|
if url.path.endswith(".js"): content_type = "application/javascript"
|
|
if url.path.endswith(".css"): content_type = "text/css"
|
|
except FileNotFoundError: status_code = 404
|
|
elif url.path == "/kernels":
|
|
if "kernel" in (query:=parse_qs(url.query)):
|
|
def getarg(k:str,default=0): return int(query[k][0]) if k in query else default
|
|
kidx, ridx = getarg("kernel"), getarg("idx")
|
|
try:
|
|
# stream details
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "text/event-stream")
|
|
self.send_header("Cache-Control", "no-cache")
|
|
self.end_headers()
|
|
for r in get_details(contexts[0][kidx], contexts[1][kidx][ridx]):
|
|
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
|
self.wfile.flush()
|
|
self.wfile.write("data: END\n\n".encode("utf-8"))
|
|
return self.wfile.flush()
|
|
# pass if client closed connection
|
|
except (BrokenPipeError, ConnectionResetError): return
|
|
ret, content_type = json.dumps(kernels).encode(), "application/json"
|
|
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
|
else: status_code = 404
|
|
|
|
# send response
|
|
self.send_response(status_code)
|
|
self.send_header('Content-Type', content_type)
|
|
self.send_header('Content-Length', str(len(ret)))
|
|
self.end_headers()
|
|
return self.wfile.write(ret)
|
|
|
|
# ** main loop
|
|
|
|
def reloader():
|
|
mtime = os.stat(__file__).st_mtime
|
|
while not stop_reloader.is_set():
|
|
if mtime != os.stat(__file__).st_mtime:
|
|
print("reloading server...")
|
|
os.execv(sys.executable, [sys.executable] + sys.argv)
|
|
time.sleep(0.1)
|
|
|
|
def load_pickle(path:str):
|
|
if path is None or not os.path.exists(path): return None
|
|
with open(path, "rb") as f: return pickle.load(f)
|
|
|
|
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
|
class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
|
|
parser.add_argument('--profile', type=str, help='Path profile', default=None)
|
|
args = parser.parse_args()
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
|
|
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
|
|
stop_reloader = threading.Event()
|
|
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
|
|
st = time.perf_counter()
|
|
print("*** viz is starting")
|
|
|
|
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
|
|
|
|
# NOTE: this context is a tuple of list[keys] and list[values]
|
|
kernels = get_metadata(*contexts) if contexts is not None else []
|
|
|
|
perfetto_profile = to_perfetto(profile) if profile is not None else None
|
|
|
|
server = TCPServerWithReuse(('', PORT), Handler)
|
|
reloader_thread = threading.Thread(target=reloader)
|
|
reloader_thread.start()
|
|
print(f"*** started viz on {HOST}:{PORT}")
|
|
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
|
|
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
|
|
try: server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
print("*** viz is shutting down...")
|
|
stop_reloader.set()
|
|
|