from __future__ import annotations from abc import abstractmethod import functools from math import gcd from itertools import product from tinygrad.helpers import partition from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator # NOTE: Python has different behavior for negative mod and floor div than c # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node)) class Node: b: Union[Node, int] min: int max: int def render(self, ops=None, ctx=None) -> Any: if ops is None: ops = render_python assert self.__class__ in (Variable, NumNode) or self.min != self.max return ops[type(self)](self, ops, ctx) def vars(self): return [] def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0)) # expand a Node into List[Node] that enumerates the underlying Variables from min to max # expand increments earlier variables faster than later variables (as specified in the argument) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def expand(self, idxs:Optional[Tuple[VariableOrNum, ...]]=None) -> List[Node]: if idxs is None: idxs = (self.expand_idx(),) return [self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs)] @staticmethod def iter_idxs(idxs:Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int,...]]: yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]])) # substitute Variables with the values in var_vals def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__) def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") @functools.cached_property def hash(self) -> int: return hash(self.key) def __repr__(self): return self.render(ctx="REPR") def __str__(self): return "<"+self.key+">" def __hash__(self): return self.hash def __bool__(self): return not (self.max == self.min == 0) def __eq__(self, other:object) -> bool: if not isinstance(other, Node): return NotImplemented return self.key == other.key def __neg__(self): return self*-1 def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)]) def __radd__(self, b:int): return self+b def __sub__(self, b:Union[Node,int]): return self+-b def __rsub__(self, b:int): return -self+b def __le__(self, b:Union[Node,int]): return self < (b+1) def __gt__(self, b:Union[Node,int]): return (-self) < (-b) def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1) def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b)) def __mul__(self, b:Union[Node, int]): if b == 0: return NumNode(0) if b == 1: return self if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b)) def __rmul__(self, b:int): return self*b # *** complex ops *** def __rfloordiv__(self, b:int): if self.min > b >= 0: return NumNode(0) if isinstance(self, NumNode): return NumNode(b // self.b) raise RuntimeError(f"not supported: {b} // {self}") def __floordiv__(self, b:Union[Node,int], factoring_allowed=True): if isinstance(b, Node): if b.__class__ is NumNode: return self // b.b if self == b: return NumNode(1) if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node raise RuntimeError(f"not supported: {self} // {b}") assert b != 0 if b < 0: return (self//-b)*-1 if b == 1: return self # the numerator of div is not allowed to be negative if self.min < 0: offset = self.min//b # factor out an "offset" to make the numerator positive. don't allowing factoring again return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset return create_node(DivNode(self, b)) def __rmod__(self, b:int): if self.min > b >= 0: return NumNode(b) if isinstance(self, NumNode): return NumNode(b % self.b) raise RuntimeError(f"not supported: {b} % {self}") def __mod__(self, b:Union[Node,int]): if isinstance(b, Node): if b.__class__ is NumNode: return self % b.b if self == b: return NumNode(0) if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node raise RuntimeError(f"not supported: {self} % {b}") assert b > 0 if b == 1: return NumNode(0) if self.min >= 0 and self.max < b: return self if (self.min//b) == (self.max//b): return self - (b*(self.min//b)) if self.min < 0: return (self - ((self.min//b)*b)) % b return create_node(ModNode(self, b)) @staticmethod def num(num:int) -> NumNode: return NumNode(num) @staticmethod def factorize(nodes:List[Node]) -> List[Node]: mul_groups: Dict[Node, int] = {} for x in nodes: a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1) mul_groups[a] = mul_groups.get(a, 0) + b return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0] @staticmethod def sum(nodes:List[Node]) -> Node: nodes = [x for x in nodes if x.max or x.min] if not nodes: return NumNode(0) if len(nodes) == 1: return nodes[0] new_nodes: List[Node] = [] num_node_sum = 0 for node in SumNode(nodes).flat_components: if node.__class__ is NumNode: num_node_sum += node.b else: new_nodes.append(node) if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes): new_nodes = Node.factorize(new_nodes) if num_node_sum: new_nodes.append(NumNode(num_node_sum)) return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0) @staticmethod def ands(nodes:List[Node]) -> Node: if not nodes: return NumNode(1) if len(nodes) == 1: return nodes[0] if any(not x for x in nodes): return NumNode(0) # filter 1s nodes = [x for x in nodes if x.min != x.max] return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1)) # 4 basic node types class Variable(Node): def __new__(cls, expr:Optional[str], nmin:int, nmax:int): assert nmin >= 0 and nmin <= nmax if nmin == nmax: return NumNode(nmin) return super().__new__(cls) def __init__(self, expr:Optional[str], nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax self.val:Optional[int] = None def bind(self, val): assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}" self.val = val return self def unbind(self) -> Tuple[Variable, int]: assert self.val is not None, f"cannot unbind {self}" return Variable(self.expr, self.min, self.max), self.val def vars(self): return [self] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self class NumNode(Node): def __init__(self, num:int): assert isinstance(num, int), f"{num} is not an int" self.b:int = num self.min, self.max = num, num def bind(self, val): assert self.b == val, f"cannot bind {val} to {self}" return self def __eq__(self, other): return self.b == other def __hash__(self): return self.hash # needed with __eq__ override def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self def create_node(ret:Node): assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}" if ret.min == ret.max: return NumNode(ret.min) return ret class OpNode(Node): def __init__(self, a:Node, b:Union[Node, int]): self.a, self.b = a, b self.min, self.max = self.get_bounds() def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else []) @abstractmethod def get_bounds(self) -> Tuple[int, int]: pass class LtNode(OpNode): def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b) def get_bounds(self) -> Tuple[int, int]: if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) class MulNode(OpNode): def __lt__(self, b: Union[Node, int]): if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b) sgn = 1 if self.b > 0 else -1 return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b)) def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right if self.b % b == 0: return self.a*(self.b//b) if b % self.b == 0 and self.b > 0: return self.a//(b//self.b) return Node.__floordiv__(self, b, factoring_allowed) def __mod__(self, b: Union[Node, int]): a = (self.a * (self.b%b)) return Node.__mod__(a, b) def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) class DivNode(OpNode): def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 and isinstance(self.b, int) return self.a.min//self.b, self.a.max//self.b def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) // self.b class ModNode(OpNode): def __mod__(self, b: Union[Node, int]): if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b) return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b) def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod return Node.__floordiv__(self, b, factoring_allowed) def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 and isinstance(self.b, int) return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b class RedNode(Node): def __init__(self, nodes:List[Node]): self.nodes = nodes def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, []) class SumNode(RedNode): @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): fully_divided: List[Node] = [] rest: List[Node] = [] if isinstance(b, SumNode): nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode) de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode) if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b if isinstance(b, Node): for x in self.flat_components: if x % b == 0: fully_divided.append(x // b) else: rest.append(x) if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b return Node.__floordiv__(self, b, False) if b == 1: return self if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed) fully_divided, rest = [], [] _gcd = b divisor = 1 for x in self.flat_components: if x.__class__ in (NumNode, MulNode): if x.b%b == 0: fully_divided.append(x//b) else: rest.append(x) _gcd = gcd(_gcd, x.b) if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b else: rest.append(x) _gcd = 1 if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd) if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor) return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def __mod__(self, b: Union[Node, int]): if isinstance(b, SumNode): nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode) de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode) if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node new_nodes: List[Node] = [] for x in self.nodes: if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b)) elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b)) else: new_nodes.append(x) return Node.__mod__(Node.sum(new_nodes), b) def __lt__(self, b:Union[Node,int]): lhs: Node = self if isinstance(b, int): new_sum = [] for x in self.nodes: # TODO: should we just force the last one to always be the number if isinstance(x, NumNode): b -= x.b else: new_sum.append(x) lhs = Node.sum(new_sum) nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs] muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b) if muls: # NOTE: gcd in python 3.8 takes exactly 2 args mul_gcd = b for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here all_others = Variable.sum(others) if all_others.min >= 0 and all_others.max < mul_gcd: lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd return Node.__lt__(lhs, b) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes]) @property def flat_components(self): # recursively expand sumnode components new_nodes = [] for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x]) return new_nodes class AndNode(RedNode): def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes]) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: subed = [] for node in self.nodes: if not (sub:=node.substitute(var_vals)): return NumNode(0) subed.append(sub) return Variable.ands(subed) def create_rednode(typ:Type[RedNode], nodes:List[Node]): ret = typ(nodes) if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes])) elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes])) return create_node(ret) @functools.lru_cache(maxsize=None) def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}" def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx) def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int: if isinstance(a, (int, float)): return a ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()}) assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}" return ret.b # symbolic int sint = Union[Node, int] VariableOrNum = Union[Variable, NumNode] render_python: Dict[Type, Callable] = { Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"), NumNode: lambda self,ops,ctx: f"{self.b}", MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})", ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})", LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})", SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" }