openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
7.9 KiB

from __future__ import annotations
import math, itertools, functools
from typing import List, Dict, Callable, Type, Union
from tinygrad.helpers import partition, all_same
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def create_node(typ:Type[Node], *args):
ret = typ(*args)
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {typ} {args}"
if ret.min == ret.max: return NumNode(ret.min)
return ret
class Node:
b: int
min: int
max: int
def render(self, ops=None, ctx=None) -> str:
if ops is None: ops = render_python
assert isinstance(self, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
def __repr__(self): return "<"+self.key+">"
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __sub__(self, b:Union[Node, int]): return self+-b
def __ge__(self, b:int): return create_node(GeNode, self, b)
def __lt__(self, b:int): return create_node(LtNode, self, b)
def __mul__(self, b:int):
if b == 0: return NumNode(0)
elif b == 1: return self
if isinstance(self, MulNode): return self.a*(self.b*b) # two muls is one mul
if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum
return create_node(MulNode, self, b)
# *** complex ops ***
def __floordiv__(self, b:int):
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
if isinstance(self, SumNode):
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
nofactor = []
# ugh, i doubt this is universally right
for x in tmp_nofactor:
if isinstance(x, NumNode):
if (x.b%b) != x.b:
factors.append(Variable.num(x.b - (x.b%b))) # python does floor division
nofactor.append(Variable.num(x.b%b))
else:
nofactor.append(x)
gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor]
if len(factors) > 0:
# these don't have to be the same, just having a common factor
if len(gcd) > 0 and all_same(gcd) and gcd[0] is not None and gcd[0] > 1:
nofactor_term = Variable.sum([(x.a * (x.b//gcd[0])) if isinstance(x, MulNode) else Variable.num(x.b//gcd[0]) for x in nofactor])//(b//gcd[0])
else:
nofactor_term = Variable.sum(nofactor)//b
return Variable.sum([(x.a * (x.b//b)) if isinstance(x, MulNode) else Variable.num(x.b//b) for x in factors] + [nofactor_term])
else:
muls = [x.b for x in nofactor if isinstance(x, MulNode)]
for m in muls:
if m > 1 and b%m == 0:
return (self//m)//(b//m)
if self.min < 0:
offset = self.min//b
return (self+offset*b)//b - offset
return create_node(DivNode, self, b)
def __mod__(self, b:int):
assert b > 0
if b == 1: return NumNode(0)
if isinstance(self, SumNode):
new_nodes = []
for x in self.nodes:
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
a = Variable.sum(new_nodes)
elif isinstance(self, MulNode):
a = self.a * (self.b%b)
else:
a = self
if a.min >= 0 and a.max < b: return a
if a.min < 0: return (a + ((a.min//b)*b)) % b
return create_node(ModNode, a, b)
@staticmethod
def num(num:int) -> Node: return NumNode(num)
@staticmethod
def sum(nodes:List[Node]) -> Node:
# expand any sums inside one sum
if any([isinstance(x, SumNode) for x in nodes]):
nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode))
for x in sum_nodes: nodes += x.nodes
return Variable.sum(nodes)
# combine any numbers inside a sum
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
nodes.append(NumNode(sum([x.b for x in num_nodes])))
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
nodes, mul_nodes = partition(nodes, lambda x: not isinstance(x, MulNode))
mul_nodes += [MulNode(x, 1) for x in nodes]
mul_nodes = sorted(mul_nodes, key=lambda x: x.a.render()) # group by equality (ugh, uses render!)
new_nodes = [k * sum(x.b for x in g) for k, g in itertools.groupby(mul_nodes, key=lambda x: x.a)]
nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in new_nodes]
# filter 0s
nodes = [x for x in nodes if x.min != 0 or x.max != 0]
return create_node(SumNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0))
@staticmethod
def ands(nodes:List[Node]) -> Node:
if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return create_node(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:str, nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr:str, nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
class NumNode(Node):
def __init__(self, num:int):
self.b, self.min, self.max = num, num, num
class OpNode(Node):
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = self.minmax(a,b)
minmax = staticmethod(lambda a,b: (1//0, 1//0))
class RedNode(Node):
def __init__(self, nodes:List[Node]):
self.nodes = nodes
self.min, self.max = self.minmax(nodes)
minmax = staticmethod(lambda nodes: (1//0, 1//0))
# operation nodes
class GeNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.min >= b), int(a.max >= b)))
class LtNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.max < b), int(a.min < b)))
class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b) if b >= 0 else (a.max*b, a.min*b))
class DivNode(OpNode):
@staticmethod
def minmax(a, b):
assert a.min >= 0
return a.min//b, a.max//b
class ModNode(OpNode):
@staticmethod
def minmax(a, b):
assert a.min >= 0
if a.max - a.min >= b or (a.min != a.max and a.min%b >= a.max%b): return (0, b-1)
return a.min%b, a.max%b
# reduce nodes
class SumNode(RedNode): minmax = staticmethod(lambda nodes: (sum([x.min for x in nodes]), sum([x.max for x in nodes])))
class AndNode(RedNode): minmax = staticmethod(lambda nodes: (min([x.min for x in nodes]), max([x.max for x in nodes])))
render_python : Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}<{self.min},{self.max}>" if ctx == "DEBUG" else f"{self.expr}",
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
GeNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}>={self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
}