You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
709 lines
43 KiB
709 lines
43 KiB
from __future__ import annotations
|
|
import itertools, functools, math
|
|
from dataclasses import dataclass
|
|
from collections import defaultdict
|
|
from typing import Optional, cast, Final, Callable, Sequence
|
|
|
|
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
|
|
from tinygrad.ops import PatternMatcher, UPat
|
|
from tinygrad.spec import type_verify, shape_spec
|
|
from tinygrad.device import Device
|
|
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
|
|
from tinygrad.dtype import ImageDType
|
|
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
|
|
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.shape.view import strides_for_shape
|
|
from tinygrad.codegen.linearize import linearize_uop
|
|
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
|
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
|
|
|
class KernelOptError(Exception): pass
|
|
|
|
def check(cond:bool, msg:str=""):
|
|
if not cond: raise KernelOptError(msg)
|
|
|
|
@dataclass
|
|
class TensorCoreOptions:
|
|
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
|
axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
|
|
axis_pads: tuple[tuple[int, int], ...]
|
|
def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
|
|
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
|
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
|
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
|
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
|
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
|
|
|
class Kernel:
|
|
def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
|
|
assert ast.op is Ops.SINK, ast.op
|
|
self.ast = ast
|
|
|
|
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
|
# verify AST matches the spec
|
|
if __debug__: type_verify(list(self.ast.toposort), shape_spec)
|
|
|
|
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
|
|
|
self.vars: list[Variable] = self.ast.variables()
|
|
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
|
self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
|
|
|
|
# get earlybufs, before any reduceops
|
|
earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
|
|
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
|
# NOTE: full_shape can be wrong if there's a tree of reduces
|
|
|
|
# create new shapetrackers inside this kernel, we will permute them
|
|
self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
|
|
|
|
# add the shapetrackers for each reduce
|
|
# we use this to track which axes are reduced in each reduce
|
|
for x in self.reduceops:
|
|
self.sts.append(unwrap(x.st))
|
|
self.sts.append(unwrap(x.src[0].st))
|
|
|
|
# move all reduce axes to the end
|
|
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
|
permute = tuple([i for i,(s,n) in reduce if not resolve(s != n)] + [i for i,(s,n) in reduce if resolve(s != n)])
|
|
self.reshape_and_permute(None, permute)
|
|
|
|
# parameters for optimization
|
|
self.applied_opts: list[Opt] = []
|
|
self.group_for_reduces: int = 0
|
|
self.upcasted: int = 0
|
|
self.local_dims: int = 0
|
|
self.tensor_core: Optional[TensorCore] = None
|
|
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
|
self.use_tensor_cores: int = 0
|
|
self.dont_use_locals: bool = False
|
|
self.lds: list[bool] = [False] * len(self.bufs)
|
|
|
|
# group simplifies
|
|
self.simplify_ones()
|
|
self.simplify_merge_adjacent()
|
|
|
|
def copy(self):
|
|
ret = type(self).__new__(type(self))
|
|
|
|
# base linearizer params
|
|
ret.opts, ret.ast = self.opts, self.ast
|
|
|
|
# things downstream of the AST
|
|
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
|
|
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
|
|
|
# parameters for optimizations
|
|
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals, ret.lds = \
|
|
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals, self.lds
|
|
ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
|
|
|
|
return ret
|
|
|
|
@property
|
|
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
|
|
|
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
|
|
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
|
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
|
return list(zip(upcasted_shape, upcasted_stride,
|
|
[x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
|
|
|
|
@property
|
|
def first_reduce(self) -> int:
|
|
return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
|
|
|
|
@property
|
|
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
|
|
|
@property
|
|
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
|
|
|
@property
|
|
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
|
|
|
@property
|
|
def full_shape(self) -> tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
|
|
|
@property
|
|
def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
|
|
|
|
@property
|
|
def shape_len(self) -> int: return len(self.sts[0].shape)
|
|
|
|
@property
|
|
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
|
|
|
# there's eight chunks of the shape
|
|
# blue -- global dims
|
|
# cyan -- local dims (warp ones first)
|
|
# *** self.first_reduce
|
|
# green -- reduce-local dims
|
|
# red -- reduce loops
|
|
# *** self.upcasted
|
|
# purple -- reduce upcasted
|
|
# yellow -- normal upcasted dimensions
|
|
def colors(self) -> list[str]:
|
|
# first non local non reduce dims are global (blue)
|
|
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
|
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
|
colors += ["cyan"] * self.local_dims
|
|
# between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
|
|
colors += ["green"] * self.group_for_reduces
|
|
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
|
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
|
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
|
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
|
|
assert len(colors) == self.shape_len, "colors size mismatch"
|
|
return colors
|
|
|
|
def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
|
|
shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
|
|
ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
|
|
if pad: ret += ' '*(pad-ansilen(ret))
|
|
return ret
|
|
|
|
# ******************** base simplifiers ********************
|
|
|
|
# apply reshape and permute to all shapetrackers
|
|
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
|
|
def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
|
|
def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
|
|
self.sts = [permute(reshape(st)) for st in self.sts]
|
|
|
|
# drops the final dimension
|
|
def upcast(self):
|
|
check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
|
|
self.upcasted += 1
|
|
|
|
# axis : the axis to pull from
|
|
# amount : the amount to take
|
|
# top : if you want to pull that amount from the top
|
|
# insert_before : place to insert the new stuff
|
|
def shift_to(self, axis, amount, top=False, insert_before=None):
|
|
if insert_before is None: insert_before = self.shape_len
|
|
move_axis = axis if top else axis+1
|
|
if move_axis < insert_before: insert_before += 1
|
|
self.reshape_and_permute(
|
|
lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
|
|
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
|
|
|
# ******************** complex simplifiers ********************
|
|
|
|
def simplify_ones(self) -> bool:
|
|
# remove places where the shape is all ones
|
|
# TODO: this should be factored in to multi shape stride
|
|
if self.shape_len == 0: return False
|
|
all_ones = [s==1 for s in self.full_shape]
|
|
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
|
self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
|
|
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
|
return any(all_ones)
|
|
|
|
def simplify_merge_adjacent(self):
|
|
if self.shape_len == 0: return
|
|
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
|
|
|
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
|
if isinstance(self.membufs[0].dtype, ImageDType):
|
|
base_shape = self.membufs[0].dtype.shape
|
|
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
|
special_strides: tuple[sint, ...] = tuple()
|
|
for i,g in enumerate(shape_idx_groups):
|
|
shape_piece = tuple(self.output_shape[x] for x in g)
|
|
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
|
special_strides += strides_for_shape(shape_piece)
|
|
# adding the fake image shape
|
|
shapes.append(self.output_shape)
|
|
strides.append(special_strides)
|
|
|
|
# merge dimensions if we can, multi _merge_dims
|
|
# NOTE: this does not always preserve the reduce dimension
|
|
# TODO: move this into shapetracker, with tests!
|
|
# TODO: how does this work with multi-reduce?
|
|
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
|
|
for i in range(1, len(shapes[0])):
|
|
can_merge = []
|
|
for s,st,ret in zip(shapes, strides, rets):
|
|
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
|
si, sti, last_st = s[i], st[i], ret[-1][1]
|
|
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
|
# more can merge than this
|
|
mergeable = all(can_merge) and i != self.first_reduce
|
|
for j,(s,st) in enumerate(zip(shapes, strides)):
|
|
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
|
else: rets[j].append((s[i], st[i]))
|
|
|
|
# do the reshapes
|
|
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
|
|
|
# ******************** high level optimizers ********************
|
|
|
|
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
|
has_cast = tc.dtype_in != tc.dtype_out
|
|
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
|
|
|
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
|
if mul_op.op is not Ops.MUL: return None
|
|
|
|
def buf_index(src:UOp) -> Optional[int]:
|
|
# TODO: apply tc even if the sources are not from LOAD
|
|
if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
|
|
try:
|
|
if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
|
|
except ValueError: return None
|
|
return None
|
|
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
|
|
|
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
|
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
|
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
|
if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
|
|
|
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
|
if not (axis < len(axis_choices)): return None
|
|
|
|
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
|
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
|
|
if axis_pads and (opt_level < 2): return None
|
|
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
|
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
|
|
|
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
|
|
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
|
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
|
for tc in tensor_cores:
|
|
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
|
# can only fuse reduces with the same tc options
|
|
assert all_same(tensor_core_opts)
|
|
if tensor_core_opts[0] is None: continue
|
|
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
|
|
|
# attempt to pad the tensor axes that require it
|
|
try:
|
|
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
|
except KernelOptError: continue
|
|
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
|
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
|
|
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
|
self.tensor_core = tc
|
|
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
|
return True
|
|
return False
|
|
|
|
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
|
|
tc_opt:Optional[int]=None) -> bool:
|
|
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
|
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
|
|
|
Keyword arguments:
|
|
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
|
0: will disable any tensor core matching
|
|
1: enable tensor cores
|
|
2: apply tensor core shape but don't use UOp.WMMA
|
|
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
|
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
|
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
|
[0-N]: uses only the n'th tensor core available; useful for search
|
|
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
|
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
|
|
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
|
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
|
"""
|
|
if tc_select is None: tc_select = TC_SELECT.value
|
|
if tc_opt is None: tc_opt = TC_OPT.value
|
|
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
|
try: # check TC first and apply hand-coded opts if successful
|
|
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
|
|
|
|
if (tc_opts:=self.tensor_core_opts) is not None:
|
|
if extra_opts is not None:
|
|
for opt in extra_opts: self.apply_opt(opt)
|
|
else:
|
|
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
|
# hand-coded TC opts
|
|
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
|
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
|
if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
|
|
|
|
if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
|
|
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
|
|
return True
|
|
except KernelOptError:
|
|
return False
|
|
|
|
def real_axis(self, opt:Opt):
|
|
if opt.axis is None: return -1
|
|
if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
|
|
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
|
|
return opt.axis
|
|
|
|
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
|
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
|
|
|
if opt.op is OptOps.TC:
|
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
|
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
|
check(opt.axis is not None, "tensor core opts must have an axis")
|
|
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
|
|
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
|
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
|
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
|
self.applied_opts.append(opt)
|
|
return
|
|
|
|
axis = self.real_axis(opt)
|
|
if opt.op != OptOps.LDS: check(axis < len(self.full_shape), "invalid axis")
|
|
|
|
if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
|
|
elif opt.arg is not None:
|
|
check(isinstance(opt.arg, int), "arg should be int")
|
|
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
|
|
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
|
|
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
|
else: amt = -1
|
|
|
|
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
|
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
|
acc_sz = self.reduceop.dtype.itemsize
|
|
upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
|
|
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
|
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
|
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
|
|
|
if opt.op is OptOps.LOCAL: # cyan
|
|
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
|
# it's disabled for now since it makes BEAM slow for little gain
|
|
check(self.opts.has_local, "target does not support local")
|
|
check(axis < self.global_dims, "local is for globals")
|
|
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
|
self.local_dims += 1
|
|
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
|
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
|
check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
|
|
check(not self.tensor_core, "can't group with tensor cores")
|
|
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
|
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
|
self.group_for_reduces += 1
|
|
elif opt.op is OptOps.UNROLL: # purple
|
|
check(axis < self.first_upcast, "can't upcasted already upcasted")
|
|
check(amt <= 32, "don't unroll more than 32")
|
|
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
|
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
|
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
|
|
if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
|
|
if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
|
|
self.shift_to(axis, amt, insert_before=None)
|
|
self.upcast()
|
|
elif opt.op is OptOps.UPCAST: # yellow
|
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
|
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
|
|
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
|
self.shift_to(axis, amt, insert_before=None)
|
|
self.upcast()
|
|
elif opt.op is OptOps.NOLOCALS:
|
|
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
|
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
|
self.dont_use_locals = True
|
|
elif opt.op is OptOps.SWAP:
|
|
check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
|
|
permute = list(range(self.shape_len))
|
|
permute[axis], permute[amt] = permute[amt], permute[axis]
|
|
self.reshape_and_permute(None, tuple(permute))
|
|
elif opt.op is OptOps.PADTO:
|
|
check(not self.vars, "does not work with symbolic shape")
|
|
check(axis < self.first_upcast, "cannot pad upcasted")
|
|
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
|
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
|
|
padded = False
|
|
for i,st in enumerate(self.sts):
|
|
if (s:=st.shape[axis]) == 1: continue # reduced
|
|
check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
|
|
if (ru := round_up(cast(int, s), amt) - s):
|
|
# pad right seems to be faster
|
|
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
|
padded = True
|
|
check(padded, "nothing was padded")
|
|
elif opt.op is OptOps.LDS:
|
|
check(0 <= axis < len(self.bufs), f"invalid buffer {axis}")
|
|
self.lds = self.lds[:axis] + [True] + self.lds[axis+1:]
|
|
|
|
if append_opt: self.applied_opts.append(opt)
|
|
if self.simplify_ones() and self.tensor_core_opts:
|
|
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
|
|
|
def required_optimizations(self) -> Kernel:
|
|
if isinstance(self.membufs[0].dtype, ImageDType):
|
|
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
|
assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
|
|
if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
|
return self
|
|
|
|
def hand_coded_optimizations(self) -> Kernel:
|
|
self.required_optimizations()
|
|
|
|
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
|
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
|
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
|
self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
|
(mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
|
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
|
|
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
|
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
|
if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
|
for global_idx in range(self.global_dims):
|
|
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
|
if DEBUG >= 3:
|
|
print(f"MATVEC: {self.full_shape=} {self.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
|
if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
|
if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
|
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
|
return self
|
|
|
|
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
|
|
# are we grouping? (requires local shape support)
|
|
if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \
|
|
self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
|
# TODO: use 1024 if it's allowed in a smarter way
|
|
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
|
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
|
try: # may fail due to excessive smem usage
|
|
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
|
break
|
|
except KernelOptError: pass
|
|
|
|
# upcast float4 images
|
|
for buf_index,buf in enumerate(self.bufs):
|
|
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
|
if buf.src[0].dtype.__class__ is ImageDType:
|
|
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
|
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
|
|
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
|
else:
|
|
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
|
|
|
# no more opt if we are grouping
|
|
if self.group_for_reduces: return self
|
|
|
|
# **** below this line need to be optional and benchmarked ****
|
|
|
|
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
|
# to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
|
|
# expression and run test/test_ops.py with IMAGE=2
|
|
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
|
# this can be made much smarter
|
|
to_upcast: list[int] = []
|
|
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
|
for axis in range(self.first_reduce):
|
|
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
|
# for now skip upcasting here if there is a symbolic axis
|
|
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
|
prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
|
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
|
to_upcast.append(axis)
|
|
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
|
|
|
# potentially do more upcasts of non reduce axes based on a heuristic
|
|
is_dsp = self.opts is not None and self.opts.device == "DSP"
|
|
upcasted_axis: set[int] = set()
|
|
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
|
|
xb_choices = []
|
|
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
|
|
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
|
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
|
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
|
|
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
|
|
if xb_choices:
|
|
xb_choices = sorted(xb_choices)
|
|
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
|
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
|
upcasted_axis.add(xb_choices[0][2])
|
|
else: break
|
|
|
|
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
|
|
if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
|
|
if isinstance(s:=self.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
|
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
|
# if it's small, upcast a second reduce dimension too
|
|
if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
|
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
|
else:
|
|
for splits in [4]:
|
|
if self.full_unupcasted_shape[-1]%splits == 0:
|
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
|
|
break
|
|
|
|
# if nothing at all is upcasted and it's easy to, do an upcast
|
|
# TODO: this is breaking the tests
|
|
for splits in [4]:
|
|
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
|
|
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
|
|
|
|
# **** local groups ****
|
|
|
|
if self.opts.has_local:
|
|
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
|
|
self.apply_opt(Opt(OptOps.NOLOCALS))
|
|
else:
|
|
# prioritize making expand axes local
|
|
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
|
|
to_local: list[tuple[int, int]] = []
|
|
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
|
local_size = prod(sz for _, sz in to_local)
|
|
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
|
|
if local_sz is not None: to_local.append((axis, local_sz))
|
|
deleted_shape = 0
|
|
for axis, local_sz in sorted(to_local[:3]):
|
|
axis = axis - deleted_shape
|
|
will_delete_shape = local_sz == self.full_shape[axis]
|
|
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
|
if will_delete_shape: deleted_shape += 1
|
|
|
|
return self
|
|
|
|
# **** kernel outputs ****
|
|
|
|
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
|
@functools.cached_property
|
|
def name(self) -> str:
|
|
# kernel name (before late upcast)
|
|
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
|
|
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
|
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
|
|
|
# name the function something unique
|
|
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
|
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
|
return name + colored(num, 'BLACK')
|
|
|
|
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
|
|
@functools.lru_cache(None)
|
|
def fixup_ast(op:UOp) -> UOp:
|
|
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
|
if op.op in GroupOp.Buffer and op in self.bufs:
|
|
st_uop = self.sts[self.bufs.index(op)].to_uop()
|
|
# NOTE: if CONST got masked after applying opts, we create a new VALID
|
|
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
|
|
# otherwise we just replace the VIEW source
|
|
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
|
if op.op is Ops.SINK:
|
|
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
|
self.local_dims, self.upcasted, self.dont_use_locals))
|
|
if op.op is Ops.REDUCE_AXIS:
|
|
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
|
|
|
def reduced_axes(start, stop):
|
|
return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
|
axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
|
|
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
|
|
|
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
|
wd, tcd = self.global_dims, self.first_upcast
|
|
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
|
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
|
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
|
def get_tc_swizzle_st(shape, local_perm, upcast_perm):
|
|
offset = (tcd - (wd + len(local_perm)))
|
|
permaxis = list(range(wd)) \
|
|
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
|
|
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
|
|
return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
|
|
|
|
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
|
for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
|
|
src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
|
|
if swizzle: srcs[i] = src.view(get_tc_swizzle_st(src_st.shape, *swizzle))
|
|
|
|
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
|
local_shape = tuple(1 if st == 0 or i < wd or (i >= self.first_reduce and i < tcd) else src_st.shape[i] \
|
|
for i,st in enumerate(src_st.real_strides()))
|
|
st = store_st = ShapeTracker.from_shape(local_shape)
|
|
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
|
|
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
|
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
|
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
|
|
|
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
|
|
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
|
|
tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
|
|
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
|
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
|
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
|
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
|
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
|
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
|
|
|
else: # for TC=3 MUL/SUM instead of WMMA
|
|
tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
|
|
|
|
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
|
|
|
ret = ret.replace(arg = (op.arg[0], axes))
|
|
if self.group_for_reduces and grouped_axes:
|
|
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
|
|
tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
|
|
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
|
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
|
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
|
local_size = st_uop.arg.real_size()
|
|
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
|
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
|
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
|
if op is self.reduceops[-1]: return grouped_reduce
|
|
st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
|
|
return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
|
|
|
|
return ret
|
|
|
|
return graph_rewrite(fixup_ast(self.ast), view_left)
|
|
|
|
def apply_lds(self, ast) -> UOp:
|
|
def transform(ctx:tuple[Kernel, set[UOp]], global_access:UOp): return None
|
|
|
|
return graph_rewrite(ast, PatternMatcher([(UPat((Ops.LOAD, Ops.STORE), name="global_access"), transform)]), ctx=(self, set()))
|
|
|
|
# **** this is the lowerer ****
|
|
|
|
@track_rewrites()
|
|
def linearize(self, name_override:Optional[str]=None, ast_transform:Optional[Callable]=None) -> Kernel:
|
|
# display the AST
|
|
if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")
|
|
|
|
modified_ast = self.get_optimized_ast(name_override)
|
|
modified_ast = self.apply_lds(modified_ast)
|
|
if ast_transform is not None: modified_ast = ast_transform(self, modified_ast)
|
|
|
|
if DEBUG >= 3:
|
|
print(self.name)
|
|
if DEBUG >= 5: print(self.ast)
|
|
for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
|
|
print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s} {str(st.real_strides()):30s}",
|
|
str(st) if DEBUG >= 4 else "")
|
|
print(self.applied_opts)
|
|
if DEBUG >= 5: print(modified_ast)
|
|
# verify AST matches the spec after applying opts
|
|
if __debug__: type_verify(list(modified_ast.toposort))
|
|
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
|
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
|
|
|
|
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
|
if DEBUG >= 6: print_uops(self.uops)
|
|
return self
|
|
|
|
def to_program(self, name_override:Optional[str]=None, ast_transform:Optional[Callable]=None) -> ProgramSpec:
|
|
self.linearize(name_override, ast_transform)
|
|
assert self.uops[0].op is Ops.NAME, "first uop must be name"
|
|
src = self.opts.render(self.uops)
|
|
|
|
if CAPTURE_PROCESS_REPLAY:
|
|
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
|
|
|
|
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
|
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
|
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
|
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
|
key=lambda x: (x.op, x.src[0].arg)))
|
|
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
|
|
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
|
|