import unittest from tinygrad import dtypes from tinygrad.ops import UOp, graph_rewrite_map, _substitute from tinygrad.codegen.symbolic import symbolic class TestRewriteMap(unittest.TestCase): def test_substitute(self): a = UOp.variable('a', 0, 10) b = UOp.variable('b', 0, 10) c = UOp.variable('c', 0, 10) e = UOp.variable('e', 0, 10) ret = (a+b)*c sub = {a+b: e} sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True) self.assertIs(sub_map[a+b], e) self.assertIs(sub_map[(a+b)*c], e*c) def test_substitute_depth_2(self): a = UOp.variable('a', 0, 10) b = UOp.variable('b', 0, 10) c = UOp.variable('c', 0, 10) d = UOp.variable('d', 0, 10) e = UOp.variable('e', 0, 10) f = UOp.variable('f', 0, 10) ret = (a+b)*c+d sub = {a+b: e, (a+b)*c: f} sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True) self.assertIs(sub_map[a+b], e) self.assertIs(sub_map[(a+b)*c], f) def test_add_zero(self): # Build a small graph: add(0, add(const=0, const=5)) zero_node = UOp.const(dtypes.int, 0) five_node = UOp.const(dtypes.int, 5) inner_add = zero_node + five_node root_add = zero_node + inner_add # Perform top-down rewrite node_map = graph_rewrite_map(root_add, symbolic) # We expect that add(0, add(0, 5)) -> add(0, 5) -> 5 # Check the mapping assert node_map[root_add] == five_node assert node_map[inner_add] == five_node # zero_node and five_node map to themselves assert node_map[zero_node] == zero_node assert node_map[five_node] == five_node def test_double_neg(self): """ Test rewriting neg(neg(5)) => 5 using symbolic. """ # In some versions of TinyGrad, you might do: (-(-five_node)) five_node = UOp.const(dtypes.int, 5) # If your code allows UOp(...), do that; else you might do something like: # double_neg_five = -(-five_node) # But let's be explicit: neg_five = -five_node double_neg_five = -neg_five node_map = graph_rewrite_map(double_neg_five, symbolic) # node_map should map double_neg_five -> five_node self.assertEqual(node_map[double_neg_five], five_node) # five_node maps to itself self.assertEqual(node_map[five_node], five_node) def test_add_zero_and_double_neg(self): """ Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5 """ zero_node = UOp.const(dtypes.int, 0) five_node = UOp.const(dtypes.int, 5) neg_five = -five_node double_neg_five = -neg_five root_add = zero_node + double_neg_five node_map = graph_rewrite_map(root_add, symbolic) # node_map: root_add -> five_node, double_neg_five -> five_node self.assertEqual(node_map[root_add], five_node) self.assertEqual(node_map[double_neg_five], five_node) # zero_node, five_node map to themselves self.assertEqual(node_map[zero_node], zero_node) self.assertEqual(node_map[five_node], five_node) def test_multi_var_rewrites(self): x_var = UOp.variable('x', 0, 10) y_var = UOp.variable('y', -5, 5) zero_node = UOp.const(dtypes.int, 0) sum_with_zero = y_var + zero_node # (y + 0) combined = x_var + sum_with_zero # x + (y + 0) double_neg = -(-combined) # neg(neg(x + y)) final_expr = zero_node + double_neg # 0 + (x + y) node_map = graph_rewrite_map(final_expr, symbolic) # The final root should be (x_var + y_var). expected = x_var + y_var # Each sub-expression has its own "final" result. # (y + 0) -> y_var self.assertEqual(node_map[sum_with_zero], y_var) # (x + (y+0)) -> (x + y) self.assertEqual(node_map[combined], expected) # neg(neg(x+y)) -> (x + y) self.assertEqual(node_map[double_neg], expected) # 0 + (x+y) -> (x + y) self.assertEqual(node_map[final_expr], expected) # x_var, y_var, zero_node remain unchanged self.assertEqual(node_map[x_var], x_var) self.assertEqual(node_map[y_var], y_var) self.assertEqual(node_map[zero_node], zero_node) def test_complex_multi_var_edges(self): """ Build a multi-variable expression with multiple intermediates: x_var = UOp.variable('x', 1, 10) y_var = UOp.variable('y', -5, 5) z_var = UOp.variable('z', 0, 5) zero_node = UOp.const(dtypes.int, 0) one_node = UOp.const(dtypes.int, 1) yz_sum = y_var + z_var yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum yz_neg = -yz_sum_zero -> -(y+z) yz_dneg = -yz_neg -> y+z (double neg gone) x_plus_yz = x_var + yz_dneg -> x + (y+z) double_neg_x = -(-x_plus_yz) -> x + (y+z) final_expr = double_neg_x * one_node -> x + (y+z) We expect the final result to be (x + (y+z)). Each original node should map to the final node that replaces it, which might be structurally equivalent but not the same reference. """ x_var = UOp.variable('x', 1, 10) y_var = UOp.variable('y', -5, 5) z_var = UOp.variable('z', 0, 5) zero_node = UOp.const(dtypes.int, 0) one_node = UOp.const(dtypes.int, 1) # Build sub-expressions yz_sum = y_var + z_var # (y + z) yz_sum_zero = yz_sum + zero_node # (y + z) + 0 yz_neg = -yz_sum_zero # -(y+z) yz_dneg = -yz_neg # -(-(y+z)) -> (y+z) x_plus_yz = x_var + yz_dneg # x + (y+z) double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z) final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z) node_map = graph_rewrite_map(final_expr, symbolic) # (y + z) is unchanged self.assertEqual(node_map[yz_sum], yz_sum) # (y+z) + 0 => (y+z) self.assertEqual(node_map[yz_sum_zero], yz_sum) # -(y+z) remains -(y+z), but might be a new UOp with updated children # Compare structurally to -(y_var + z_var). self.assertEqual(node_map[yz_neg], -yz_sum) # -(-(y+z)) => (y+z) self.assertEqual(node_map[yz_dneg], yz_sum) # x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum self.assertEqual(node_map[x_plus_yz], x_var + yz_sum) # -(-(x+(y+z))) => x + (y+z) self.assertEqual(node_map[double_neg_x], x_var + yz_sum) # (x+(y+z)) * 1 => x+(y+z) self.assertEqual(node_map[final_expr], x_var + yz_sum) # Unchanged atomic nodes map to themselves self.assertEqual(node_map[x_var], x_var) self.assertEqual(node_map[y_var], y_var) self.assertEqual(node_map[z_var], z_var) self.assertEqual(node_map[zero_node], zero_node) self.assertEqual(node_map[one_node], one_node) if __name__ == "__main__": unittest.main()