import math from tinygrad.dtype import dtypes, DType from tinygrad.helpers import polyN from tinygrad.ops import UOp TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64) def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp): """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio""" return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf) # *** helper functions for bit manipulation *** def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1] def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()] def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()] # **** utils **** def shr(x:UOp, y:int) -> UOp: return x // (2**y) def shl(x:UOp, y:int) -> UOp: return x * (2**y) def rintk(d:UOp) -> UOp: """round d:float to int away from 0""" out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount) return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype) def pow2if(q:UOp, float_dtype:DType): """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]""" out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype.scalar()].vec(q.dtype.vcount) return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype) def ilogb2k(d:UOp) -> UOp: """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf).""" assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)) # -1 <= ilog2bk(d) <= 128 return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype) def ldexp3k(d:UOp, e:UOp) -> UOp: """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number.""" assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count) m1 = d.bitcast(dtype) m2 = shl(e.cast(dtype), mantissa_bits(d.dtype)) return (m1 + m2).bitcast(d.dtype).cast(d.dtype) def ldexp2k(d:UOp, e:UOp) -> UOp: """d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal.""" assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64) return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype) def frexp(v:UOp) -> tuple[UOp, UOp]: """frexp(v) -> (mantissa, exponent) assuming v != 0""" assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES # m1 = masks for mantissa, m2 = masks to normalize the mantissa. m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()] m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()] bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count)) exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype) # Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0). mantissa = ((bits & m1) | m2).bitcast(v.dtype) exp = exponent - exponent_bias(v.dtype) + 1 return mantissa, exp # *** reduction algorithms for sine *** def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]: """ Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where 39800.0 <= d <= +Inf Returns a tuple of `(r, q)`: - `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`. - `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`. """ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES # https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751 # 190 bits of 2/pi for Payne-Hanek style argument reduction two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410] intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base == dtypes.float16 else d.dtype f, e = frexp(d) ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast_vec(dtypes.uint64) # extract 96 relevant bits of 2/pi based on magnitude of argument i = shr(e.cast_vec(dtypes.uint64), 5) e = e.cast_vec(dtypes.int32) & 31 offset = 32 - e def _take(an:UOp, offset:int, count:int=0) -> UOp: """an = two_over_pi_f[i+offset]""" if count+offset < len(two_over_pi_f) - 1: an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset])) return an def _shl_lazy(x, y): return (x.cast_vec(dtypes.uint64) * pow2if(y, d.dtype).cast_vec(dtypes.uint64)).cast_vec(dtypes.uint32) def _shr_lazy(x, y): return (x.cast_vec(dtypes.uint64) // pow2if(y, d.dtype).cast_vec(dtypes.uint64)).cast_vec(dtypes.uint32) a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)] # (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e)) # Note: e >= 1 for all numbers d >= 1.0. assume e != 0 hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset) mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset) lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset) def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast_vec(dtypes.uint64) * y.cast_vec(dtypes.uint64) # compute x * 2/pi p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32) # round quotient to nearest q = shr(p, 62).cast_vec(dtypes.int32) p = p & 0x3fffffffffffffff r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype) # if fraction >= 0.5, r -= pi/2, q += 1 return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1) def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]: """ Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where 0 <= abs(d) <= 39800.0 Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`. """ def _reduce_d(x:UOp, q:UOp): # https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823 if x.dtype == dtypes.float64: # https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77 PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24 d = qdh * -PI_A + x d = q * -PI_A + d d = qdh * -PI_B + d d = q * -PI_B + d d = qdh * -PI_C + d d = q * -PI_C + d d = (qdh + q) * -PI_D + d elif x.dtype == dtypes.float16: # [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision. d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16) else: # https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503 d = q * -3.1414794921875 + x d = q * -0.00011315941810607910156 + d d = q * -1.9841872589410058936e-09 + d d = q * -1.2154201256553420762e-10 + d return d m_1_pi = 0.318309886183790671537767526745028724 qdh = (d * (m_1_pi / 2.0**24)).cast_vec(dtypes.int64).cast(d.dtype) * (2.0**24) quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base == dtypes.float64 else rintk(d * m_1_pi) return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast_vec(dtypes.int32) # *** approximate sine on small angle. *** def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32)) # approximate sine on [-pi/2, pi/2] def sin_poly(d:UOp) -> UOp: return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0], [-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10, -2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815, -0.166666666666666657414808, 1.0]) def _ifand(q:UOp, n:int): return (q & n).ne(0) def sin_poly_small(d:UOp, q:UOp) -> UOp: r = sin_poly(d) return r * _ifand(q, 1).where(r.const_like(-1), r.const_like(1)) def sin_poly_large(d:UOp, q:UOp) -> UOp: r = sin_poly(d + _ifand(q, 1).where(d.const_like(math.pi / 2), d.const_like(0))) return r * _ifand(q, 2).where(r.const_like(-1), r.const_like(1)) # *** toplevel functions for xsin/xlog2/xexp2 *** def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: """ Implements a 1.0 ULP approximation for Ops.SIN. - fast=True assumes x <= switch_over. - switch_over is the threshold for switching to payne_hanek_reduction. """ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES # mask +-inf/nan as zero x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d) # x_sign = sign(x) x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) x_abs = x * x_sign r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs) if fast: result = sin_poly_small(r, q) else: # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction. r_small, q_small = cody_waite_reduction(x_abs) result = (x_abs UOp: """ Implements a 1.0 ULP approximation for Ops.EXP2 - Paper: https://arxiv.org/pdf/2001.09258 """ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES # mask +=inf/nan as zero. x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d) q = rintk(x) # s = d - round(d) s = x - q.cast(x.dtype) # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2]. if d.dtype == dtypes.float64: u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0]) u = ldexp2k(u, q) # u*2^q upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()] # Replace x >= upper with +inf u = (d >= upper).where(d.const_like(math.inf), u) # Replace x < lower with zero. u = (d UOp: """ Implements a 1.0 ULP approximation for Ops.LOG2 Paper: https://arxiv.org/pdf/2001.09258 5.5 """ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES # TODO: float16 denormal need float32 to achieve precision if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16) FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4) is_denormal = d UOp: # start with b ** e = exp2(e * log2(b)) ret = (base < 0).where(-base, base).log2().mul(exponent).exp2() # negative base adjustment: nan for non-integer exponent and -1 for odd exponent non_int = exponent != exponent.cast_vec(dtypes.int32).cast(exponent.dtype) adj = non_int.where(ret.const_like(math.nan), (exponent < 0).where(-exponent, exponent).cast_vec(dtypes.int32).mod(2).cast_vec(dtypes.bool).where(ret.const_like(-1), ret.const_like(1))) # fix 0 ** 0 = 1 return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))