openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

78 lines
4.0 KiB

from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.ops import PatternMatcher, graph_rewrite, UOp
from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing
from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
@dataclass
class RewriteStep:
pm: PatternMatcher
ctx: Callable[[UOp], Any]|None = None
name: str|None = None
bottom_up: bool = False
def __call__(self, sink:UOp):
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, lambda ast: get_index(ast, opts), name="lowerer"))
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
# ignore (for masked stores)
ret.append(RewriteStep(pm_store_ignore, name="store_ignore"))
ret.append(RewriteStep(pm_move_ignore, name="move_ignore"))
# expand + remove surviving ignores
ret.append(RewriteStep(pm_delete_ignore+sym+expander, name="expander"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
# devectorize (TODO: does this need opts?)
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
supported_ops = tuple(opts.code_for_op.keys())
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
# final rules for the renderer (without sym)
pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)+pm_render+extra_matcher
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts, name="final rewrite"))
# ** linearizer **
if linearizer:
ret.append(RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True))
ret.append(RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"))
ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks"))
ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize"))
return ret
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)