You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					203 lines
				
				7.1 KiB
			
		
		
			
		
	
	
					203 lines
				
				7.1 KiB
			| 
											12 hours ago
										 | import unittest
 | ||
|  | from tinygrad import dtypes
 | ||
|  | from tinygrad.uop.ops import UOp, graph_rewrite_map, _substitute
 | ||
|  | from tinygrad.uop.symbolic import symbolic
 | ||
|  | 
 | ||
|  | class TestRewriteMap(unittest.TestCase):
 | ||
|  |   def test_substitute(self):
 | ||
|  |     a = UOp.variable('a', 0, 10)
 | ||
|  |     b = UOp.variable('b', 0, 10)
 | ||
|  |     c = UOp.variable('c', 0, 10)
 | ||
|  |     e = UOp.variable('e', 0, 10)
 | ||
|  |     ret = (a+b)*c
 | ||
|  |     sub = {a+b: e}
 | ||
|  |     sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
 | ||
|  |     self.assertIs(sub_map[a+b], e)
 | ||
|  |     self.assertIs(sub_map[(a+b)*c], e*c)
 | ||
|  | 
 | ||
|  |   def test_substitute_depth_2(self):
 | ||
|  |     a = UOp.variable('a', 0, 10)
 | ||
|  |     b = UOp.variable('b', 0, 10)
 | ||
|  |     c = UOp.variable('c', 0, 10)
 | ||
|  |     d = UOp.variable('d', 0, 10)
 | ||
|  |     e = UOp.variable('e', 0, 10)
 | ||
|  |     f = UOp.variable('f', 0, 10)
 | ||
|  |     ret = (a+b)*c+d
 | ||
|  |     sub = {a+b: e, (a+b)*c: f}
 | ||
|  |     sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
 | ||
|  |     self.assertIs(sub_map[a+b], e)
 | ||
|  |     self.assertIs(sub_map[(a+b)*c], f)
 | ||
|  | 
 | ||
|  |   def test_multistage_substitute(self):
 | ||
|  |     a = UOp.variable('a', 0, 10)
 | ||
|  |     b = UOp.variable('b', 0, 10)
 | ||
|  |     c = UOp.variable('c', 0, 10)
 | ||
|  |     d = UOp.variable('d', 0, 10)
 | ||
|  |     sub1 = {a+b:c}
 | ||
|  |     start = (a+b)*c
 | ||
|  |     # stage 1: (a+b)*c -> c*c
 | ||
|  |     sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
 | ||
|  |     self.assertIs(sub_map1[(a+b)*c], c*c)
 | ||
|  |     # stage 2: c*c -> d
 | ||
|  |     sub2 = {c*c:d}
 | ||
|  |     sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
 | ||
|  |     # (a+b)*c -> c*c -> d
 | ||
|  |     self.assertIs(sub_map2[(a+b)*c], d)
 | ||
|  | 
 | ||
|  |   def test_add_zero(self):
 | ||
|  |     # Build a small graph: add(0, add(const=0, const=5))
 | ||
|  |     zero_node = UOp.const(dtypes.index, 0)
 | ||
|  |     five_node = UOp.const(dtypes.index, 5)
 | ||
|  |     inner_add = zero_node + five_node
 | ||
|  |     root_add = zero_node + inner_add
 | ||
|  | 
 | ||
|  |     # Perform top-down rewrite
 | ||
|  |     node_map = graph_rewrite_map(root_add, symbolic)
 | ||
|  | 
 | ||
|  |     # We expect that add(0, add(0, 5)) -> add(0, 5) -> 5
 | ||
|  |     # Check the mapping
 | ||
|  |     assert node_map[root_add] == five_node
 | ||
|  |     assert node_map[inner_add] == five_node
 | ||
|  |     # zero_node and five_node map to themselves
 | ||
|  |     assert node_map[zero_node] == zero_node
 | ||
|  |     assert node_map[five_node] == five_node
 | ||
|  | 
 | ||
|  |   def test_double_neg(self):
 | ||
|  |     """
 | ||
|  |     Test rewriting neg(neg(5)) => 5 using symbolic.
 | ||
|  |     """
 | ||
|  |     # In some versions of TinyGrad, you might do: (-(-five_node))
 | ||
