You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
3.3 KiB
68 lines
3.3 KiB
1 month ago
|
import unittest, math
|
||
|
import numpy as np
|
||
|
from tinygrad import dtypes
|
||
|
from tinygrad.ops import UOp
|
||
|
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if
|
||
|
from test.helpers import eval_uop
|
||
|
|
||
|
class TestTranscendentalFunctions(unittest.TestCase):
|
||
|
def test_payne_hanek_reduction(self):
|
||
|
r, q = (eval_uop(u) for u in payne_hanek_reduction(UOp.const(dtypes.float64, 12 * math.pi + 0.1)))
|
||
|
np.testing.assert_allclose(r, 0.1 - math.pi / 2)
|
||
|
np.testing.assert_equal(q, 1)
|
||
|
|
||
|
r, q = (eval_uop(u) for u in payne_hanek_reduction(UOp.const(dtypes.float64, 12 * math.pi)))
|
||
|
np.testing.assert_allclose(r, 0.0, atol=1e-8)
|
||
|
np.testing.assert_equal(q, 4)
|
||
|
|
||
|
r, q = (eval_uop(u) for u in payne_hanek_reduction(UOp.const(dtypes.float64, 12 * math.pi - 0.1)))
|
||
|
np.testing.assert_allclose(r, -0.1)
|
||
|
np.testing.assert_equal(q, 4)
|
||
|
|
||
|
def test_cody_waite_reduction(self):
|
||
|
r, q = (eval_uop(u) for u in cody_waite_reduction(UOp.const(dtypes.float64, 12 * math.pi + 0.1)))
|
||
|
np.testing.assert_allclose(r, 0.1)
|
||
|
np.testing.assert_equal(q, 12)
|
||
|
|
||
|
def test_frexp(self):
|
||
|
for x in (1, -1):
|
||
|
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, x)))
|
||
|
np.testing.assert_equal(mantissa, 0.5)
|
||
|
np.testing.assert_equal(exponent, 1)
|
||
|
|
||
|
for x in (2, -2):
|
||
|
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 2.0)))
|
||
|
np.testing.assert_equal(mantissa, 0.5)
|
||
|
np.testing.assert_equal(exponent, 2)
|
||
|
|
||
|
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 5.0)))
|
||
|
np.testing.assert_equal(mantissa, 0.625)
|
||
|
np.testing.assert_equal(exponent, 3)
|
||
|
|
||
|
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 1000.0)))
|
||
|
np.testing.assert_allclose(mantissa, 0.9765625)
|
||
|
np.testing.assert_equal(exponent, 10)
|
||
|
|
||
|
def test_rintk(self):
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 0.0))), 0)
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.0))), 5)
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.5))), 6)
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.999))), 6)
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.0))), -5)
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.5))), -6)
|
||
|
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.999))), -6)
|
||
|
|
||
|
def test_pow2if(self):
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 0), dtypes.float)), 1.0)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 1), dtypes.float)), 2.0)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 2), dtypes.float)), 4.0)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 10), dtypes.float)), 1024.0)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 63), dtypes.float)), 2**63)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -1), dtypes.float)), 0.5)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -2), dtypes.float)), 0.25)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
|
||
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|