openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

120 lines
4.9 KiB

#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.codegen.kernel import Kernel, Opt
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp
# *** process replay settings
# internal
PAGE_SIZE = getenv("PAGE_SIZE", 100)
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
TABLE_NAME = f"process_replay_{VERSION}"
os.environ["RUN_PROCESS_REPLAY"] = "0"
os.environ["CAPTURE_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
MAX_LINES = 500
def trunc_log(x):
if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"]
logging.info("\n".join(lines))
# user config
ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in os.getenv("PR_TITLE", flag))
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
class ProcessReplayWarning(Warning): pass
# *** recreators
def recreate_sched(big_sink:UOp) -> list[UOp]:
sched, _, __ = create_schedule_with_vars(big_sink)
return [x.ast for x in sched]
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
return k.opts.render(cast(list,k.to_program(name).uops))
# *** diff a "good" recreation against the generated version
def diff(offset:int, name:str, fxn:Callable) -> None:
# TODO: add this assert back for schedule
if ASSERT_DIFF and name != "schedule": warnings.filterwarnings("error", category=ProcessReplayWarning)
if early_stop.is_set(): return None
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{name}_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
if changed > MAX_DIFF_PCT:
warnings.warn(f"detected changes in over {MAX_DIFF_PCT}% of {name}s. skipping further diff generation.")
early_stop.set()
break
# try unpickle
try: args = pickle.loads(row[0])
except Exception as e:
changed += 1
warnings.warn(f"FAILED TO UNPICKLE OBJECTS {e}", ProcessReplayWarning)
continue
# try recreate
try:
ctx_vars = {k:v.value for k,v in args[-2].items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value}
with Context(**ctx_vars): good = fxn(*args[:-2])
if good is None: continue
except Exception as e:
changed += 1
warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning)
if ctx_vars: logging.info(ctx_vars)
for x in args[:-2]: trunc_log(x)
continue
# diff kernels
try: assert str(args[-1]) == str(good)
except AssertionError:
changed += 1
if ctx_vars: logging.info(ctx_vars)
for x in args[:-2]: trunc_log(x)
changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines()))
logging.info("\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in changes))
warnings.warn("PROCESS REPLAY DETECTED CHANGE", ProcessReplayWarning)
conn.commit()
cur.close()
# *** generic runner for executing fxn across all rows of a table in parallel
def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
return None
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
# *** main loop
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
print(f"running process replay with {ASSERT_DIFF=}")
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
logging.info(f"***** {name} diff")
try: _pmap(name, fxn)
except Exception as e:
if ASSERT_DIFF: raise e
logging.error(f"{name} diff err {e}")