You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							352 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							352 lines
						
					
					
						
							12 KiB
						
					
					
				| import unittest, math
 | |
| from tinygrad.uop.ops import UOp, Ops
 | |
| from tinygrad.dtype import dtypes, Invalid
 | |
| 
 | |
| class TestVminVmaxProperties(unittest.TestCase):
 | |
|   def test_vmin_vmax_constant(self):
 | |
|     # vmin and vmax for a constant
 | |
|     uop = UOp.const(dtypes.int32, 42)
 | |
|     self.assertEqual(uop.vmin, 42)
 | |
|     self.assertEqual(uop.vmax, 42)
 | |
| 
 | |
|   def test_vmin_vmax_cmpne(self):
 | |
|     uop = UOp.const(dtypes.int32, 42)
 | |
|     def test_bool(u, x):
 | |
|       self.assertEqual(u.vmin, x)
 | |
|       self.assertEqual(u.vmax, x)
 | |
|     test_bool(uop != 42, False)
 | |
|     test_bool(uop != 43, True)
 | |
|     test_bool(uop != 41, True)
 | |
| 
 | |
|   def test_vmin_vmax_addition_with_variable(self):
 | |
|     # vmin and vmax for addition with a variable
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x + 5
 | |
|     self.assertEqual(uop.vmin, 15)
 | |
|     self.assertEqual(uop.vmax, 25)
 | |
| 
 | |
|   def test_vmin_vmax_subtraction_with_variable(self):
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x - 5
 | |
|     self.assertEqual(uop.vmin, 5)
 | |
|     self.assertEqual(uop.vmax, 15)
 | |
|     uop = 5 - x
 | |
|     self.assertEqual(uop.vmin, -15)
 | |
|     self.assertEqual(uop.vmax, -5)
 | |
| 
 | |
|   def test_vmin_vmax_and_with_variable(self):
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x & 5
 | |
|     self.assertEqual(uop.vmin, 0)
 | |
|     self.assertEqual(uop.vmax, 5)
 | |
| 
 | |
|     # this can be improved
 | |
|     uop = x & 15
 | |
|     self.assertEqual(uop.vmin, 0)
 | |
|     self.assertEqual(uop.vmax, 15)
 | |
| 
 | |
|     # this can be improved
 | |
|     uop = x & 32
 | |
|     self.assertEqual(uop.vmin, 0)
 | |
|     self.assertEqual(uop.vmax, 20)
 | |
| 
 | |
|   def test_vmin_vmax_multiplication_with_variable(self):
 | |
|     # vmin and vmax for multiplication with a variable
 | |
|     x = UOp.variable('x', -3, 4)
 | |
|     uop = x * 2
 | |
|     self.assertEqual(uop.vmin, -6)
 | |
|     self.assertEqual(uop.vmax, 8)
 | |
| 
 | |
|   def test_vmin_vmax_variable_inside_special(self):
 | |
|     uop = UOp(Ops.SPECIAL, dtypes.int, arg='gidx0', src=(UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10)),))
 | |
|     self.assertEqual(uop.vmin, 0)
 | |
|     self.assertEqual(uop.vmax, 9)
 | |
| 
 | |
|   def test_vmin_vmax_multiplication_0_inf(self):
 | |
|     # vmin and vmax for multiplication with a variable
 | |
|     x = UOp.const(dtypes.float, 0.0)
 | |
|     y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
 | |
|     uop = x * y
 | |
|     # TODO: these should be 0, but definitely should not be nan
 | |
|     self.assertEqual(uop.vmin, -math.inf)
 | |
|     self.assertEqual(uop.vmax, math.inf)
 | |
| 
 | |
|   def test_vmin_vmax_with_negative_multiplication(self):
 | |
|     # vmin and vmax when multiplying by a negative number
 | |
|     x = UOp.variable('x', 2, 5)
 | |
|     uop = x * -3
 | |
|     self.assertEqual(uop.vmin, -15)
 | |
|     self.assertEqual(uop.vmax, -6)
 | |
| 
 | |
|   def test_vmin_vmax_with_negative_multiplication2(self):
 | |
|     # vmin and vmax when multiplying by a negative number
 | |
|     x = UOp.variable('x', -2, 5)
 | |
|     uop = x * -3
 | |
|     self.assertEqual(uop.vmin, -15)
 | |
|     self.assertEqual(uop.vmax, 6)
 | |
| 
 | |
|   def test_vmin_vmax_nested_min_max(self):
 | |
|     # vmin and vmax with nested min/max operations
 | |
