You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
14 KiB
266 lines
14 KiB
import math
|
|
from tinygrad.dtype import dtypes, DType
|
|
from tinygrad.helpers import polyN
|
|
from tinygrad.ops import UOp
|
|
|
|
TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
|
|
|
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
|
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
|
|
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
|
|
|
|
# *** helper functions for bit manipulation ***
|
|
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
|
|
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
|
|
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
|
|
|
|
# **** utils ****
|
|
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
|
|
def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
|
|
|
def rintk(d:UOp) -> UOp:
|
|
"""round d:float to int away from 0"""
|
|
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
|
|
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
|
|
|
|
def pow2if(q:UOp, float_dtype:DType):
|
|
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
|
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype.scalar()].vec(q.dtype.vcount)
|
|
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
|
|
|
|
def ilogb2k(d:UOp) -> UOp:
|
|
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
|
|
# -1 <= ilog2bk(d) <= 128
|
|
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
|
|
|
|
def ldexp3k(d:UOp, e:UOp) -> UOp:
|
|
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
|
|
m1 = d.bitcast(dtype)
|
|
m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
|
|
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
|
|
|
|
def ldexp2k(d:UOp, e:UOp) -> UOp:
|
|
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
|
|
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
|
|
|
def frexp(v:UOp) -> tuple[UOp, UOp]:
|
|
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
|
assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
|
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
|
|
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
|
|
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
|
|
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
|
|
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
|
|
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
|
|
exp = exponent - exponent_bias(v.dtype) + 1
|
|
return mantissa, exp
|
|
|
|
# *** reduction algorithms for sine ***
|
|
def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
"""
|
|
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
|
|
39800.0 <= d <= +Inf
|
|
Returns a tuple of `(r, q)`:
|
|
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
|
|
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
|
|
"""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
|
|
# 190 bits of 2/pi for Payne-Hanek style argument reduction
|
|
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
|
|
|
|
intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base == dtypes.float16 else d.dtype
|
|
|
|
f, e = frexp(d)
|
|
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast_vec(dtypes.uint64)
|
|
# extract 96 relevant bits of 2/pi based on magnitude of argument
|
|
i = shr(e.cast_vec(dtypes.uint64), 5)
|
|
e = e.cast_vec(dtypes.int32) & 31
|
|
offset = 32 - e
|
|
|
|
def _take(an:UOp, offset:int, count:int=0) -> UOp:
|
|
"""an = two_over_pi_f[i+offset]"""
|
|
if count+offset < len(two_over_pi_f) - 1:
|
|
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
|
return an
|
|
def _shl_lazy(x, y): return (x.cast_vec(dtypes.uint64) * pow2if(y, d.dtype).cast_vec(dtypes.uint64)).cast_vec(dtypes.uint32)
|
|
def _shr_lazy(x, y): return (x.cast_vec(dtypes.uint64) // pow2if(y, d.dtype).cast_vec(dtypes.uint64)).cast_vec(dtypes.uint32)
|
|
|
|
a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
|
|
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
|
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
|
|
hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
|
|
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
|
|
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
|
|
|
|
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast_vec(dtypes.uint64) * y.cast_vec(dtypes.uint64)
|
|
# compute x * 2/pi
|
|
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
|
|
|
|
# round quotient to nearest
|
|
q = shr(p, 62).cast_vec(dtypes.int32)
|
|
p = p & 0x3fffffffffffffff
|
|
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
|
|
|
|
# if fraction >= 0.5, r -= pi/2, q += 1
|
|
return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1)
|
|
|
|
def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
"""
|
|
Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
|
|
0 <= abs(d) <= 39800.0
|
|
Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
|
|
"""
|
|
def _reduce_d(x:UOp, q:UOp):
|
|
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
|
|
if x.dtype == dtypes.float64:
|
|
# https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
|
|
PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
|
|
d = qdh * -PI_A + x
|
|
d = q * -PI_A + d
|
|
d = qdh * -PI_B + d
|
|
d = q * -PI_B + d
|
|
d = qdh * -PI_C + d
|
|
d = q * -PI_C + d
|
|
d = (qdh + q) * -PI_D + d
|
|
elif x.dtype == dtypes.float16:
|
|
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
|
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
|
else:
|
|
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
|
|
d = q * -3.1414794921875 + x
|
|
d = q * -0.00011315941810607910156 + d
|
|
d = q * -1.9841872589410058936e-09 + d
|
|
d = q * -1.2154201256553420762e-10 + d
|
|
return d
|
|
|
|
m_1_pi = 0.318309886183790671537767526745028724
|
|
qdh = (d * (m_1_pi / 2.0**24)).cast_vec(dtypes.int64).cast(d.dtype) * (2.0**24)
|
|
quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base == dtypes.float64 else rintk(d * m_1_pi)
|
|
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast_vec(dtypes.int32)
|
|
|
|
# *** approximate sine on small angle. ***
|
|
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))
|
|
# approximate sine on [-pi/2, pi/2]
|
|
def sin_poly(d:UOp) -> UOp:
|
|
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0],
|
|
[-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
|
|
-2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
|
|
-0.166666666666666657414808, 1.0])
|
|
|
|
def _ifand(q:UOp, n:int): return (q & n).ne(0)
|
|
|
|
def sin_poly_small(d:UOp, q:UOp) -> UOp:
|
|
r = sin_poly(d)
|
|
return r * _ifand(q, 1).where(r.const_like(-1), r.const_like(1))
|
|
|
|
def sin_poly_large(d:UOp, q:UOp) -> UOp:
|
|
r = sin_poly(d + _ifand(q, 1).where(d.const_like(math.pi / 2), d.const_like(0)))
|
|
return r * _ifand(q, 2).where(r.const_like(-1), r.const_like(1))
|
|
|
|
# *** toplevel functions for xsin/xlog2/xexp2 ***
|
|
|
|
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
|
"""
|
|
Implements a 1.0 ULP approximation for Ops.SIN.
|
|
- fast=True assumes x <= switch_over.
|
|
- switch_over is the threshold for switching to payne_hanek_reduction.
|
|
"""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
# mask +-inf/nan as zero
|
|
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
|
# x_sign = sign(x)
|
|
x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
|
x_abs = x * x_sign
|
|
r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
|
|
if fast: result = sin_poly_small(r, q)
|
|
else:
|
|
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
|
|
r_small, q_small = cody_waite_reduction(x_abs)
|
|
result = (x_abs<switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
|
|
# adjusts the sign for abs(x)
|
|
result = result * x_sign
|
|
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
|
return _lazy_map_numbers(d, d.const_like(math.nan), d.const_like(math.nan), d.const_like(math.nan), result)
|
|
|
|
def xexp2(d:UOp) -> UOp:
|
|
"""
|
|
Implements a 1.0 ULP approximation for Ops.EXP2
|
|
- Paper: https://arxiv.org/pdf/2001.09258
|
|
"""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
# mask +=inf/nan as zero.
|
|
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
|
q = rintk(x)
|
|
# s = d - round(d)
|
|
s = x - q.cast(x.dtype)
|
|
# a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
|
|
if d.dtype == dtypes.float64:
|
|
u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
|
|
0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
|
|
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
|
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
|
|
u = ldexp2k(u, q) # u*2^q
|
|
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
|
|
# Replace x >= upper with +inf
|
|
u = (d >= upper).where(d.const_like(math.inf), u)
|
|
# Replace x < lower with zero.
|
|
u = (d<lower).where(d.const_like(0.0), u)
|
|
# exp2(NaN) = NaN
|
|
return d.ne(d).where(d.const_like(math.nan), u)
|
|
|
|
def xlog2(d:UOp) -> UOp:
|
|
"""
|
|
Implements a 1.0 ULP approximation for Ops.LOG2
|
|
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
|
"""
|
|
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
|
# TODO: float16 denormal need float32 to achieve precision
|
|
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
|
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
|
is_denormal = d<FLT_MIN
|
|
a = is_denormal.where(d * (2 ** 64), d)
|
|
|
|
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
|
m = ldexp3k(a, -e)
|
|
e = is_denormal.where(e - 64, e)
|
|
|
|
x = (m - 1.0) / (m + 1.0)
|
|
x2 = x * x
|
|
if d.dtype == dtypes.float64:
|
|
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
|
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
|
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
|
|
else:
|
|
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
|
|
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
|
|
r = t * (x * x2) + (s_hi + s_lo)
|
|
|
|
# log2(Inf) = Inf
|
|
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
|
# log2(x) = NaN for x < 0
|
|
r = (d<-0.0).where(r.const_like(math.nan), r)
|
|
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
|
# log2_zero = the value of unmasked xlog2(0.0).
|
|
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
|
|
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
|
# log2(NaN) = NaN
|
|
r = d.ne(d).where(r.const_like(math.nan), r)
|
|
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
|
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
|
|
|
|
def xpow(base:UOp, exponent:UOp) -> UOp:
|
|
# start with b ** e = exp2(e * log2(b))
|
|
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
|
|
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
|
|
non_int = exponent != exponent.cast_vec(dtypes.int32).cast(exponent.dtype)
|
|
adj = non_int.where(ret.const_like(math.nan),
|
|
(exponent < 0).where(-exponent, exponent).cast_vec(dtypes.int32).mod(2).cast_vec(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
|
|
# fix 0 ** 0 = 1
|
|
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))
|
|
|