# stuff needed to unpack a kernel from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View inf, nan = float('inf'), float('nan') UOps = Ops # kernel unpacker from tinygrad.codegen.kernel import Kernel def ast_str_to_ast(ast_str:str) -> UOp: return eval(ast_str) def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts) def kern_str_to_lin(kern_str:str, opts=None): (ast, applied_opts,) = eval(kern_str) k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) return k # load worlds, a dataset of about 12k kernels import gzip from pathlib import Path import random from tinygrad.helpers import dedup, DEBUG def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True): fn = Path(__file__).parent.parent / "datasets/sops.gz" ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n")) assert len(ast_strs) > 5000, f"dataset size = {len(ast_strs)} is too small" if DEBUG >= 1: print(f"loaded {len(ast_strs)=} before filters") if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x] if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x] if filter_novariable: ast_strs = [x for x in ast_strs if "DEFINE_VAR" not in x] if DEBUG >= 1: print(f"loaded {len(ast_strs)=} after filters {filter_reduce=}, {filter_noimage=}, {filter_novariable=}") random.seed(1337) random.shuffle(ast_strs) return ast_strs def assert_same_lin(l1, l2): assert l1.colored_shape() == l2.colored_shape() assert all(x==y for x,y in zip(l1.sts, l2.sts)) # get features import math MAX_DIMS = 16 MAX_BUFS = 9 def lin_to_feats(lin:Kernel, use_sts=True): assert lin.shape_len < MAX_DIMS, "too many dims" all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"] lc = [all_colors.index(x) for x in lin.colors()] ret = [] # before, some generic linearizer stuff ret.append(lin.upcasted) ret.append(lin.local_dims) # first, the full shape, including the colors for s,os,c in zip(lin.full_shape,lin.output_shape,lc): if isinstance(s, UOp): ret.append(False) ret += [0]*9 else: ret.append(True) ret.append(math.log2(s)) ret.append(min(33, s)) ret.append(math.log2(os)) ret.append(min(33, os)) ret.append(s%2 == 0) ret.append(s%3 == 0) ret.append(s%4 == 0) ret.append(s%8 == 0) ret.append(s%16 == 0) cc = [0]*7 cc[c] = 1 ret += cc ret += [0] * (17*(MAX_DIMS-len(lin.full_shape))) ret = [float(x) for x in ret] if use_sts: my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts]) assert len(my_sts) < MAX_BUFS sts_len = 3 + 5*MAX_DIMS for s in my_sts: ret.append(s[0]) # reduce ret.append(s[2]) # has mask ret.append(s[3]) # len views for d in s[1]: ret.append(d is None) ret.append(d == 0) ret.append(d == 1) ret.append(min(33, d) if d is not None else -1) if d is not None and d >= 1: ret.append(math.log2(d)) else: ret.append(-1) ret += [0] * (5*(MAX_DIMS - len(s[1]))) ret += [0] * (sts_len*(MAX_BUFS - len(my_sts))) assert len(ret) == 1021, f"wrong len {len(ret)}" else: assert len(ret) == 274, f"wrong len {len(ret)}" return ret from tinygrad.device import Device, Buffer from tinygrad.engine.search import _ensure_buffer_alloc, _time_program from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_get, diskcache_put def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) dev = Device[lin.opts.device] assert dev.compiler is not None rawbufs = _ensure_buffer_alloc(rawbufs) var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} p = lin.to_program() tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) return min(tms)