You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
431 lines
23 KiB
431 lines
23 KiB
from typing import Any, Callable, cast
|
|
import functools, operator, itertools
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from tinygrad.device import is_dtype_supported
|
|
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType
|
|
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
|
|
from tinygrad.codegen.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
|
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
|
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
from tinygrad.renderer import Renderer
|
|
|
|
# ***** image load valid simplification *****
|
|
|
|
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
|
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
|
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
|
|
|
# wait for it to be image indexed before running simplification
|
|
if start_idx.dtype.count != 2: return None
|
|
|
|
# can drop valid if idx is out of bound when valid is False
|
|
drop_stmt = []
|
|
for stmt in split_uop(valid, Ops.AND):
|
|
try: X, is_upper_bound, c = parse_valid(stmt)
|
|
except ValueError: return None
|
|
|
|
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
|
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
|
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
|
|
testidx = testidx.simplify()
|
|
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
|
drop_stmt.append(stmt)
|
|
continue
|
|
|
|
# if X <= c, check if it's out of bound when X = c+1
|
|
# if X >= c, check if it's out of bound when X = c-1
|
|
test_value = c + 1 if is_upper_bound else c - 1
|
|
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
|
if i.is_increasing():
|
|
rw = i.substitute({X:X.const_like(test_value)}).simplify()
|
|
if rw.vmin >= b or rw.vmax < 0:
|
|
drop_stmt.append(stmt)
|
|
break
|
|
|
|
if not drop_stmt and idx is start_idx: return None
|
|
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
|
return buf.index(idx, new_valid)
|
|
|
|
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
|
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
|
|
# remove the gate from the index
|
|
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
|
|
|
load_store_indexing = PatternMatcher([
|
|
# simplify valid
|
|
(UPat(Ops.AND, name="valid"), simplify_valid),
|
|
# image load valid idx simplification
|
|
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
|
# index True is just Index
|
|
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
|
# delete_redundant_gates (after expand)
|
|
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
|
UPat.var("val"))), delete_redundant_gates),
|
|
])
|
|
|
|
# ***** load/store grouping *****
|
|
|
|
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
|
# generate the individual indexes
|
|
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
|
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
|
# extract all the relevant offsets
|
|
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
|
for i in range(vec.dtype.count):
|
|
idx: Any = midx.src[i].src[1]
|
|
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
|
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
|
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
|
else: root_src, arg = idx, 0
|
|
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
|
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
|
|
|
# the buf.dtype is always a pointer
|
|
ptrdtype = cast(PtrDType, buf.dtype)
|
|
|
|
# then rewrite everything we can into groups
|
|
ret = []
|
|
idxs: list[int|None] = [None]*vec.dtype.count
|
|
global_offset = 0
|
|
for offsets in offsets_rootsrc.values():
|
|
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
|
for grp in grouped_offsets:
|
|
# get the index offset for this element. using [0] is okay, because they are the same
|
|
lidx = midx.src[offsets[grp[0]][0]]
|
|
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
|
# set the idxs of the output
|
|
for i,g in enumerate(grp):
|
|
for oo in offsets[g]: idxs[oo] = global_offset+i
|
|
# add this lidx to the CAT
|
|
ret.append(lidx)
|
|
global_offset += len(grp)
|
|
assert None not in idxs, f"some idxs are missing {idxs}"
|
|
# this base thing is for image, we want the CAT to be a normal pointer
|
|
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
|
|
return post_cat.gep(tuple(cast(list[int], idxs)))
|
|
|
|
def cat_after_store(cat:UOp, data:UOp):
|
|
# TODO: this is written in many places
|
|
offset = 0
|
|
ret = []
|
|
for s in cat.src:
|
|
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
|
|
offset += s.dtype.count
|
|
return UOp.sink(ret[0], *ret[1:])
|
|
|
|
def gep_on_store(gep:UOp, st:UOp):
|
|
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
|
# fake argsort. TODO: handle duplicates
|
|
a = {}
|
|
for i,x in enumerate(gep.arg): a[x] = i
|
|
new_arg = tuple(x[1] for x in sorted(a.items()))
|
|
return UOp(Ops.STORE, src=(gep.src[0], st.gep(new_arg)))
|
|
|
|
load_store_folding = PatternMatcher([
|
|
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
|
|
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"),
|
|
UPat.var("mask"))), expand_index),
|
|
# GEP after LOAD
|
|
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
|
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
|
# GEP on data of STORE
|
|
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
|
|
# put PTRCAT after LOAD
|
|
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
|
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
|
# put PTRCAT after STORE
|
|
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
|
|
])
|
|
|
|
# ***** optional patterns *****
|
|
|
|
@functools.lru_cache(None)
|
|
def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
|
# calculate m,s such that x//d == (x*m) >> s for all 0 <= x <= vmax, d>0; adapted from Hacker's Delight, Chapter 10
|
|
nc = (vmax+1)//(d) * d - 1
|
|
nbits = vmax.bit_length()
|
|
for s in range(0, 2*nbits + 1):
|
|
if 2**s > nc*(d - 1 - (2**s - 1) % d):
|
|
m = (2**s + d - 1 - (2**s - 1) % d)//d
|
|
return m, s
|
|
assert False
|
|
|
|
def fast_idiv(ctx: Renderer|None, x: UOp, d: int) -> UOp|None:
|
|
# idiv is truncated division, but arithmatic shift is floored division, so can only do non-negative numbers!
|
|
if x.vmin<0: return None
|
|
sign = 1 if d > 0 else -1
|
|
m,s = magicgu(vmax := min(x.vmax, dtypes.max(x.dtype)), abs(d))
|
|
if m * vmax <= dtypes.max(x.dtype): return sign * ((x*m) >> s)
|
|
# promo_lattice needs to return an unsigned type
|
|
if ctx is not None and dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, ctx.device):
|
|
if m * vmax <= dtypes.max(next_dtype): return sign * ((x.cast(next_dtype)*m) >> s).cast(x.dtype)
|
|
return None
|
|
|
|
powers_of_two = {2**i:i for i in range(64)}
|
|
@functools.cache
|
|
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
|
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
|
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
|
# rewrite SQRT to xpow 0.5
|
|
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
|
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
|
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
|
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
|
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
|
if Ops.SHR in ops:
|
|
# no reason to check x>=0 for uints
|
|
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
|
pat += [(UPat.var("x", dtypes.sints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
|
|
if not getenv("DISABLE_FAST_IDIV"):
|
|
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
|
|
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("d"), lambda ctx, x, d: x - d*f if (f:=fast_idiv(ctx, x, d.arg)) is not None else None)]
|
|
if Ops.NEG in ops:
|
|
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
|
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
|
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
|
return PatternMatcher(pat)
|
|
|
|
# *** correct load/store ***
|
|
|
|
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|
# this splits loads and stores into multiple chunks
|
|
|
|
# if there's only one element to load/store, no splitting needed
|
|
if (sz:=ls.src[0].dtype.count) == 1: return None
|
|
buf = idx.src[0]
|
|
|
|
# determine fold lengths
|
|
lengths = []
|
|
must_divide = True
|
|
if ctx is not None and ctx.device == "DSP":
|
|
lengths = [128,64,32,16,8,4]
|
|
must_divide = False
|
|
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
|
pass
|
|
elif isinstance(buf.dtype, ImageDType):
|
|
lengths = [4]
|
|
elif ctx is not None and ctx.supports_float4:
|
|
# TODO: a better way to get this than ctx
|
|
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
|
|
lengths.append(1) # worst case, it's not folded
|
|
|
|
# filter fold lengths that don't divide
|
|
if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
|
|
|
|
# split based on the fold lengths
|
|
global_offset = 0
|
|
ret = []
|
|
ptrdtype = cast(PtrDType, buf.dtype)
|
|
while global_offset < sz:
|
|
# with 1 at the end of the lengths list, this will always hit
|
|
for fold_length in lengths:
|
|
if global_offset+fold_length > sz: continue
|
|
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
|
|
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
|
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
|
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
|
global_offset += fold_length
|
|
break
|
|
|
|
# if it wasn't split, we return None. otherwise we CAT them
|
|
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if len(ret) > 1 else None
|
|
|
|
def image_fixup(ls:UOp):
|
|
# normal image load or store, with the CAST from expand_index
|
|
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
|
|
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
|
idx = ls.src[0].src[0]
|
|
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
|
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
|
return ls.replace(src=(idx,)+ls.src[1:])
|
|
|
|
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
|
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
|
|
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
|
idx = ls.src[0]
|
|
id4 = idx.src[1] % 4
|
|
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
|
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
|
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
|
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
|
|
|
return None
|
|
|
|
correct_load_store = PatternMatcher([
|
|
# split LOAD/STORE
|
|
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
|
|
# image indexing, including unfoldable images
|
|
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
|
])
|
|
|
|
# *** uop expander ***
|
|
|
|
# TODO: there's a lot shared with gep_through_wmma here
|
|
def no_vectorized_wmma(wmma:UOp):
|
|
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
|
if wmma.dtype.count == out_sz: return None
|
|
tsrcs = []
|
|
for s,sz in zip(wmma.src, wmma.arg[6]):
|
|
ssz = prod(x[1] for x in sz)
|
|
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
|
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
|
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
|
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
|
|
|
def no_vectorized_alu(alu:UOp):
|
|
if alu.dtype.vcount == 1: return None
|
|
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
|
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
|
|
|
def no_vectorized_acc(acc:UOp):
|
|
if acc.dtype.count == 1: return None
|
|
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
|
|
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
|
return UOp(Ops.VECTORIZE, acc.dtype, alus)
|
|
|
|
devectorize = PatternMatcher([
|
|
# no ALU on vectorized dtypes
|
|
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu),
|
|
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
|
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
|
])
|
|
|
|
pm_render = PatternMatcher([
|
|
# for rendering, we use explicit VECTORIZE
|
|
(UPat(Ops.CONST, name='c'),
|
|
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
|
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
|
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
|
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
|
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
|
# give any loads that are masked an alt value
|
|
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
|
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
|
# gate any stores that aren't gated with ifs
|
|
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
|
|
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
|
|
])
|
|
|
|
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
|
|
|
|
@dataclass
|
|
class ReduceContext:
|
|
acc_num: int = 0
|
|
|
|
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
|
|
# if this has a horizontal reduction component, do that first
|
|
if inp.dtype != out_dtype:
|
|
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
|
horizontal_amount = inp.dtype.count//out_dtype.count
|
|
return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
|
return [inp]
|
|
|
|
def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|
inp, reduce_range = red.src[0], red.src[1:]
|
|
lst = horizontal_reduce(inp, red.dtype)
|
|
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
|
# if we have a range
|
|
if len(reduce_range) != 0:
|
|
acc = UOp(Ops.DEFINE_ACC, red.dtype, (red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
|
lst = [acc] + lst # put acc as the first element
|
|
ctx.acc_num += 1
|
|
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
|
return acc.assign(ret) if len(reduce_range) != 0 else ret
|
|
|
|
def no_vectorized_reduce(inp:UOp, red:UOp):
|
|
if inp.dtype != red.dtype:
|
|
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:])
|
|
if red.dtype.vcount == 1: return red
|
|
# no_vectorize_alu ignoring ranges
|
|
if red.dtype.vcount == 1: return None
|
|
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
|
|
return UOp(Ops.VECTORIZE, red.dtype, alus)
|
|
|
|
def reduce_rangeless(red:UOp):
|
|
# TODO: share code with reduce_unparented
|
|
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
|
if red.src[0].dtype != red.dtype: return None
|
|
if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
|
|
ret = red.src[0]
|
|
if red.arg is Ops.ADD:
|
|
for r in red.src[1:]:
|
|
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
|
return ret
|
|
|
|
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
|
|
|
|
pm_reduce_collapse = PatternMatcher([
|
|
# lift x+y out of reduce on lt
|
|
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
|
|
# lift x*y out of reduce
|
|
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
|
|
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
|
|
# lift x+y out of reduce on ne
|
|
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
|
|
# fold the range
|
|
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
|
|
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
|
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
|
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
|
# REDUCE on ADD
|
|
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
|
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
|
|
# MUL casted bool
|
|
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
|
|
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
|
|
# WHERE on LOAD (works on max too)
|
|
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
|
lambda buf,idx,gate: buf.index(idx, gate).load()),
|
|
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
|
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
|
# INDEX on RANGE / gated RANGE
|
|
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
|
|
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
|
|
# AND on WHERE
|
|
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
|
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
|
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
|
|
# remove REDUCEs that no longer have a RANGE in the src
|
|
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),
|
|
# devectorize REDUCE
|
|
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
|
|
# index/load/where. TODO: this is more aggressive than needed
|
|
(UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
|
|
])+sym
|
|
|
|
def reduce_collapse(red:UOp):
|
|
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
|
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
|
|
replaces: dict[UOp, UOp] = {}
|
|
for u in included:
|
|
for s in u.src:
|
|
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
|
|
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
|
|
collapse_fxn = red.substitute(replaces)
|
|
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
|
|
# TODO: why is REDUCE needed here and just RANGE isn't enough?
|
|
if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
|
|
return sink.substitute({v:k for k,v in replaces.items()})
|
|
|
|
def reduce_unparented(red:UOp):
|
|
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
|
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
|
if len(reduce_unparented) == 0: return None
|
|
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
|
|
if red.arg is Ops.ADD:
|
|
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
|
return ret
|
|
|
|
pm_reduce = PatternMatcher([
|
|
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
|
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
|
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
|
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
|
# REDUCE -> DEFINE_ACC+ASSIGN
|
|
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
|
# tensor core built in accumulate
|
|
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
|
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
|
])+sym
|
|
|