#!/usr/bin/env python3 # compare kernels created by HEAD against master import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools from typing import Callable, Any from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, to_function_name from tinygrad.engine.grouper import get_kernelize_map from tinygrad.codegen.kernel import Kernel from tinygrad.uop.ops import UOp, Ops # *** process replay settings # internal PAGE_SIZE = getenv("PAGE_SIZE", 100) REF = os.getenv("GITHUB_REF_NAME", "") MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) TABLE_NAME = f"process_replay_{VERSION}" os.environ["CAPTURE_PROCESS_REPLAY"] = "0" early_stop = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format="%(message)s") MAX_LINES = 500 def trunc_log(x): if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"] logging.info("\n".join(lines)) # user config ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in os.getenv("PR_TITLE", flag)) if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0 SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") if REF == "master": SKIP_PROCESS_REPLAY = True class ProcessReplayWarning(Warning): pass # *** replay the function and convert return values to string def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]: UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1) new_sink = big_sink.substitute(get_kernelize_map(big_sink)) def to_str(ret:UOp) -> str: asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL] return "\n".join([f"{len(asts)} kernels", *asts]) return to_str(new_sink), to_str(ret[big_sink]), (big_sink,) def replay_linearize(k:Kernel, _:Kernel, name_override=None, ast_transform=None) -> tuple[str, str, tuple[Any, ...]]: # create a copy because the Kernel class contains optimization parameters (other than applied_opts) in its state # this should be made fully functional. It's fine for process replay since copy returns a fresh instance k2 = k.copy() k2.linearize(name_override=name_override or to_function_name(k.name), ast_transform=ast_transform) def to_str(ret:Kernel) -> str: try: return ret.opts.render(ret.uops) except NotImplementedError: return "" # NULL backend doesn't have a renderer, this is okay return to_str(k2), to_str(k), (k.ast, k.opts, k.applied_opts) replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "linearize":replay_linearize} # *** run replayers on captured rows and print diffs def diff(offset:int) -> None: if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning) if early_stop.is_set(): return None conn = db_connection() cur = conn.cursor() cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) changed = 0 for row in cur.fetchall(): if changed > MAX_DIFF_PCT: warnings.warn(f"detected changes in over {MAX_DIFF_PCT}%. skipping further diff generation.", ProcessReplayWarning) early_stop.set() break try: name, args, kwargs, ctx_vals, loc, ret = pickle.loads(row[0]) ctx_vars = {k:v.value for k,v in ctx_vals.items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value} if (replayer:=replayers.get(name)) is None: continue with Context(**ctx_vars): good, compare, metadata = replayer(ret, *args, **kwargs) if good != compare: for m in metadata: trunc_log(m) logging.info(loc) for line in difflib.unified_diff(good.splitlines(), compare.splitlines()): logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) if ctx_vars: logging.info(ctx_vars) warnings.warn("PROCESS REPLAY DETECTED CHANGE", ProcessReplayWarning) except Exception as e: changed += 1 warnings.warn(e, ProcessReplayWarning) conn.commit() cur.close() # *** main loop if __name__ == "__main__": if SKIP_PROCESS_REPLAY: logging.info("skipping process replay.") exit(0) conn = db_connection() cur = conn.cursor() try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] except sqlite3.OperationalError: warnings.warn(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning) exit(int(ASSERT_DIFF)) finally: conn.commit() cur.close() logging.info(f"running process replay with {ASSERT_DIFF=}") with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool: inputs = list(range(0, row_count, PAGE_SIZE)) list(tqdm(pool.imap_unordered(diff, inputs), total=len(inputs))) pool.close() pool.join() pool.terminate()