# all of symbolic lives here now from typing import cast import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap from tinygrad.uop.decompositions import xpow # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ******** def simplify_pow(x:UOp, c:UOp) -> UOp|None: if c.arg < 0: return x.reciprocal().pow(-c) if c.arg == 0: return x.const_like(1) if int(c.arg-0.5)+0.5 == c.arg: return x.pow(c.const_like(c.arg-0.5)) * x.sqrt() if int(c.arg) == c.arg: return (y := x.pow(c.const_like(c.arg//2))) * y * (x if c.arg%2 == 1 else 1) return None def fold_bitcast(root:UOp, c:UOp) -> UOp|None: if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None if c.dtype.itemsize != root.dtype.itemsize: return None def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0] return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg))) invalid_pat = UPat.const(dtypes.index, Invalid).named("i") invalid_gate = UPat.var("cond").where(UPat.var("x",dtype=dtypes.index), invalid_pat) propagate_invalid = PatternMatcher([ # this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0 # propagate invalid, push it past children *((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i)) for op in GroupOp.Binary-GroupOp.Comparison), *((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison), # invalid + y -> y same for other ops *((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison), # i < y -> a_bool_value_that_will_never_be_used: we choose a random bool const *((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: UOp.const(dtypes.bool, True)) for op in GroupOp.Comparison), # a.where(b.where(c, d), d) -> (a & b).where(c, d) (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), # order of gate&!cond matters!, and-clauses are only simplified left to right and we need to gate to be used to fold cond (UPat.var("gate").where(invalid_gate, UPat.var("y")), lambda gate,cond,x,y,i: ((gate&cond.logical_not()).logical_not()).where(gate.where(x,y), i)), # unswap the branches for the rule above (UPat.var("gate").where(UPat.var("y"), invalid_gate).named("where"), lambda gate,cond,x,y,i: gate.logical_not().where(cond.where(x,i), y)) ]) symbolic_simple = propagate_invalid + PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1 ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) # 4 variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x), ((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x), ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None), ((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"), lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None), (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0 (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # ** constant folding ** # TODO: add const folding for Ops.THREEFRY (UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))), (UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))), (UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))), # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y), (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y), # *** cast/bitcast *** (UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)), (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast), # b.cast(a).cast(b) -> b if a preserves all values in b (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None), # ** pow ** (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow), # positive const ** x (UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None), # rules for threefry ((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)&0xFFFFFFFF), # TODO: why is the and needed? (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x), # hacks for threefry long removal when padded (TODO: genericize) (UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)), lambda x,y: y.where(x, 0).cast(dtypes.uint64) * (1<<32)), ((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32), lambda x,y: y.where(x.cast(dtypes.uint32), 0)), # new decomp rules for threefry (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x), (UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0)) ]) # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** def lt_folding(x:UOp, c:int) -> UOp|None: p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d) return None def canonicalize_simplex(X:UOp) -> UOp|None: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] for u in X.split_uop(Ops.ADD): # assumed the const is the last src of MUL if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True u = u.src[0] if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None ret.append(u) return functools.reduce(operator.add, ret) if changed else None def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None: # simple cancel div/mod case when the range of the numerator lies within a single denominator interval x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int) if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}") if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max): return x - q*y if d.op is Ops.MOD else d.const_like(q) return None def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None: # remove nested mod in case the inner mod is a multiple of the outer mod # example: (a%4 + b)%2 -> (a+b)%2 if ((c := y.arg) < 0) or x.vmin<0: return None new_xs = [] something_changed = False for u in x.split_uop(Ops.ADD): if u.op is Ops.MOD: if u.src[1].divides(c) is not None: something_changed = True u = u.src[0] new_xs.append(u) new_x: UOp = functools.reduce(operator.add, new_xs) if something_changed and new_x.vmin>=0: return new_x % y return None def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None: # we can fold if the expression has only one non-constant term and this term can only take on two values if ((c := y.arg) < 0): return None x,const = x.pop_const() terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)]) if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1: y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore return (y2-y1)*(v-v.vmin) + y1 return None def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None: # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0): return None x,const = x.pop_const() terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)]) # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c rems = [min((r:=f%c), r-c, key=abs) for f in factors] if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None if d.op is Ops.MOD: return rem - rem.vmin//c*c return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None: # x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd) gcd = UOp.gcd(*x.split_uop(Ops.ADD), y).simplify() if gcd.op is Ops.CONST and gcd.arg==1: return None ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) return ret*gcd if d.op is Ops.MOD else ret def gcd_with_remainder(d: UOp, x: UOp, y: UOp): # (gcd*x+r)//(gcd*d) -> (x+(r%d)//gcd)//d + r//(gcd*d) # (gcd*x+r)%(gcd*d) -> gcd*(x+(r%d)//gcd)%d + r%gcd # These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x if ((c := y.arg) < 0) or x.vmin<0: return None x_no_const, const = x.pop_const() gcd = UOp.gcd(*x_no_const.split_uop(Ops.ADD), y).simplify() assert gcd.op is Ops.CONST if gcd.arg==1: return None new_x = unwrap(x_no_const.divide_exact(gcd)).simplify() + (const%c)//gcd if new_x.vmin<0: return None ret = new_x.alu(d.op, x.ufix(c//gcd.arg)) return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None: # we try and nest the div and see if it allows the numerator to be simplified if ((c := y.arg) < 0): return None factors = [u.const_factor() for u in x.pop_const()[0].split_uop(Ops.ADD)] # div is the smallest factor of the denominator (greater than 1) out of all "factors" # TODO: there are better ways to pick `div`, this sometimes adds extra divisions # TODO: add same optimization for mod div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0]) if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div) return None def factor_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None: # (d*x+y)//d -> x+y//d or (d*x+y)%d # for mod we go further and take the remainder of all factors to reduce their size # These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x if y.vmin<0 or x.vmin<0: return None quo, rem = [], [] for u in x.split_uop(Ops.ADD): if (q:=u.divide_exact(y)) is not None: quo.append(q) # if this is mod and y is a const, we can make the remainder factor sm elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c: rem.append(u.divides(c)*(c%y.arg)) quo.append(u.const_like(0)) # we append this so we can check if something changed else: rem.append(u) new_x = sum(rem)+x.const_like(0) if len(quo)==0 or new_x.vmin<0: return None return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo) def gep_through_wmma(gep:UOp, wmma:UOp): out_sz = prod(x[1] for x in wmma.arg[6][-1]) wmma_idxs = gep.arg[::out_sz] for i in range(out_sz): if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None tsrcs = [] for s,sz in zip(wmma.src, wmma.arg[6]): src_args = [] ssz = prod(x[1] for x in sz) for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz)) tsrcs.append(s.gep(tuple(src_args))) return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg) gep_pushing = PatternMatcher([ # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST (UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'), lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))), (UPat(Ops.VECTORIZE, name='vec').f(Ops.GEP, name='gep'), lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), (UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)), (UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), # GEP on void is skipped (UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x), # GEP in order is removed (UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None), # push all GEPs through ALUs (fix arange stuff) (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ if not isinstance(gep.dtype, PtrDType) else None), # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \ if not isinstance(x.dtype, PtrDType) else None), # VECTORIZE on same GEP (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))), # push some GEPs through WMMAs (UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma), ]) commutative = PatternMatcher([ # ** COMMUTATIVE flipping (only for index) ** # NOTE: this can break merging vector math by only flipping some of them (UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), ]) symbolic = symbolic_simple+commutative+PatternMatcher([ # ** boolean algebra ** (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # TODO: make a more general or folder like simplify_valid (UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True # ** combine terms ** (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)), (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) ((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)), ((UPat.var("y") + UPat.var("x") * UPat.cvar("c")) + UPat.var("x"), lambda x,y,c: y+x*(c+1)), (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2), ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3) (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c # a conditional with the same results either way is a noop, also fold const conditionals (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), (UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t) if f.arg is not Invalid else None), # alu of two where with same conds can combine, only do if true branch or false branch is const (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \ lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), # if its a plus we add the associative variation too ((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \ lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), # ALU/variable min==max -> CONST (slow!) (UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist? #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z), #((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const), # ** two stage ALU folding ** *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"), lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative), ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0 ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # generic lt folding (UPat.var("x", dtypes.index) 0 # not x < 1 -> X > 0 ((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # div folding ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d) # a range mod its own upper bound is just the range (UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r), (UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), divide_by_gcd), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), gcd_with_remainder), (UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), (UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), factor_remainder), (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax<=0 else None), ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None), ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** mod ** # mod folding (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), # cast/long folding # if the intermediate cast doesnt narrow we can do it in one cast (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None), (UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None), # try to do math in int instead of long (UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), ])+gep_pushing symbolic_flat = symbolic+PatternMatcher([ # ** combine terms (opinionated) ** (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) # ******** we take a small aside to "simplify_valid" to rewrite valids ******** def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: # if it's X <= c, returns X, True, c # if it's X >= c, returns X, False, c # (X < c).ne(True) -> X >= c if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ (s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin) # X < c -> X <= c-1 if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # return None if valid is always False, otherwise the simplified uop (might be the same as input) # first, parse valid into {expr: (lower_bound, upper_bound)} bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None]) for stmt in valid.split_uop(Ops.AND): try: expr, is_upper, c = parse_valid(stmt) except ValueError: continue # give up if we cannot parse the valid bounds[expr][int(is_upper)] = c # don't simplify any other gates, can lead to OOB, we substitute them back later uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) # simplify uop given that valid is True for expr,v in bounds.items(): v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop # some expr has lower bound > upper bound -> valid is an empty set and we return None if v0 > v1: return None # whole node became a const if v0 == v1: uop = uop.substitute({expr:expr.const_like(v0)}).simplify() continue # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]) # try checking the whole clause if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))]) for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] if uop.op is Ops.VECTORIZE and len(uop.src) == 2: if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) elif all_same(newuops): uop = newuops[0] # put the loads back in uop = uop.substitute({v:k for k,v in load_subs.items()}) return uop def _valid_priority(v: UOp, valids:list[UOp]): # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids) except ValueError: return 0 def simplify_valid(valid:UOp) -> UOp|None: ret:list[UOp] = [] something_changed = False valids = list(valid.split_uop(Ops.AND)) for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): # TODO: root cause this and test_simplify_valid_from_div if stmt.op is Ops.CAST: return None ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt) if ret[-1] is not stmt: something_changed = True return functools.reduce(operator.and_, ret) if something_changed else None # ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ******** def reduce_mul_chain(r:UOp): if r.arg not in {Ops.ADD, Ops.MAX}: return None if r.dtype != r.src[0].dtype: return None inside, outside = [], [] for m in r.src[0].split_uop(Ops.MUL): m_parents = m.toposort() if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m) else: inside.append(m) if len(outside) == 0: return None return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside) # this is symbolic 2.0 REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP} REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP} sym = symbolic_flat+PatternMatcher([ # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), (UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if (newx:=uop_given_valid(cond, x)) is not x else None), # LOAD/STORE -> NOOP (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c), # VECTORIZE/CONST, VECTORIZE/GEP (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), # reorder ALU/VECTORIZE (UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'), lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)), # VECTORIZE of a single element is just that element (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # VECTORIZE void is SINK (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b), (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), # tensor core with a 0 input is acc (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc), (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), # ** self folding ** # x!=0 -> (bool)x (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # ** where ** # push cast to branches (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))), # ** pow ** ((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))), # ** load/store folding ** (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"), lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])), # fold gated LOAD/STORE (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"), lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0 (UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c")).or_casted(),), allow_any_len=True, name="l"), UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a)+l.src[1:])), (UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),), allow_any_len=True, name="l")), lambda c,idx,l,a: l.replace(src=(l.src[0], a)+l.src[1:])), # remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels (UPat(Ops.BARRIER, name="root"), lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg) if any(x.op in REMOVE_FROM_BARRIER for x in root.src) else None), (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg) if any(x.op in REMOVE_FROM_SINK for x in root.src) else None), ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()), ((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x) (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x) (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)), (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y), # move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce) ((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg), # reduce mul chain, move muls after the reduce (UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain), ])