openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

557 lines
37 KiB

# all of symbolic lives here now
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
from tinygrad.uop.decompositions import xpow
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
def simplify_pow(x:UOp, c:UOp) -> UOp|None:
if c.arg < 0: return x.reciprocal().pow(-c)
if c.arg == 0: return x.const_like(1)
if int(c.arg-0.5)+0.5 == c.arg: return x.pow(c.const_like(c.arg-0.5)) * x.sqrt()
if int(c.arg) == c.arg: return (y := x.pow(c.const_like(c.arg//2))) * y * (x if c.arg%2 == 1 else 1)
return None
def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
if c.dtype.itemsize != root.dtype.itemsize: return None
def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
propagate_invalid = PatternMatcher([
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
# propagate invalid, push it past children
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if cast.dtype is not dtypes.index else None),
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
for op in GroupOp.Binary-GroupOp.Comparison),
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
# invalid + y -> y same for other ops
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison),
# i < y -> a_bool_value_that_will_never_be_used: we choose a random bool const
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: UOp.const(dtypes.bool, True)) for op in GroupOp.Comparison),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
])
symbolic_simple = propagate_invalid + PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
# variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding **
# TODO: add const folding for Ops.THREEFRY
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
# *** div rules ***
(UPat.cvar('x', arg=0) / 0, lambda x: x.const_like(float('nan'))), # 0/0 -> nan
((UPat.var("x") * 0) / 0, lambda x: x.const_like(float('nan'))), # (x*0)/0 -> nan
# can be wrong if x or x2 is 0
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST
and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# *** cast/bitcast ***
(UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
# b.cast(a).cast(b) -> b if a preserves all values in b
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
# ** pow **
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
# positive const ** x
(UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
# rules for threefry
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)&0xFFFFFFFF), # TODO: why is the and needed?
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
# hacks for threefry long removal when padded (TODO: genericize)
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
lambda x,y: y.where(x, 0).cast(dtypes.uint64) * (1<<32)),
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), 0)),
# new decomp rules for threefry
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0)),
# ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
def lt_folding(x:UOp, c:int) -> UOp|None:
p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1)
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
return unwrap(UOp.sum(*np).divides(d))<(c//d)
return None
def canonicalize_simplex(X:UOp) -> UOp|None:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in X.split_uop(Ops.ADD):
# assumed the const is the last src of MUL
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
ret.append(u)
return UOp.sum(*ret) if changed else None
def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None:
# simple cancel div/mod case when the range of the numerator lies within a single denominator interval
x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
return x - q*y if d.op is Ops.MOD else d.const_like(q)
return None
def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
# remove nested mod in case the inner mod is a multiple of the outer mod
# example: (a%4 + b)%2 -> (a+b)%2
if ((c := y.arg) < 0) or x.vmin<0: return None
new_xs = []
something_changed = False
for u in x.split_uop(Ops.ADD):
if u.op is Ops.MOD:
if u.src[1].divides(c) is not None:
something_changed = True
u = u.src[0]
new_xs.append(u)
new_x: UOp = UOp.sum(*new_xs)
if something_changed and new_x.vmin>=0: return new_x % y
return None
def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we can fold if the expression has only one non-constant term and this term can only take on two values
if ((c := y.arg) < 0): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c)
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c)
return (y2-y1)*(v-v.vmin) + y1
return None
def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None
if d.op is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
gcd = UOp.gcd(*x.split_uop(Ops.ADD), y).simplify()
if gcd.op is Ops.CONST and gcd.arg==1: return None
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd)))
return ret*gcd if d.op is Ops.MOD else ret
def gcd_with_remainder(d: UOp, x: UOp, y: UOp):
# (gcd*x+r)//(gcd*d) -> (x+(r%d)//gcd)//d + r//(gcd*d)
# (gcd*x+r)%(gcd*d) -> gcd*(x+(r%d)//gcd)%d + r%gcd
# These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x
if ((c := y.arg) < 0) or x.vmin<0: return None
x_no_const, const = x.pop_const()
gcd = UOp.gcd(*x_no_const.split_uop(Ops.ADD), y).simplify()
assert gcd.op is Ops.CONST
if gcd.arg==1: return None
new_x = unwrap(x_no_const.divide_exact(gcd)).simplify() + (const%c)//gcd
if new_x.vmin<0: return None
ret = new_x.alu(d.op, x.ufix(c//gcd.arg))
return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c
def factor_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
# (d*x+y)//d -> x+y//d or (d*x+y)%d
# for mod we go further and take the remainder of all factors to reduce their size
# These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x
if y.vmin<0 or x.vmin<0: return None
quo, rem = [], []
for u in x.split_uop(Ops.ADD):
if (q:=u.divide_exact(y)) is not None: quo.append(q)
# if this is mod and y is a const, we can make the remainder factor sm
elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
rem.append(u.divides(c)*(c%y.arg))
quo.append(u.const_like(0)) # we append this so we can check if something changed
else: rem.append(u)
new_x = sum(rem)+x.const_like(0)
if len(quo)==0 or new_x.vmin<0: return None
return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo)
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we try and nest the div and see if it allows the numerator to be simplified
if ((c := y.arg) < 0): return None
factors = [u.const_factor() for u in x.split_uop(Ops.ADD) if u.op not in (Ops.CONST, Ops.VCONST)]
div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0])
newxs = fold_divmod_congruence(newx:=(x//div), x, y.const_like(div))
if newxs is None: newxs = factor_remainder(newx, x, y.const_like(div))
if div==y.arg or newxs is None or x.vmin<0 or newx.vmin<0: return None
return newxs//(c//div)
def gep_through_wmma(gep:UOp, wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
wmma_idxs = gep.arg[::out_sz]
for i in range(out_sz):
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
src_args = []
ssz = prod(x[1] for x in sz)
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
tsrcs.append(s.gep(tuple(src_args)))
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
(UPat(Ops.VECTORIZE, name='vec').f(Ops.GEP, name='gep'),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# GEP on void is skipped
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
# GEP in order is removed
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
# push all GEPs through ALUs (fix arange stuff)
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
if not isinstance(x.dtype, PtrDType) else None),
# VECTORIZE on same GEP
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
# push some GEPs through WMMAs
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
])
commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for index) **
# NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
])
symbolic = symbolic_simple+commutative+PatternMatcher([
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# TODO: make a more general or folder like simplify_valid
(UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True
# ** combine terms **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
((UPat.var("y") + UPat.var("x") * UPat.cvar("c")) + UPat.var("x"), lambda x,y,c: y+x*(c+1)),
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
(UPat.cvar("y") * (UPat.var("x", dtype=dtypes.index) + UPat.cvar("c")), lambda x,y,c: (y*x)+(y*c)), # -(x+c) -> -x + -c
# ** where folding **
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)
if f.arg is not Invalid else None),
# alu of two where with same conds can combine, only do if true branch or false branch is const
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# if its a plus we add the associative variation too
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST (slow!)
(UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
#((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
# ** two stage ALU folding **
*((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
# ** lt **
# c0*x<c1 for positive int c0,c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# x//d<c
((UPat.var("x", dtype=dtypes.index)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic ***
# generic lt folding
(UPat.var("x", dtypes.index)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
(UPat.var("x", dtypes.index)*-1 < UPat.var("y")*-1, lambda x,y: y<x),
# canonicalize a simplex with positive coefficients > 0
# not x < 1 -> X > 0
((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
# ** div **
# div folding
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
# a range mod its own upper bound is just the range
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r),
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
(UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), divide_by_gcd),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), gcd_with_remainder),
(UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
(UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), factor_remainder),
(UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax<=0 else None),
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None),
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
# ** mod **
# mod folding
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
# cast/long folding
# if the intermediate cast doesnt narrow we can do it in one cast
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None),
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
# try to do math in int instead of long
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# VECTORIZE/CONST
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
])+gep_pushing
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]|None:
# if it's X <= c, returns X, True, c
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin)
# X < c -> X <= c-1
if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
return None
def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
# return simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
for stmt in valid.split_uop(Ops.AND):
if (res:=parse_valid(stmt)) is None: continue
expr, is_upper, c = res
bounds[expr][int(is_upper)] = c
# don't simplify any other gates, can lead to OOB, we substitute them back later
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
# simplify uop given that valid is True
all_candidates = []
for i,(expr,v) in enumerate(bounds.items()):
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
# try checking the whole clause
all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype)))
if try_simplex:
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = [[all_candidates[-1]]]
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable(f"fake{i}", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [uop.substitute({X:newX}) for X,newX in candidate]
if any(u is uop for u in newuops): continue # if any branch doesnt appear in uop, skip
newuops = [u.simplify().substitute({newX:X}).simplify(full_symbolic=False) for (X,newX),u in zip(candidate,newuops)]
if all_same(newuops): uop = newuops[0]
elif uop.op is Ops.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
# try all the valids together (but only the whole expressions)
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop
def _valid_priority(v: UOp, valids:list[UOp]):
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
return sum(-1 if (res:=parse_valid(v)) is not None and res[0] in other.toposort() else 0 for other in valids)
def simplify_valid(valid:UOp) -> UOp|None:
if valid.op_in_backward_slice_with_self(Ops.INDEX): return None # this should only be for indexing, skip if there's a INDEX
ret:list[UOp] = []
something_changed = False
valids = list(valid.split_uop(Ops.AND))
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
ret.append(uop_given_valid(UOp.prod(*ret), stmt) if ret else stmt)
if ret[-1] is not stmt: something_changed = True
return UOp.prod(*ret) if something_changed else None
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
def reduce_mul_chain(r:UOp):
if r.arg not in {Ops.ADD, Ops.MAX}: return None
if r.dtype != r.src[0].dtype: return None
inside, outside = [], []
for m in r.src[0].split_uop(Ops.MUL):
m_parents = m.toposort()
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
else: inside.append(m)
if len(outside) == 0: return None
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None
return UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses]).where(x, i)
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
def where_on_load(c1, buf, x):
c2 = x.get_valid()
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
# also no data dependent loads!
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
if not (removed:=moved_clauses+duplicate_clauses): return None
# aditionally we can drop the clause on the where if it already exists in the load
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
pm_move_where_on_load = PatternMatcher([
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
])
pm_simplify_valid = PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
# TODO: this regressed openpilot, not having this regressed cifar
# (UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
])
# this is symbolic 2.0
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK}
sym = symbolic+pm_simplify_valid+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
# VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
# VECTORIZE of a single element is just that element
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is GROUP
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp.group(*x.src)),
# tensor core with a 0 input is acc
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
# ** self folding **
# x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where **
# push cast to branches
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
# ** pow **
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
# ** load/store folding **
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])),
# fold gated LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
# move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
# reduce mul chain, move muls after the reduce
(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
# clean up GROUP/SINK
(UPat(Ops.GROUP, src=(UPat.var("x"),)), lambda x: x),
(UPat((Ops.SINK, Ops.GROUP), name="root"),
lambda root: UOp(root.op, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK_LIKE else (x,) for x in root.src)), root.arg)
if any(x.op in REMOVE_FROM_SINK_LIKE for x in root.src) else None),
# remove END with empty NOOP
(UPat(Ops.END, src=(UPat(Ops.NOOP, src=(), name="noop"),), allow_any_len=True), lambda noop:noop),
# ** combine terms (opinionated) **
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
])