from typing import Any, Callable import functools from dataclasses import dataclass from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL from tinygrad.ops import PatternMatcher, graph_rewrite, UOp from tinygrad.renderer import Renderer # import all pattern matchers here from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \ pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext @dataclass class RewriteStep: pm: PatternMatcher ctx: Callable[[UOp], Any]|None = None name: str|None = None bottom_up: bool = False def __call__(self, sink:UOp): return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up) def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]: # cache with the values of the context vars return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value) @functools.cache def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: # ** lowerer (rewrite_shapetracker_with_index) ** ret: list[RewriteStep] = [] if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) ret.append(RewriteStep(pm_lowerer, lambda ast: get_index(ast, opts), name="lowerer")) # ** expander (expand_rewrite) ** ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic")) # ignore (for masked stores) ret.append(RewriteStep(pm_store_ignore, name="store_ignore")) ret.append(RewriteStep(pm_move_ignore, name="move_ignore")) # expand + remove surviving ignores ret.append(RewriteStep(pm_delete_ignore+sym+expander, name="expander")) # ** devectorizer (full_graph_rewrite) ** # remove reduce ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce")) # devectorize (TODO: does this need opts?) if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize")) supported_ops = tuple(opts.code_for_op.keys()) extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([]) # optional pre matcher if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher")) # final rules for the renderer (without sym) pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)+pm_render+extra_matcher ret.append(RewriteStep(pm_final_rewrite, lambda _: opts, name="final rewrite")) # ** linearizer ** if linearizer: ret.append(RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True)) ret.append(RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends")) ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks")) ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize")) return ret def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp: return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer)) def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)