openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

410 lines
22 KiB

#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
import subprocess, ctypes, pathlib, traceback
from contextlib import redirect_stdout, redirect_stderr
from decimal import Decimal
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
# VIZ API
# ** list all saved rewrites
ref_map:dict[Any, int] = {}
def get_rewrites(t:RewriteTrace) -> list[dict]:
ret = []
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc), "trace":k.tb if j == 0 else None,
"query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)]
if isinstance(k.ret, ProgramSpec):
steps.append({"name":"View UOp List", "query":f"/render?ctx={i}&fmt=uops", "depth":0})
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src", "depth":0})
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm", "depth":0})
for key in k.keys: ref_map[key] = i
ret.append({"name":k.display_name, "steps":steps})
return ret
# ** get the complete UOp graphs for one rewrite
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
def pystr(u:UOp, i:int) -> str:
try: return pyrender(u)
except Exception: return str(u)
def uop_to_json(x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
if u.op is Ops.KERNEL:
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
try:
if len(rngs:=u.ranges):
label += f"\n({multirange_str(rngs, color=True)})"
if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
ranges: list[UOp] = []
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]
if ranges: label += "\n"+' '.join([f"{s.render()}={s.vmax+1}" for s in ranges])
if u.op in {Ops.END, Ops.REDUCE} and len(trngs:=list(UOp.sink(*u.src[range_start[u.op]:]).ranges)):
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+str(u.metadata)
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
return graph
@functools.cache
def _reconstruct(a:int):
op, dtype, src, arg, *rest = trace.uop_fields[a]
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink,i),
"changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0,i).splitlines(),pystr(u1,i).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink
# encoder helpers
def enum_str(s, cache:dict[str, int]) -> int:
if (cret:=cache.get(s)) is not None: return cret
cache[s] = ret = len(cache)
return ret
def option(s:int|None) -> int: return 0 if s is None else s+1
# Profiler API
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
for e in profile:
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device, e.is_copy)), (e.en if e.en is not None else e.st)+diff, e)
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
elif isinstance(e, ProfileGraphEvent):
cpu_ts = []
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device, ent.is_copy)), e.sigs[ent.en_id]+diff]
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
# normalize event timestamps and attach kernel metadata
def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
events:list[bytes] = []
exec_points:dict[str, ProfilePointEvent] = {}
for st,et,dur,e in dev_events:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
if dur == 0: continue
name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
flops = sym_infer(p.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(p.estimates.mem, var_vals)/t, sym_infer(p.estimates.lds, var_vals)/t
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
(f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem",
(f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"]
if (metadata_str:=",".join([str(m) for m in (ei.arg['metadata'] or ())])): fmt.append(metadata_str)
if isinstance(e, ProfileGraphEntry): fmt.append("(batched)")
key = ei.key
elif isinstance(e.name, TracingKey):
name = e.name.display_name
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
events.append(struct.pack("<IIIIfI", enum_str(name, scache), option(ref), option(key), st-start_ts, dur, enum_str("\n".join(fmt), scache)))
return struct.pack("<BI", 0, len(events))+b"".join(events) if events else None
def encode_mem_free(key:int, ts:int, execs:list[ProfilePointEvent], scache:dict) -> bytes:
ei_encoding:list[tuple[int, int, int, int]] = [] # <[u32, u32, u8, u8] [run id, display name, buffer number and mode (2 = r/w, 1 = w, 0 = r)]
for e in execs:
num = next(i for i,k in enumerate(e.arg["bufs"]) if k == key)
mode = 2 if (num in e.arg["inputs"] and num in e.arg["outputs"]) else 1 if (num in e.arg["outputs"]) else 0
ei_encoding.append((e.key, enum_str(e.arg["name"], scache), num, mode))
return struct.pack("<BIII", 0, ts, key, len(ei_encoding))+b"".join(struct.pack("<IIBB", *t) for t in ei_encoding)
def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtype_size:dict[str, int],
scache:dict[str, int]) -> bytes|None:
peak, mem = 0, 0
temp:dict[int, int] = {}
events:list[bytes] = []
buf_ei:dict[int, list[ProfilePointEvent]] = {}
for st,_,_,e in dev_events:
if not isinstance(e, ProfilePointEvent): continue
if e.name == "alloc":
safe_sz = min(1_000_000_000_000, e.arg["sz"])
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), safe_sz))
dtype_size.setdefault(e.arg["dtype"].name, e.arg["dtype"].itemsize)
temp[e.key] = nbytes = safe_sz*e.arg["dtype"].itemsize
mem += nbytes
if mem > peak: peak = mem
if e.name == "exec" and e.arg["bufs"]:
for b in e.arg["bufs"]: buf_ei.setdefault(b, []).append(e)
if e.name == "free":
events.append(encode_mem_free(e.key, int(e.ts) - start_ts, buf_ei.pop(e.key, []), scache))
mem -= temp.pop(e.key)
for t in temp: events.append(encode_mem_free(t, end_ts-start_ts, buf_ei.pop(t, []), scache))
peaks.append(peak)
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
def load_sqtt(profile:list[ProfileEvent]) -> None:
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None
def err(name:str, msg:str|None=None) -> None:
step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"}
return ctxs.append({"name":"Counters", "steps":[step]})
try: from extra.sqtt.roc import decode
except Exception: return err("DECODER IMPORT ISSUE")
try: rctx = decode(profile)
except Exception: return err("DECODER ERROR")
if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
steps:list[dict] = []
for name,waves in rctx.inst_execs.items():
if (r:=ref_map.get(name)): name = ctxs[r]["name"]
steps.append({"name":name, "depth":0, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters",
"data":{"src":trace.keys[r].ret.src if r else name, "lang":"cpp"}})
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
# The idle time can be caused by:
# * Arbiter loss
# * Source or destination register dependency
# * Instruction cache miss
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
for w in waves:
rows, prev_instr = [], w.begin_time
for i,e in enumerate(w.insts):
rows.append((e.inst, e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1]))
prev_instr = max(prev_instr, e.time + e.dur)
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}]
steps.append({"name":f"Wave {w.wave_id}", "depth":1, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters",
"data":{"rows":rows, "cols":["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"], "summary":summary}})
ctxs.append({"name":"Counters", "steps":steps})
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
# start by getting the time diffs
for ev in profile:
if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
# load device specific counters
device_decoders:dict[str, Callable[[list[ProfileEvent]], None]] = {}
for device in device_ts_diffs:
d = device.split(":")[0]
if d == "AMD": device_decoders[d] = load_sqtt
for fxn in device_decoders.values(): fxn(profile)
# map events per device
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
markers:list[ProfilePointEvent] = []
start_ts:int|None = None
end_ts:int|None = None
for ts,en,e in flatten_events(profile):
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
if start_ts is None or st < start_ts: start_ts = st
if end_ts is None or et > end_ts: end_ts = et
if isinstance(e, ProfilePointEvent) and e.name == "marker": markers.append(e)
if start_ts is None: return None
# return layout of per device events
layout:dict[str, bytes|None] = {}
scache:dict[str, int] = {}
peaks:list[int] = []
dtype_size:dict[str, int] = {}
for k,v in dev_events.items():
v.sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts, scache)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)
groups = sorted(layout.items(), key=lambda x: '' if len(ss:=x[0].split(" ")) == 1 else ss[1])
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), v]) for k,v in groups if v is not None]
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":int(e.ts-start_ts), **e.arg} for e in markers]}).encode()
return struct.pack("<IQII", unwrap(end_ts)-start_ts, max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
# ** Assembly analyzers
def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
target_args = [f"-mtriple={mtriple}", f"-mcpu={mcpu}"]
# disassembly output can include headers / metadata, skip if llvm-mca can't parse those lines
data = json.loads(subprocess.check_output(["llvm-mca","-skip-unsupported-instructions=parse-failure","--json","-"]+target_args, input=asm.encode()))
cr = data["CodeRegions"][0]
resource_labels = [repr(x)[1:-1] for x in data["TargetInfo"]["Resources"]]
rows:list = [[instr] for instr in cr["Instructions"]]
# add scheduler estimates
for info in cr["InstructionInfoView"]["InstructionList"]: rows[info["Instruction"]].append(info["Latency"])
# map per instruction resource usage
instr_usage:dict[int, dict[int, int]] = {}
for d in cr["ResourcePressureView"]["ResourcePressureInfo"]:
instr_usage.setdefault(i:=d["InstructionIndex"], {}).setdefault(r:=d["ResourceIndex"], 0)
instr_usage[i][r] += d["ResourceUsage"]
# last row is the usage summary
summary = [{"idx":k, "label":resource_labels[k], "value":v} for k,v in instr_usage.pop(len(rows), {}).items()]
max_usage = max([sum(v.values()) for i,v in instr_usage.items() if i<len(rows)], default=0)
for i,usage in instr_usage.items(): rows[i].append([[k, v, (v/max_usage)*100] for k,v in usage.items()])
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
def get_stdout(f: Callable) -> str:
buf = io.StringIO()
try:
with redirect_stdout(buf), redirect_stderr(buf): f()
except Exception: traceback.print_exc(file=buf)
return buf.getvalue()
def get_render(i:int, j:int, fmt:str) -> dict|None:
if fmt == "counters": return ctxs[i]["steps"][j]["data"]
if not isinstance(prg:=trace.keys[i].ret, ProgramSpec): return None
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(prg.uops or [])), "lang":"txt"}
if fmt == "src": return {"src":prg.src, "lang":"cpp"}
compiler = Device[prg.device].compiler
disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(prg.src)))
from tinygrad.runtime.support.compiler_cpu import llvm, LLVMCompiler
if isinstance(compiler, LLVMCompiler):
mtriple = ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode()
mcpu = ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode()
ret = get_llvm_mca(disasm_str, mtriple, mcpu)
else: ret = {"src":disasm_str, "lang":"x86asm"}
return ret
# ** HTTP server
def get_int(query:dict[str, list[str]], k:str) -> int: return int(query.get(k,["0"])[0])
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
if (url:=urlparse(self.path)).path == "/":
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"
if url.path.endswith(".css"): content_type = "text/css"
except FileNotFoundError: status_code = 404
elif (query:=parse_qs(url.query)):
if url.path == "/render":
render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), query["fmt"][0])
ret, content_type = json.dumps(render_src).encode(), "application/json"
else:
try: return self.stream_json(get_full_rewrite(trace.rewrites[i:=get_int(query, "ctx")][get_int(query, "idx")], i))
except (KeyError, IndexError): status_code = 404
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
else: status_code = 404
# send response
self.send_response(status_code)
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
def stream_json(self, source:Generator):
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: END\n\n".encode("utf-8"))
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
# ** main loop
def reloader():
mtime = os.stat(__file__).st_mtime
while not stop_reloader.is_set():
if mtime != os.stat(__file__).st_mtime:
print("reloading server...")
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
T = TypeVar("T")
def load_pickle(path:pathlib.Path, default:T) -> T:
if not path.exists(): return default
with path.open("rb") as f: return pickle.load(f)
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
stop_reloader = threading.Event()
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
print("*** viz is starting")
ctxs = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
profile_ret = get_profile(load_pickle(args.profile, default=[]))
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(f"*** started viz on {HOST}:{PORT}")
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()