from typing import cast, Optional, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit from collections import defaultdict from dataclasses import replace from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE from tinygrad.dtype import ImageDType, PtrDType from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer import ProgramSpec actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)] actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)] if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)] actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))] actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals): test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1 while prod(test_global_size) > max_global_size: for j in range(len(global_size)-1,-1,-1): if test_global_size[j] > 16: test_global_size[j] //= 2 factor *= 2 break return test_global_size, factor def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None, max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]: factor = 1 if p.global_size is not None and max_global_size is not None: global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals) p = replace(p, global_size=global_size) try: car = CompiledRunner(p, precompiled=lib) except AssertionError: return [math.inf] * cnt tms = [] input_bufs = [rawbufs[i] for i in car.p.globals] for _ in range(cnt): if clear_l2: if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches() else: with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False) tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor) if early_stop is not None and early_stop < min(tms): break return tms class TimeoutException(Exception): pass def timeout_handler(signum, frame): raise TimeoutException() def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]: if hasattr(signal, "alarm"): signal.signal(getattr(signal, 'SIGALRM'), timeout_handler) # set timeout signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10)) ret = None try: p = x[1].to_program(name_override="test") assert p.uops is not None, "uop list wasn't generated?" if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0: if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}") raise RuntimeError("too many uops") st = time.perf_counter() prog = compiler.compile(p.src) et = time.perf_counter() - st ret = (p, prog, et) except RuntimeError: if DEBUG >= 4: traceback.print_exc() except Exception as e: if getenv("BEAM_STRICT_MODE"): raise e finally: if hasattr(signal, "alarm"): signal.alarm(0) return x[0], ret # workers should ignore ctrl c def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs] # *** external API *** # get (scrap) buffers for timing the linearizer def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: bufsts: defaultdict[int, list[UOp]] = defaultdict(list) for x in lin.bufs: if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x) rawbufs: list[Optional[Buffer]] = [None]*len(bufsts) for k,lx in bufsts.items(): buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx) assert isinstance(dtype, (PtrDType, ImageDType)) if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case. buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype) assert all(r is not None for r in rawbufs) return cast(list[Buffer], rawbufs) # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) kernel_actions = actions.copy() if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first for i, action in enumerate(kernel_actions): if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1: # replace every tc_action with default tc with one tc_action for each available tc kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)] for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue lin2 = lin.copy() try: lin2.apply_opt(a) up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1 for s,c in zip(lin2.full_shape, lin2.colors()): if c in {"magenta", "yellow"}: up *= s elif c in {"cyan", "green", "white"}: lcl *= s if up//tc_up > max_up or lcl > max_lcl: if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}") continue acted_lins[i+1] = lin2 except KernelOptError: pass return acted_lins beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel: global beam_pool key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None: ret = lin.copy() for o in val[len(lin.applied_opts):]: ret.apply_opt(o) return ret beam: list[tuple[Kernel, float]] = [(lin, float("inf"))] seen_libs = set() default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0 if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)): beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16)) @atexit.register def close_pool(): beam_pool.close() min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6 if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}") if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}") try: rawbufs = _ensure_buffer_alloc(rawbufs) var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} exiting, st = False, time.perf_counter() dev = Device[lin.opts.device] while not exiting: acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam]) timed_lins: list[tuple[Kernel, float]] = [] _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler) least_compute_ops = math.inf for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))): if proc is None: continue p, lib, compile_et = proc if lib in seen_libs: continue # filter out kernels that use 1000x more compute than the smallest least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops) if least_compute_ops*1000 < this_compute_ops: continue seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches')) except RuntimeError: continue # for runtime issues timed_lins.append((acted_lins[i], min(tms))) if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501 elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 # done opts = sorted(timed_lins, key=lambda x: x[1]) exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress)) if not exiting: beam = opts[:amt] elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1] if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501 except KeyboardInterrupt as e: if beam_pool is not None: beam_pool.terminate() raise e if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}") return beam[0][0] def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]: test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs MAX_WORKGROUP = 1024 local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice def try_exec(local_size): try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501 except Exception: return float('inf') ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) assert not math.isinf(ret[0]), "all optimize_local_size exec failed" return ret[1]