openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

58 lines
3.4 KiB

# should this merge with transcendental?
from typing import Callable
import functools
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import dtypes, promo_lattice
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher
from tinygrad.helpers import getenv
from tinygrad.uop.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
# ***** optional patterns *****
@functools.lru_cache(None)
def magicgu(vmax:int, d:int) -> tuple[int,int]:
# calculate m,s such that x//d == (x*m) >> s for all 0 <= x <= vmax, d>0; adapted from Hacker's Delight, Chapter 10
nc = (vmax+1)//(d) * d - 1
nbits = vmax.bit_length()
for s in range(0, 2*nbits + 1):
if 2**s > nc*(d - 1 - (2**s - 1) % d):
m = (2**s + d - 1 - (2**s - 1) % d)//d
return m, s
assert False
def fast_idiv(ctx: Renderer|None, x: UOp, d: int) -> UOp|None:
# idiv is truncated division, but arithmetic shift is floored division, so can only do non-negative numbers!
if x.vmin<0: return None
sign = 1 if d > 0 else -1
m,s = magicgu(vmax := min(x.vmax, dtypes.max(x.dtype)), abs(d))
if m * vmax <= dtypes.max(x.dtype): return sign * ((x*m) >> s)
# promo_lattice needs to return an unsigned type
if ctx is not None and dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, ctx.device):
if m * vmax <= dtypes.max(next_dtype): return sign * ((x.cast(next_dtype)*m) >> s).cast(x.dtype)
return None
powers_of_two = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
if Ops.SHR in ops:
# no reason to check x<0 for uints
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
if not getenv("DISABLE_FAST_IDIV"):
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("d"), lambda ctx, x, d: x - d*f if (f:=fast_idiv(ctx, x, d.arg)) is not None else None)]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
return PatternMatcher(pat)