|     x = UOp.variable('x', 0, 10)
 | |
|     uop = x.maximum(5).minimum(8)
 | |
|     self.assertEqual(uop.vmin, 5)
 | |
|     self.assertEqual(uop.vmax, 8)
 | |
| 
 | |
|   def test_vmin_vmax_where(self):
 | |
|     x = UOp.variable('x', 0, 10)
 | |
|     y = UOp.variable('y', 1, 11)
 | |
|     z = UOp.variable('z', 2, 12)
 | |
|     uop = (x<5).where(y, z)
 | |
|     self.assertEqual(uop.vmin, 1)
 | |
|     self.assertEqual(uop.vmax, 12)
 | |
| 
 | |
|   def test_vmin_vmax_shl(self):
 | |
|     x = UOp.variable('x', 0, 10) << 5
 | |
|     self.assertEqual(x.vmin, 0)
 | |
|     self.assertEqual(x.vmax, 10 << 5)
 | |
| 
 | |
|   def test_vmin_vmax_shr(self):
 | |
|     x = UOp.variable('x', 0, 10) >> 2
 | |
|     self.assertEqual(x.vmin, 0)
 | |
|     self.assertEqual(x.vmax, 10 >> 2)
 | |
| 
 | |
|   def test_vmin_vmax_cast(self):
 | |
|     x = UOp.variable('x', -10, 10, dtypes.int)
 | |
|     x_float = x.cast(dtypes.float)
 | |
|     self.assertEqual(x_float.vmin, -10)
 | |
|     self.assertEqual(x_float.vmax, 10)
 | |
|     x_bool = x.cast(dtypes.bool)
 | |
|     self.assertEqual(x_bool.vmin, False)
 | |
|     self.assertEqual(x_bool.vmax, True)
 | |
|     x_uint = x.cast(dtypes.uint)
 | |
|     self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
 | |
|     self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
 | |
| 
 | |
|   def test_vmin_vmax_invalid(self):
 | |
|     i = UOp.invalid()
 | |
|     self.assertNotEqual(i.vmin, i.vmax)
 | |
| 
 | |
|   def test_vmin_vmax_invalid_vconst(self):
 | |
|     x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
 | |
|     self.assertLess(x.vmin, 0)
 | |
|     self.assertGreater(x.vmax, 4)
 | |
| 
 | |
| class TestVminVmaxDivMod(unittest.TestCase):
 | |
|   def test_vmin_vmax_division_positive(self):
 | |
|     # vmin and vmax for division of a variable by a positive constant
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x // 2
 | |
|     self.assertEqual(uop.vmin, 5)
 | |
|     self.assertEqual(uop.vmax, 10)
 | |
| 
 | |
|   def test_vmin_vmax_division_negative(self):
 | |
|     # vmin and vmax for division of a variable by a negative constant
 | |
|     # always positive
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x // -2
 | |
|     self.assertEqual(uop.vmin, -10)
 | |
|     self.assertEqual(uop.vmax, -5)
 | |
|     uop = x // -3
 | |
|     self.assertEqual(uop.vmin, -6)
 | |
|     self.assertEqual(uop.vmax, -3)
 | |
| 
 | |
|     # always negative
 | |
|     x = UOp.variable('x', -20, -10)
 | |
|     uop = x // -2
 | |
|     self.assertEqual(uop.vmin, 5)
 | |
|     self.assertEqual(uop.vmax, 10)
 | |
|     uop = x // -3
 | |
|     self.assertEqual(uop.vmin, 3)
 | |
|     self.assertEqual(uop.vmax, 6)
 | |
| 
 | |
|     # cross 0
 | |
|     x = UOp.variable('x', -10, 10)
 | |
|     uop = x // -2
 | |
|     self.assertEqual(uop.vmin, -5)
 | |
|     self.assertEqual(uop.vmax, 5)
 | |
|     uop = x // -3
 | |
|     self.assertEqual(uop.vmin, -3)
 | |
|     self.assertEqual(uop.vmax, 3)
 | |
| 
 | |
|   def test_vmin_vmax_div_symbolic(self):
 | |
|     x = UOp.variable('x', 1, 10)
 | |
|     y = UOp.variable('y', 3, 5)
 | |