|  |     five_node = UOp.const(dtypes.index, 5)
 | ||
|  |     # If your code allows UOp(...), do that; else you might do something like:
 | ||
|  |     # double_neg_five = -(-five_node)
 | ||
|  |     # But let's be explicit:
 | ||
|  |     neg_five = -five_node
 | ||
|  |     double_neg_five = -neg_five
 | ||
|  | 
 | ||
|  |     node_map = graph_rewrite_map(double_neg_five, symbolic)
 | ||
|  | 
 | ||
|  |     # node_map should map double_neg_five -> five_node
 | ||
|  |     self.assertEqual(node_map[double_neg_five], five_node)
 | ||
|  |     # five_node maps to itself
 | ||
|  |     self.assertEqual(node_map[five_node], five_node)
 | ||
|  | 
 | ||
|  |   def test_add_zero_and_double_neg(self):
 | ||
|  |     """
 | ||
|  |     Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
 | ||
|  |     """
 | ||
|  |     zero_node = UOp.const(dtypes.index, 0)
 | ||
|  |     five_node = UOp.const(dtypes.index, 5)
 | ||
|  |     neg_five = -five_node
 | ||
|  |     double_neg_five = -neg_five
 | ||
|  |     root_add = zero_node + double_neg_five
 | ||
|  | 
 | ||
|  |     node_map = graph_rewrite_map(root_add, symbolic)
 | ||
|  | 
 | ||
|  |     # node_map: root_add -> five_node, double_neg_five -> five_node
 | ||
|  |     self.assertEqual(node_map[root_add], five_node)
 | ||
|  |     self.assertEqual(node_map[double_neg_five], five_node)
 | ||
|  |     # zero_node, five_node map to themselves
 | ||
|  |     self.assertEqual(node_map[zero_node], zero_node)
 | ||
|  |     self.assertEqual(node_map[five_node], five_node)
 | ||
|  | 
 | ||
|  |   def test_multi_var_rewrites(self):
 | ||
|  |     x_var = UOp.variable('x', 0, 10)
 | ||
|  |     y_var = UOp.variable('y', -5, 5)
 | ||
|  |     zero_node = UOp.const(dtypes.index, 0)
 | ||
|  | 
 | ||
|  |     sum_with_zero = y_var + zero_node    # (y + 0)
 | ||
|  |     combined = x_var + sum_with_zero     # x + (y + 0)
 | ||
|  |     double_neg = -(-combined)           # neg(neg(x + y))
 | ||
|  |     final_expr = zero_node + double_neg  # 0 + (x + y)
 | ||
|  | 
 | ||
|  |     node_map = graph_rewrite_map(final_expr, symbolic)
 | ||
|  | 
 | ||
|  |     # The final root should be (x_var + y_var).
 | ||
|  |     expected = x_var + y_var
 | ||
|  | 
 | ||
|  |     # Each sub-expression has its own "final" result.
 | ||
|  |     # (y + 0) -> y_var
 | ||
|  |     self.assertEqual(node_map[sum_with_zero], y_var)
 | ||
|  |     # (x + (y+0)) -> (x + y)
 | ||
|  |     self.assertEqual(node_map[combined], expected)
 | ||
|  |     # neg(neg(x+y)) -> (x + y)
 | ||
|  |     self.assertEqual(node_map[double_neg], expected)
 | ||
|  |     # 0 + (x+y) -> (x + y)
 | ||
|  |     self.assertEqual(node_map[final_expr], expected)
 | ||
|  | 
 | ||
|  |     # x_var, y_var, zero_node remain unchanged
 | ||
|  |     self.assertEqual(node_map[x_var], x_var)
 | ||
|  |     self.assertEqual(node_map[y_var], y_var)
 | ||
|  |     self.assertEqual(node_map[zero_node], zero_node)
 | ||
|  | 
 | ||
|  |   def test_complex_multi_var_edges(self):
 | ||
