You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
203 lines
12 KiB
203 lines
12 KiB
from typing import cast, Optional, Callable
|
|
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
|
|
from collections import defaultdict
|
|
from dataclasses import replace
|
|
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
|
from tinygrad.device import Device, Buffer, Compiler
|
|
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
|
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
|
|
from tinygrad.dtype import ImageDType, PtrDType
|
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.engine.realize import CompiledRunner
|
|
from tinygrad.renderer import ProgramSpec
|
|
|
|
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
|
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
|
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
|
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
|
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
|
|
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
|
|
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
|
|
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
|
|
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
|
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
|
|
|
def _get_test_global_size(global_size, max_global_size, var_vals):
|
|
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
|
|
while prod(test_global_size) > max_global_size:
|
|
for j in range(len(global_size)-1,-1,-1):
|
|
if test_global_size[j] > 16:
|
|
test_global_size[j] //= 2
|
|
factor *= 2
|
|
break
|
|
return test_global_size, factor
|
|
|
|
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
|
|
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
|
factor = 1
|
|
if p.global_size is not None and max_global_size is not None:
|
|
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
|
p = replace(p, global_size=global_size)
|
|
try: car = CompiledRunner(p, precompiled=lib)
|
|
except AssertionError: return [math.inf] * cnt
|
|
tms = []
|
|
input_bufs = [rawbufs[i] for i in car.p.globals]
|
|
for _ in range(cnt):
|
|
if clear_l2:
|
|
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
|
|
else:
|
|
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
|
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
|
if early_stop is not None and early_stop < min(tms): break
|
|
return tms
|
|
|
|
class TimeoutException(Exception): pass
|
|
def timeout_handler(signum, frame): raise TimeoutException()
|
|
|
|
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
|
|
if hasattr(signal, "alarm"):
|
|
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
|
# set timeout
|
|
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
|
ret = None
|
|
try:
|
|
p = x[1].to_program(name_override="test")
|
|
assert p.uops is not None, "uop list wasn't generated?"
|
|
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
|
|
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
|
|
raise RuntimeError("too many uops")
|
|
st = time.perf_counter()
|
|
prog = compiler.compile(p.src)
|
|
et = time.perf_counter() - st
|
|
ret = (p, prog, et)
|
|
except RuntimeError:
|
|
if DEBUG >= 4: traceback.print_exc()
|
|
except Exception as e:
|
|
if getenv("BEAM_STRICT_MODE"): raise e
|
|
finally:
|
|
if hasattr(signal, "alarm"): signal.alarm(0)
|
|
return x[0], ret
|
|
|
|
# workers should ignore ctrl c
|
|
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
|
|
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
|
|
|
|
# *** external API ***
|
|
|
|
# get (scrap) buffers for timing the linearizer
|
|
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
|
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
|
for x in lin.bufs:
|
|
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
|
rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
|
|
for k,lx in bufsts.items():
|
|
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
|
assert isinstance(dtype, (PtrDType, ImageDType))
|
|
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
|
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
|
|
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
|
|
assert all(r is not None for r in rawbufs)
|
|
return cast(list[Buffer], rawbufs)
|
|
|
|
# get dictionary of all possible actions
|
|
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
|
kernel_actions = actions.copy()
|
|
|
|
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
|
|
for i, action in enumerate(kernel_actions):
|
|
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
|
# replace every tc_action with default tc with one tc_action for each available tc
|
|
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
|
|
|
for i,a in enumerate(kernel_actions):
|
|
if a.axis is not None and a.op is not OptOps.TC:
|
|
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
|
lin2 = lin.copy()
|
|
try:
|
|
lin2.apply_opt(a)
|
|
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
|
|
for s,c in zip(lin2.full_shape, lin2.colors()):
|
|
if c in {"magenta", "yellow"}: up *= s
|
|
elif c in {"cyan", "green", "white"}: lcl *= s
|
|
if up//tc_up > max_up or lcl > max_lcl:
|
|
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
|
|
continue
|
|
acted_lins[i+1] = lin2
|
|
except KernelOptError: pass
|
|
return acted_lins
|
|
|
|
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
|
def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel:
|
|
global beam_pool
|
|
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
|
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
|
ret = lin.copy()
|
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
|
return ret
|
|
|
|
beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
|
|
seen_libs = set()
|
|
|
|
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
|
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
|
@atexit.register
|
|
def close_pool(): beam_pool.close()
|
|
|
|
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
|
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
|
|
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
|
|
|
|
try:
|
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
|
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
|
exiting, st = False, time.perf_counter()
|
|
dev = Device[lin.opts.device]
|
|
while not exiting:
|
|
acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
|
timed_lins: list[tuple[Kernel, float]] = []
|
|
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
|
least_compute_ops = math.inf
|
|
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
|
if proc is None: continue
|
|
p, lib, compile_et = proc
|
|
if lib in seen_libs: continue
|
|
# filter out kernels that use 1000x more compute than the smallest
|
|
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
|
|
if least_compute_ops*1000 < this_compute_ops: continue
|
|
seen_libs.add(lib)
|
|
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
|
except RuntimeError: continue # for runtime issues
|
|
timed_lins.append((acted_lins[i], min(tms)))
|
|
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
|
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
|
|
|
# done
|
|
opts = sorted(timed_lins, key=lambda x: x[1])
|
|
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
|
if not exiting: beam = opts[:amt]
|
|
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
|
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
|
except KeyboardInterrupt as e:
|
|
if beam_pool is not None: beam_pool.terminate()
|
|
raise e
|
|
|
|
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
|
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
|
|
return beam[0][0]
|
|
|
|
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
|
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
|
MAX_WORKGROUP = 1024
|
|
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
|
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
|
def try_exec(local_size):
|
|
try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
|
except Exception: return float('inf')
|
|
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
|
return ret[1]
|
|
|