You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
244 lines
8.6 KiB
244 lines
8.6 KiB
1 month ago
|
import unittest, math
|
||
|
from tinygrad.ops import UOp, Ops
|
||
|
from tinygrad.dtype import dtypes
|
||
|
|
||
|
class TestVminVmaxProperties(unittest.TestCase):
|
||
|
def test_vmin_vmax_constant(self):
|
||
|
# vmin and vmax for a constant
|
||
|
uop = UOp.const(dtypes.int32, 42)
|
||
|
self.assertEqual(uop.vmin, 42)
|
||
|
self.assertEqual(uop.vmax, 42)
|
||
|
|
||
|
def test_vmin_vmax_cmpne(self):
|
||
|
uop = UOp.const(dtypes.int32, 42)
|
||
|
def test_bool(u, x):
|
||
|
self.assertEqual(u.vmin, x)
|
||
|
self.assertEqual(u.vmax, x)
|
||
|
test_bool(uop != 42, False)
|
||
|
test_bool(uop != 43, True)
|
||
|
test_bool(uop != 41, True)
|
||
|
|
||
|
def test_vmin_vmax_addition_with_variable(self):
|
||
|
# vmin and vmax for addition with a variable
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x + 5
|
||
|
self.assertEqual(uop.vmin, 15)
|
||
|
self.assertEqual(uop.vmax, 25)
|
||
|
|
||
|
def test_vmin_vmax_multiplication_with_variable(self):
|
||
|
# vmin and vmax for multiplication with a variable
|
||
|
x = UOp.variable('x', -3, 4)
|
||
|
uop = x * 2
|
||
|
self.assertEqual(uop.vmin, -6)
|
||
|
self.assertEqual(uop.vmax, 8)
|
||
|
|
||
|
def test_vmin_vmax_variable_inside_special(self):
|
||
|
uop = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10))))
|
||
|
self.assertEqual(uop.vmin, 0)
|
||
|
self.assertEqual(uop.vmax, 10)
|
||
|
|
||
|
def test_vmin_vmax_multiplication_0_inf(self):
|
||
|
# vmin and vmax for multiplication with a variable
|
||
|
x = UOp.const(dtypes.float, 0.0)
|
||
|
y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
|
||
|
uop = x * y
|
||
|
# TODO: these should be 0, but definitely should not be nan
|
||
|
self.assertEqual(uop.vmin, -math.inf)
|
||
|
self.assertEqual(uop.vmax, math.inf)
|
||
|
|
||
|
def test_vmin_vmax_with_negative_multiplication(self):
|
||
|
# vmin and vmax when multiplying by a negative number
|
||
|
x = UOp.variable('x', 2, 5)
|
||
|
uop = x * -3
|
||
|
self.assertEqual(uop.vmin, -15)
|
||
|
self.assertEqual(uop.vmax, -6)
|
||
|
|
||
|
def test_vmin_vmax_with_negative_multiplication2(self):
|
||
|
# vmin and vmax when multiplying by a negative number
|
||
|
x = UOp.variable('x', -2, 5)
|
||
|
uop = x * -3
|
||
|
self.assertEqual(uop.vmin, -15)
|
||
|
self.assertEqual(uop.vmax, 6)
|
||
|
|
||
|
def test_vmin_vmax_nested_min_max(self):
|
||
|
# vmin and vmax with nested min/max operations
|
||
|
x = UOp.variable('x', 0, 10)
|
||
|
uop = x.maximum(5).minimum(8)
|
||
|
self.assertEqual(uop.vmin, 5)
|
||
|
self.assertEqual(uop.vmax, 8)
|
||
|
|
||
|
def test_vmin_vmax_where(self):
|
||
|
x = UOp.variable('x', 0, 10)
|
||
|
y = UOp.variable('y', 1, 11)
|
||
|
z = UOp.variable('z', 2, 12)
|
||
|
uop = (x<5).where(y, z)
|
||
|
self.assertEqual(uop.vmin, 1)
|
||
|
self.assertEqual(uop.vmax, 12)
|
||
|
|
||
|
def test_vmin_vmax_shl(self):
|
||
|
x = UOp.variable('x', 0, 10) << 5
|
||
|
self.assertEqual(x.vmin, 0)
|
||
|
self.assertEqual(x.vmax, 10 << 5)
|
||
|
|
||
|
def test_vmin_vmax_shr(self):
|
||
|
x = UOp.variable('x', 0, 10) >> 2
|
||
|
self.assertEqual(x.vmin, 0)
|
||
|
self.assertEqual(x.vmax, 10 >> 2)
|
||
|
|
||
|
class TestVminVmaxDivMod(unittest.TestCase):
|
||
|
def test_vmin_vmax_division_positive(self):
|
||
|
# vmin and vmax for division of a variable by a positive constant
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x // 2
|
||
|
self.assertEqual(uop.vmin, 5)
|
||
|
self.assertEqual(uop.vmax, 10)
|
||
|
|
||
|
def test_vmin_vmax_division_negative(self):
|
||
|
# vmin and vmax for division of a variable by a negative constant
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x // -2
|
||
|
self.assertEqual(uop.vmin, -10)
|
||
|
self.assertEqual(uop.vmax, -5)
|
||
|
|
||
|
def test_vmin_vmax_mod_positive(self):
|
||
|
# vmin and vmax for modulo of a variable by a positive constant
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x % 3
|
||
|
self.assertEqual(uop.vmin, 0)
|
||
|
self.assertEqual(uop.vmax, 2)
|
||
|
|
||
|
@unittest.skip("broken")
|
||
|
def test_vmin_vmax_mod_negative(self):
|
||
|
# vmin and vmax for modulo of a variable by a negative constant
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x % -3
|
||
|
self.assertEqual(uop.vmin, -2)
|
||
|
self.assertEqual(uop.vmax, 0)
|
||
|
|
||
|
def test_vmin_vmax_division_with_mixed_range(self):
|
||
|
# vmin and vmax for division of a variable with a range crossing zero
|
||
|
x = UOp.variable('x', -10, 10)
|
||
|
uop = x // 3
|
||
|
self.assertEqual(uop.vmin, -4) # -10//3 = -4
|
||
|
self.assertEqual(uop.vmax, 3) # 10//3 = 3
|
||
|
|
||
|
def test_vmin_vmax_mod_with_mixed_range(self):
|
||
|
# vmin and vmax for modulo of a variable with a range crossing zero
|
||
|
x = UOp.variable('x', -10, 10)
|
||
|
uop = x % 4
|
||
|
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
|
||
|
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
|
||
|
|
||
|
class TestVminVmaxVConst(unittest.TestCase):
|
||
|
def test_vmin_vmax_vconst_single_element(self):
|
||
|
# vmin and vmax for a single-element vector constant
|
||
|
uop = UOp.const(dtypes.int32.vec(1), (42,))
|
||
|
self.assertEqual(uop.vmin, 42)
|
||
|
self.assertEqual(uop.vmax, 42)
|
||
|
|
||
|
def test_vmin_vmax_vconst_multiple_elements(self):
|
||
|
# vmin and vmax for a multi-element vector constant
|
||
|
uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7))
|
||
|
self.assertEqual(uop.vmin, -5)
|
||
|
self.assertEqual(uop.vmax, 20)
|
||
|
|
||
|
def test_vmin_vmax_vconst_all_equal(self):
|
||
|
# vmin and vmax for a vector where all elements are equal
|
||
|
uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7))
|
||
|
self.assertEqual(uop.vmin, 7)
|
||
|
self.assertEqual(uop.vmax, 7)
|
||
|
|
||
|
def test_vmin_vmax_vconst_with_negative_values(self):
|
||
|
# vmin and vmax for a vector constant containing negative values
|
||
|
uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15))
|
||
|
self.assertEqual(uop.vmin, -20)
|
||
|
self.assertEqual(uop.vmax, -5)
|
||
|
|
||
|
def test_vmin_vmax_vconst_with_floats(self):
|
||
|
# vmin and vmax for a vector constant of float values
|
||
|
uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0))
|
||
|
self.assertEqual(uop.vmin, -3.2)
|
||
|
self.assertEqual(uop.vmax, 1.5)
|
||
|
|
||
|
class TestConstFactor(unittest.TestCase):
|
||
|
def test_const_factor_constant(self):
|
||
|
# const_factor for a constant
|
||
|
uop = UOp.const(dtypes.int32, 42)
|
||
|
self.assertEqual(uop.const_factor(), 42)
|
||
|
|
||
|
def test_const_factor_addition(self):
|
||
|
# const_factor for an addition of constants
|
||
|
uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12)
|
||
|
self.assertEqual(uop.const_factor(), 6) # GCD(30, 12) = 6
|
||
|
|
||
|
def test_const_factor_multiplication(self):
|
||
|
# const_factor for a multiplication of constants
|
||
|
uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7)
|
||
|
self.assertEqual(uop.const_factor(), 5) # For multiplication, it's one of the factors
|
||
|
|
||
|
def test_const_factor_with_variable(self):
|
||
|
# const_factor for an expression involving a variable
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x * 3
|
||
|
self.assertEqual(uop.const_factor(), 3)
|
||
|
|
||
|
def test_const_factor_division(self):
|
||
|
# const_factor for an expression with division
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x // 4
|
||
|
self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1
|
||
|
|
||
|
def test_const_factor_multiplication_of_var_and_const(self):
|
||
|
# const_factor for multiplication of a variable and a constant
|
||
|
x = UOp.variable('x', 6, 18)
|
||
|
uop = x * 4
|
||
|
self.assertEqual(uop.const_factor(), 4) # Constant factor 4
|
||
|
|
||
|
@unittest.skip("broken")
|
||
|
def test_const_factor_multiplication_of_consts_and_vars(self):
|
||
|
# Multiplying constants and variables
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = (x * 3) * 5
|
||
|
self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15)
|
||
|
|
||
|
class TestDivides(unittest.TestCase):
|
||
|
def test_divides_constant_exact(self):
|
||
|
# Divides a constant by an exact divisor
|
||
|
uop = UOp.const(dtypes.int32, 42)
|
||
|
result = uop.divides(7)
|
||
|
self.assertIsNotNone(result)
|
||
|
self.assertEqual(result.const_factor(), 6) # 42 / 7 = 6
|
||
|
|
||
|
def test_divides_constant_inexact(self):
|
||
|
# Try to divide a constant by a non-exact divisor
|
||
|
uop = UOp.const(dtypes.int32, 42)
|
||
|
result = uop.divides(5)
|
||
|
self.assertIsNone(result) # 42 is not divisible by 5
|
||
|
|
||
|
@unittest.skip("broken")
|
||
|
def test_divides_variable_and_constant(self):
|
||
|
# Multiplying a variable by a constant, then dividing by the same constant
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = x * 6
|
||
|
result = uop.divides(6)
|
||
|
self.assertIsNotNone(result)
|
||
|
self.assertEqual(result, x) # (x * 6) / 6 = x
|
||
|
|
||
|
def test_divides_complex_expression(self):
|
||
|
# Dividing a more complex expression
|
||
|
x = UOp.variable('x', 10, 20)
|
||
|
uop = (x * 6) + 18
|
||
|
result = uop.divides(6)
|
||
|
self.assertIsNotNone(result)
|
||
|
self.assertEqual(result.const_factor(), 1) # (x + 3), const_factor is 1
|
||
|
|
||
|
def test_divides_with_inexact_factors(self):
|
||
|
# Multiplying by a constant but dividing by a non-exact divisor
|
||
|
x = UOp.variable('x', 15, 45)
|
||
|
uop = x * 4
|
||
|
result = uop.divides(3)
|
||
|
self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|