|  |     """
 | ||
|  |     Build a multi-variable expression with multiple intermediates:
 | ||
|  | 
 | ||
|  |       x_var = UOp.variable('x', 1, 10)
 | ||
|  |       y_var = UOp.variable('y', -5, 5)
 | ||
|  |       z_var = UOp.variable('z', 0, 5)
 | ||
|  |       zero_node = UOp.const(dtypes.int, 0)
 | ||
|  |       one_node = UOp.const(dtypes.int, 1)
 | ||
|  | 
 | ||
|  |       yz_sum       = y_var + z_var
 | ||
|  |       yz_sum_zero  = yz_sum + zero_node   -> rewrites to yz_sum
 | ||
|  |       yz_neg       = -yz_sum_zero        -> -(y+z)
 | ||
|  |       yz_dneg      = -yz_neg             -> y+z    (double neg gone)
 | ||
|  |       x_plus_yz    = x_var + yz_dneg     -> x + (y+z)
 | ||
|  |       double_neg_x = -(-x_plus_yz)       -> x + (y+z)
 | ||
|  |       final_expr   = double_neg_x * one_node -> x + (y+z)
 | ||
|  | 
 | ||
|  |     We expect the final result to be (x + (y+z)).
 | ||
|  |     Each original node should map to the final node that replaces it,
 | ||
|  |     which might be structurally equivalent but not the same reference.
 | ||
|  |     """
 | ||
|  |     x_var = UOp.variable('x', 1, 10)
 | ||
|  |     y_var = UOp.variable('y', -5, 5)
 | ||
|  |     z_var = UOp.variable('z', 0, 5)
 | ||
|  |     zero_node = UOp.const(dtypes.index, 0)
 | ||
|  |     one_node = UOp.const(dtypes.index, 1)
 | ||
|  | 
 | ||
|  |     # Build sub-expressions
 | ||
|  |     yz_sum = y_var + z_var              # (y + z)
 | ||
|  |     yz_sum_zero = yz_sum + zero_node    # (y + z) + 0
 | ||
|  |     yz_neg = -yz_sum_zero               # -(y+z)
 | ||
|  |     yz_dneg = -yz_neg                   # -(-(y+z)) -> (y+z)
 | ||
|  |     x_plus_yz = x_var + yz_dneg         # x + (y+z)
 | ||
|  |     double_neg_x = -(-x_plus_yz)        # neg(neg(x+(y+z))) -> x+(y+z)
 | ||
|  |     final_expr = double_neg_x * one_node  # (x+(y+z)) * 1 -> x+(y+z)
 | ||
|  | 
 | ||
|  |     node_map = graph_rewrite_map(final_expr, symbolic)
 | ||
|  | 
 | ||
|  |     # (y + z) is unchanged
 | ||
|  |     self.assertEqual(node_map[yz_sum], yz_sum)
 | ||
|  | 
 | ||
|  |     # (y+z) + 0 => (y+z)
 | ||
|  |     self.assertEqual(node_map[yz_sum_zero], yz_sum)
 | ||
|  | 
 | ||
|  |     # -(y+z) remains -(y+z), but might be a new UOp with updated children
 | ||
|  |     # Compare structurally to -(y_var + z_var).
 | ||
|  |     self.assertEqual(node_map[yz_neg], -yz_sum)
 | ||
|  | 
 | ||
|  |     # -(-(y+z)) => (y+z)
 | ||
|  |     self.assertEqual(node_map[yz_dneg], yz_sum)
 | ||
|  | 
 | ||
|  |     # x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
 | ||
|  |     self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
 | ||
|  | 
 | ||
|  |     # -(-(x+(y+z))) => x + (y+z)
 | ||
|  |     self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
 | ||
|  | 
 | ||
|  |     # (x+(y+z)) * 1 => x+(y+z)
 | ||
|  |     self.assertEqual(node_map[final_expr], x_var + yz_sum)
 | ||
|  | 
 | ||
|  |     # Unchanged atomic nodes map to themselves
 | ||
|  |     self.assertEqual(node_map[x_var], x_var)
 | ||
|  |     self.assertEqual(node_map[y_var], y_var)
 | ||
|  |     self.assertEqual(node_map[z_var], z_var)
 | ||
|  |     self.assertEqual(node_map[zero_node], zero_node)
 | ||
|  |     self.assertEqual(node_map[one_node], one_node)
 | ||
|  | 
 | ||
|  | if __name__ == "__main__":
 | ||
|  |   unittest.main()
 |