openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
4.9 KiB

#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools
from typing import Callable, Any
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, to_function_name
from tinygrad.engine.grouper import get_kernelize_map
from tinygrad.codegen.kernel import Kernel
from tinygrad.uop.ops import UOp, Ops
# *** process replay settings
# internal
PAGE_SIZE = getenv("PAGE_SIZE", 100)
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
TABLE_NAME = f"process_replay_{VERSION}"
os.environ["CAPTURE_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
MAX_LINES = 500
def trunc_log(x):
if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"]
logging.info("\n".join(lines))
# user config
ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in os.getenv("PR_TITLE", flag))
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
class ProcessReplayWarning(Warning): pass
# *** replay the function and convert return values to string
def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
new_sink = big_sink.substitute(get_kernelize_map(big_sink))
def to_str(ret:UOp) -> str:
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(ret[big_sink]), (big_sink,)
def replay_linearize(k:Kernel, _:Kernel, name_override=None, ast_transform=None) -> tuple[str, str, tuple[Any, ...]]:
# create a copy because the Kernel class contains optimization parameters (other than applied_opts) in its state
# this should be made fully functional. It's fine for process replay since copy returns a fresh instance
k2 = k.copy()
k2.linearize(name_override=name_override or to_function_name(k.name), ast_transform=ast_transform)
def to_str(ret:Kernel) -> str:
try: return ret.opts.render(ret.uops)
except NotImplementedError: return "" # NULL backend doesn't have a renderer, this is okay
return to_str(k2), to_str(k), (k.ast, k.opts, k.applied_opts)
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "linearize":replay_linearize}
# *** run replayers on captured rows and print diffs
def diff(offset:int) -> None:
if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
if early_stop.is_set(): return None
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
if changed > MAX_DIFF_PCT:
warnings.warn(f"detected changes in over {MAX_DIFF_PCT}%. skipping further diff generation.", ProcessReplayWarning)
early_stop.set()
break
try:
name, args, kwargs, ctx_vals, loc, ret = pickle.loads(row[0])
ctx_vars = {k:v.value for k,v in ctx_vals.items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value}
if (replayer:=replayers.get(name)) is None: continue
with Context(**ctx_vars): good, compare, metadata = replayer(ret, *args, **kwargs)
if good != compare:
for m in metadata: trunc_log(m)
logging.info(loc)
for line in difflib.unified_diff(good.splitlines(), compare.splitlines()):
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
if ctx_vars: logging.info(ctx_vars)
warnings.warn("PROCESS REPLAY DETECTED CHANGE", ProcessReplayWarning)
except Exception as e:
changed += 1
warnings.warn(e, ProcessReplayWarning)
conn.commit()
cur.close()
# *** main loop
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
warnings.warn(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
exit(int(ASSERT_DIFF))
finally:
conn.commit()
cur.close()
logging.info(f"running process replay with {ASSERT_DIFF=}")
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(diff, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()