|     self.assertEqual((x//y).vmin, 0)
 | |
|     self.assertEqual((x//y).vmax, 3)
 | |
|     self.assertEqual(((-x)//y).vmin, -3)
 | |
|     self.assertEqual(((-x)//y).vmax, 0)
 | |
|     self.assertEqual((x//(-y)).vmin, -3)
 | |
|     self.assertEqual((x//(-y)).vmax, 0)
 | |
|     self.assertEqual(((-x)//(-y)).vmin, 0)
 | |
|     self.assertEqual(((-x)//(-y)).vmax, 3)
 | |
| 
 | |
|     self.assertEqual((100//y).vmin, 20)
 | |
|     self.assertEqual((100//y).vmax, 33)
 | |
|     self.assertEqual(((-100)//y).vmin, -33)
 | |
|     self.assertEqual(((-100)//y).vmax, -20)
 | |
|     self.assertEqual((100//(-y)).vmin, -33)
 | |
|     self.assertEqual((100//(-y)).vmax, -20)
 | |
|     self.assertEqual(((-100)//(-y)).vmin, 20)
 | |
|     self.assertEqual(((-100)//(-y)).vmax, 33)
 | |
| 
 | |
|   def test_vmin_vmax_mod_positive(self):
 | |
|     # vmin and vmax for modulo of a variable by a positive constant
 | |
|     positive = UOp.variable('positive', 10, 20)
 | |
|     uop = positive % 3
 | |
|     self.assertEqual(uop.vmin, 0)
 | |
|     self.assertEqual(uop.vmax, 2)
 | |
| 
 | |
|     negative = UOp.variable('negative', -20, -10)
 | |
|     uop = negative % 3
 | |
|     self.assertEqual(uop.vmin, -2)
 | |
|     self.assertEqual(uop.vmax, 0)
 | |
| 
 | |
|     mixed = UOp.variable('mixed', -20, 20)
 | |
|     uop = mixed % 3
 | |
|     self.assertEqual(uop.vmin, -2)
 | |
|     self.assertEqual(uop.vmax, 2)
 | |
| 
 | |
|   def test_vmin_vmax_mod_negative(self):
 | |
|     # vmin and vmax for modulo of a variable by a negative constant
 | |
|     positive = UOp.variable('positive', 10, 20)
 | |
|     uop = positive % -3
 | |
|     self.assertEqual(uop.vmin, 0)
 | |
|     self.assertEqual(uop.vmax, 2)
 | |
| 
 | |
|     negative = UOp.variable('negative', -20, -10)
 | |
|     uop = negative % -3
 | |
|     self.assertEqual(uop.vmin, -2)
 | |
|     self.assertEqual(uop.vmax, 0)
 | |
| 
 | |
|     mixed = UOp.variable('mixed', -20, 20)
 | |
|     uop = mixed % -3
 | |
|     self.assertEqual(uop.vmin, -2)
 | |
|     self.assertEqual(uop.vmax, 2)
 | |
| 
 | |
| class TestVminVmaxVConst(unittest.TestCase):
 | |
|   def test_vmin_vmax_vconst_single_element(self):
 | |
|     # vmin and vmax for a single-element vector constant
 | |
|     uop = UOp.const(dtypes.int32.vec(1), (42,))
 | |
|     self.assertEqual(uop.vmin, 42)
 | |
|     self.assertEqual(uop.vmax, 42)
 | |
| 
 | |
|   def test_vmin_vmax_vconst_multiple_elements(self):
 | |
|     # vmin and vmax for a multi-element vector constant
 | |
|     uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7))
 | |
|     self.assertEqual(uop.vmin, -5)
 | |
|     self.assertEqual(uop.vmax, 20)
 | |
| 
 | |
|   def test_vmin_vmax_vconst_all_equal(self):
 | |
|     # vmin and vmax for a vector where all elements are equal
 | |
|     uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7))
 | |
|     self.assertEqual(uop.vmin, 7)
 | |
|     self.assertEqual(uop.vmax, 7)
 | |
| 
 | |
|   def test_vmin_vmax_vconst_with_negative_values(self):
 | |
|     # vmin and vmax for a vector constant containing negative values
 | |
|     uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15))
 | |
|     self.assertEqual(uop.vmin, -20)
 | |
|     self.assertEqual(uop.vmax, -5)
 | |
| 
 | |
|   def test_vmin_vmax_vconst_with_floats(self):
 | |
|     # vmin and vmax for a vector constant of float values
 | |
|     uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0))
 | |
|     self.assertEqual(uop.vmin, -3.2)
 | |
|     self.assertEqual(uop.vmax, 1.5)
 | |
| 
 | |
|   def test_vmin_vmax_vconst_with_bools(self):
 | |
|     # vmin and vmax for a vector constant of bool values
 | |
|     uop = UOp.const(dtypes.bool.vec(3), (True, False, False))
 | |
|     self.assertIs(uop.vmin, False)
 | |
|     self.assertIs(uop.vmax, True)
 | |
| 
 | |
|   def test_vmin_vmax_vector_with_gep(self):
 | |
|     # vmin and vmax for a vector constant of bool values
 | |
|     d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
 | |
|     idx = UOp.const(dtypes.int, 0)
 | |
|     val = UOp(Ops.LOAD, dtypes.int.vec(2), (d1.index(idx),))
 | |
|     uop = (val // 32).gep(0)
 | |
|     self.assertEqual(uop.vmin, -67108864)
 | |
|     self.assertEqual(uop.vmax, 67108863)
 | |
| 
 | |
| class TestConstFactor(unittest.TestCase):
 | |
|   def test_const_factor_constant(self):
 | |
|     # const_factor for a constant
 | |
|     uop = UOp.const(dtypes.int32, 42)
 | |
|     self.assertEqual(uop.const_factor(), 42)
 | |
| 
 | |
|   def test_const_factor_addition(self):
 | |
|     # const_factor for an addition of constants
 | |
|     uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12)
 | |
|     self.assertEqual(uop.const_factor(), 6)  # GCD(30, 12) = 6
 | |
| 
 | |
|   def test_const_factor_multiplication(self):
 | |
|     # const_factor for a multiplication of constants
 | |
|     uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7)
 | |
|     self.assertEqual(uop.const_factor(), 5)  # For multiplication, it's one of the factors
 | |
| 
 | |
|   def test_const_factor_with_variable(self):
 | |
|     # const_factor for an expression involving a variable
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x * 3
 | |
|     self.assertEqual(uop.const_factor(), 3)
 | |
| 
 | |
|   def test_const_factor_division(self):
 | |
|     # const_factor for an expression with division
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x // 4
 | |
|     self.assertEqual(uop.const_factor(), 1)  # Division reduces the const_factor to 1
 | |
| 
 | |
|   def test_const_factor_multiplication_of_var_and_const(self):
 | |
|     # const_factor for multiplication of a variable and a constant
 | |
|     x = UOp.variable('x', 6, 18)
 | |
|     uop = x * 4
 | |
|     self.assertEqual(uop.const_factor(), 4)  # Constant factor 4
 | |
| 
 | |
|   @unittest.skip("broken")
 | |
|   def test_const_factor_multiplication_of_consts_and_vars(self):
 | |
|     # Multiplying constants and variables
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = (x * 3) * 5
 | |
|     self.assertEqual(uop.const_factor(), 15)  # Constant multipliers are combined (3 * 5 = 15)
 | |
| 
 | |
| class TestDivides(unittest.TestCase):
 | |
|   def test_divides_constant_exact(self):
 | |
|     # Divides a constant by an exact divisor
 | |
|     uop = UOp.const(dtypes.int32, 42)
 | |
|     result = uop.divides(7)
 | |
|     self.assertIsNotNone(result)
 | |
|     self.assertEqual(result.const_factor(), 6)  # 42 / 7 = 6
 | |
| 
 | |
|   def test_divides_constant_inexact(self):
 | |
|     # Try to divide a constant by a non-exact divisor
 | |
|     uop = UOp.const(dtypes.int32, 42)
 | |
|     result = uop.divides(5)
 | |
|     self.assertIsNone(result)  # 42 is not divisible by 5
 | |
| 
 | |
|   @unittest.skip("broken")
 | |
|   def test_divides_variable_and_constant(self):
 | |
|     # Multiplying a variable by a constant, then dividing by the same constant
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = x * 6
 | |
|     result = uop.divides(6)
 | |
|     self.assertIsNotNone(result)
 | |
|     self.assertEqual(result, x)  # (x * 6) / 6 = x
 | |
| 
 | |
|   def test_divides_complex_expression(self):
 | |
|     # Dividing a more complex expression
 | |
|     x = UOp.variable('x', 10, 20)
 | |
|     uop = (x * 6) + 18
 | |
|     result = uop.divides(6)
 | |
|     self.assertIsNotNone(result)
 | |
|     self.assertEqual(result.const_factor(), 1)  # (x + 3), const_factor is 1
 | |
| 
 | |
|   def test_divides_with_inexact_factors(self):
 | |
|     # Multiplying by a constant but dividing by a non-exact divisor
 | |
|     x = UOp.variable('x', 15, 45)
 | |
|     uop = x * 4
 | |
|     result = uop.divides(3)
 | |
|     self.assertIsNone(result)  # Cannot divide by 3, since 4 is not divisible by 3
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   unittest.main()
 | |
| 
 |