from __future__ import annotations import math, itertools from collections import defaultdict from typing import cast, Final from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp from tinygrad.device import Buffer from tinygrad.dtype import AddrSpace, dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) # NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} class Scheduler: def __init__(self, ast:UOp, opts:Renderer): self.ast, self.opts = ast, opts self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else [] @property def rngs(self): # always in order by axistype return sorted([u for u in self.ast.parents if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1]) @property def shape_len(self): return len(self.rngs) @property def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs] @property def axis_types(self): return [x.arg[-1] for x in self.rngs] @property def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0) # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2'] def shape_str(self) -> list[str]: ret: list[str] = [] cnt: dict[AxisType, int] = {} for x in self.axis_types: cnt[x] = (cnt[x] + 1) if x in cnt else 0 ret.append(f"{axis_letters[x]}{cnt[x]}") return ret def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms]) def copy(self): ret = Scheduler(self.ast, self.opts) ret.dont_use_locals = self.dont_use_locals ret.applied_opts = self.applied_opts[:] return ret kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) def get_optimized_ast(self, name_override:str|None=None): if name_override is not None: name = name_override else: kernel_type = "r" if self.reduceop is not None else "E" name = kernel_type + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())]) Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1 num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else "" name += colored(num, 'BLACK') self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range") return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1) def _globalizable_rngs(self) -> list[UOp]: store_rngs = self.ast.src[0].src[2:] # filter any not in local stores local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \ or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)] for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else [] def convert_loop_to_global(self): if not self.opts.has_local: return None globalizible_rngs = self._globalizable_rngs() rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs] self.ast = self.ast.substitute(dict(zip(self.rngs, rng))) def colors(self) -> list[str]: return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types] def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())]) def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None): if (old_sz:=rng.src[0].divides(amount)) is None: raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}") new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),)) sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng) self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} {amount} {str(new_type).split('.')[1].lower()}") return replaced_rng, new_rng def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type] def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type] # copied from kernel.py @property def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \ if isinstance(s:=self.full_shape[i], int) and s > 1] @property def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \ if isinstance(s:=self.full_shape[i], int) and s > 1] def real_axis(self, op:OptOps, axis:int|None): try: if axis is None or op is OptOps.TC: return -1 if op is OptOps.UNROLL: return self.unrollable_dims[axis] if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis] check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}") return axis except IndexError as e: raise KernelOptError from e def apply_opt(self, opt:Opt, append_opt:bool=True): if opt.op is OptOps.NOLOCALS: check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals") if append_opt: self.applied_opts.append(opt) self.dont_use_locals = True return if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}: check(self.opts.has_local, "locals needed for opt") rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP) opt_to_at = { OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST, OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE, OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD} ret = None if opt.op in opt_to_at: amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg) # copied from kernel.py. prevents METAL compiler hangs if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]) smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") if opt.op is OptOps.UNROLL: check(amt <= 32, "don't unroll more than 32") check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE") if opt.op is OptOps.UPCAST: check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16") check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}") if opt.op is OptOps.LOCAL: check(not self.dont_use_locals, "can't use locals") check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals") if opt.op is OptOps.THREAD: check(self.opts is not None and self.opts.has_threads, "target does not support threads") check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads") check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded") check(rng in self._globalizable_rngs(), "can't apply range to this dim") if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong? check(not self.dont_use_locals, "can't use locals") check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce") ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD}) elif opt.op is OptOps.TC: check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps check(opt.axis is not None, "tensor core opts must have an axis") check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg") check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select") check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt") check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid") try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt) except ValueError as e: raise KernelOptError(str(e)) check(ret is not None, "no tensor core available") elif opt.op is OptOps.PADTO: check(rng.src[0].op is Ops.CONST, "only pad const axes") check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong? check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread") # ok to pad SUM if all parent ALU ops have f(0) = 0 if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}") new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg)) check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work") replaced_rng = UOp.range(new_sz, *rng.arg) replaces = {rng:replaced_rng} valid = replaced_rng < rng.vmax+1 for b in self.bufs: if rng in (i:=b.src[1].get_idx()).sparents: replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid()))) self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}") elif opt.op is OptOps.SWAP: try: altrng = self.rngs[opt.arg] except IndexError: raise KernelOptError check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals") self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1), altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)}) self.ast = graph_rewrite(self.ast, remove_tags) else: raise KernelOptError(f"unsupported opt {opt.op}") if append_opt: self.applied_opts.append(opt) return ret def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]: reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE] if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore") reduceop = reduceops[0] if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD: mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0] if mul.op is not Ops.MUL: return None in0, in1 = mul.src try: tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]] except IndexError: raise KernelOptError(f"invalid tensor core choice {tc_select}") for tc in tensor_cores: if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar(): # tensor cores have three ranges. X, Y, and REDUCE in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0]) in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0]) red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0]) if DEBUG >= 3: print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}", f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}") if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue # pick ranges # NOTE: why are in1 and in0 switched? axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges)) if not (axis < len(axis_choices)): continue axes = list(axis_choices[axis]) # do optimizations and save the ranges try: for i,a in enumerate(axes): idx = self.rngs.index(a) if (a.vmax+1) % tc.dims[i] != 0: if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2") # apply_opt should return the updated range? self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail axes[i] = self.rngs[idx] except KernelOptError: continue # we create the warp as a whole thing, in case some of these ranges are moved/removed later warp = UOp.range(tc.threads, -1, AxisType.WARP) ne: list[UOp] = [] for opt in tc.opts: if opt[0] == "l": axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2) warp //= 2 elif opt[0] == "u": axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST) else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores") ne.append(new_range) for _, amt in tc.get_reduce_axes(): axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL) ne.append(new_range) if use_tensor_cores != 2: # fix the srcs reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0] tne = [x.replace(tag=1) for x in ne] ret = reduceop.substitute(dict(zip(ne, tne))) srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))] # get reduce/upcast axes for the tensor cores tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())]) tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)]) # axes to range number (was done in lowerer) tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes]) tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes]) # construct the op # TODO: remove tc_upcast_axes from the arg # do the reduce_axes always disappear? i think they don't # they need to be moved into the WMMA srcs wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes) wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=( UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1), UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1), UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1) tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1) # preserve extra reduces reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes] if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD) self.ast = self.ast.substitute({reduceop: tc_uop}) return axes return None # helpers for hand_coded_optimizations @property def reduceop(self) -> UOp|None: red = [x for x in self.ast.parents if x.op is Ops.REDUCE] if not len(red): return None return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ())) @property def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1] @property def output_shape(self): return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)] @property def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL)) @property def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE)) def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]: glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg) return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls] def apply_opts(ctx:Renderer, ast:UOp): if ast.tag is not None: return None k = Scheduler(ast, ctx) k.convert_loop_to_global() if ast.arg is not None and ast.arg.opts_to_apply is not None: for opt in ast.arg.opts_to_apply: k.apply_opt(opt) elif BEAM >= 1: from tinygrad.codegen.opt.search import beam_search rawbufs = bufs_from_ast(ast, ctx.device) k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet if all(len(u.src) == 1 for u in ast.parents if u.op is Ops.LOAD): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) pm_postrange_opt = PatternMatcher([ (UPat(Ops.SINK, name="ast"), apply_opts), ])