You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
10 KiB
189 lines
10 KiB
3 days ago
|
import itertools
|
||
|
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||
|
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX
|
||
|
from tinygrad.dtype import ImageDType
|
||
|
from tinygrad.uop.ops import Ops, resolve, AxisType
|
||
|
from tinygrad.codegen.opt.postrange import Scheduler
|
||
|
|
||
|
def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||
|
# first try the tensor cores
|
||
|
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
||
|
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
||
|
|
||
|
Keyword arguments:
|
||
|
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
||
|
0: will disable any tensor core matching
|
||
|
1: enable tensor cores
|
||
|
2: apply tensor core shape but don't use UOp.WMMA
|
||
|
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
||
|
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
||
|
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
||
|
[0-N]: uses only the n'th tensor core available; useful for search
|
||
|
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
||
|
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
|
||
|
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
|
||
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
||
|
"""
|
||
|
# NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis
|
||
|
if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)):
|
||
|
good_tc_opt = False
|
||
|
try: # check TC first and apply hand-coded opts if successful
|
||
|
tk = k.copy()
|
||
|
rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value)))
|
||
|
good_tc_opt = True
|
||
|
except KernelOptError:
|
||
|
pass
|
||
|
if good_tc_opt:
|
||
|
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||
|
if rngs is not None and not AMX:
|
||
|
for tc_dim in [1,0]: # attempt to upcast M and N
|
||
|
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
|
||
|
if szs:
|
||
|
# set it to the replaced range
|
||
|
rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0]
|
||
|
if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N
|
||
|
tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0]))
|
||
|
return tk
|
||
|
|
||
|
# make a copy so it does not mutate the input
|
||
|
k = k.copy()
|
||
|
|
||
|
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||
|
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||
|
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||
|
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
|
||
|
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||
|
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
|
||
|
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
|
||
|
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
|
||
|
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||
|
if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||
|
if DEBUG >= 3:
|
||
|
print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||
|
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||
|
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||
|
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||
|
return k
|
||
|
|
||
|
# are we grouping? (requires local shape support)
|
||
|
if resolve(prod(k.output_shape[i] for i in k.upcastable_dims) <= 2048, False):
|
||
|
for sz in [16]:
|
||
|
try:
|
||
|
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||
|
break
|
||
|
except KernelOptError: pass
|
||
|
|
||
|
# upcast float4 images
|
||
|
for buf_index,buf in enumerate(k.bufs):
|
||
|
if isinstance(buf.src[0].dtype, ImageDType):
|
||
|
# part of real_strides
|
||
|
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
|
||
|
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||
|
if len(unit_stride_axes_mul_4):
|
||
|
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
|
||
|
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
||
|
elif axis in k.unrollable_dims:
|
||
|
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
|
||
|
|
||
|
# no more opt if we are grouping
|
||
|
if k.group_for_reduces: return k
|
||
|
|
||
|
# **** below this line need to be optional and benchmarked ****
|
||
|
|
||
|
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||
|
to_upcast: list[int] = []
|
||
|
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||
|
for axis in k.upcastable_dims:
|
||
|
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||
|
is_masked = any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
|
||
|
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||
|
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||
|
to_upcast.append(axis)
|
||
|
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||
|
|
||
|
# potentially do more upcasts of non reduce axes based on a heuristic
|
||
|
is_dsp = k.opts is not None and k.opts.device == "DSP"
|
||
|
upcasted_axis: set[int] = set()
|
||
|
while resolve(prod(k.output_shape[i] for i in k.upcastable_dims) >= 1024):
|
||
|
xb_choices = []
|
||
|
# consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
|
||
|
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||
|
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||
|
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
|
||
|
rng = k.rngs[axis]
|
||
|
if any(rng not in b.src[1].get_idx().parents and all(r2 in b.src[1].get_idx().parents
|
||
|
for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
|
||
|
num_strides, sum_strides = 0, 0
|
||
|
for b in k.bufs:
|
||
|
idx = b.src[1].get_idx()
|
||
|
if rng in idx.parents: num_strides += 1
|
||
|
for c in idx.split_uop(Ops.ADD):
|
||
|
if c is rng: sum_strides += 1
|
||
|
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
|
||
|
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
|
||
|
xb_choices.append((num_strides, sum_strides, axis, upcast_amount))
|
||
|
if xb_choices:
|
||
|
xb_choices = sorted(xb_choices)
|
||
|
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
|
||
|
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||
|
upcasted_axis.add(xb_choices[0][2])
|
||
|
else: break
|
||
|
|
||
|
# if last reduce dim is small(ish), loop unroll the reduce
|
||
|
# NOTE: this can fail on multireduce with mismatching dimensions, this is okay
|
||
|
try:
|
||
|
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||
|
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
|
||
|
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
|
||
|
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
||
|
# if it's small, upcast a second reduce dimension too
|
||
|
if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
|
||
|
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
||
|
else:
|
||
|
for splits in [4]:
|
||
|
if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
|
||
|
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
|
||
|
break
|
||
|
except KernelOptError: pass
|
||
|
|
||
|
# if nothing at all is upcasted and it's easy to, do an upcast
|
||
|
for splits in [4]:
|
||
|
# TODO: somehow this never hits a reduce
|
||
|
if not k.upcasted and k.upcastable_dims and k.full_shape[k.upcastable_dims[-1]] % splits == 0:
|
||
|
k.apply_opt(Opt(OptOps.UPCAST, k.upcastable_dims[-1], splits))
|
||
|
|
||
|
# **** local groups ****
|
||
|
|
||
|
if k.opts.has_local:
|
||
|
if NOLOCALS:
|
||
|
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||
|
else:
|
||
|
# prioritize making expand axes local
|
||
|
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].get_idx().parents for b in k.bufs), axis) \
|
||
|
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
|
||
|
to_local: list[tuple[int, int]] = []
|
||
|
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||
|
local_size = prod(sz for _, sz in to_local)
|
||
|
local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||
|
if local_sz is not None: to_local.append((axis, local_sz))
|
||
|
deleted_shape = 0
|
||
|
for axis, local_sz in sorted(to_local[:3]):
|
||
|
axis = axis - deleted_shape
|
||
|
will_delete_shape = local_sz == k.full_shape[axis]
|
||
|
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||
|
if will_delete_shape: deleted_shape += 1
|
||
|
|
||
|
# **** threading ****
|
||
|
|
||
|
if k.opts.has_threads and k.opts.global_max is not None:
|
||
|
for threads in [32,16,12,8,6,5,4,3,2]:
|
||
|
# Skip is too many threads. Heuristic: use about 128K ops per thread
|
||
|
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
|
||
|
for axis in k.axes_of(AxisType.LOOP):
|
||
|
if k.full_shape[axis] % threads == 0:
|
||
|
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
|
||
|
break
|
||
|
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break
|
||
|
|
||
|
return k
|