You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
4.0 KiB
78 lines
4.0 KiB
from typing import Any, Callable
|
|
import functools
|
|
from dataclasses import dataclass
|
|
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
|
from tinygrad.ops import PatternMatcher, graph_rewrite, UOp
|
|
from tinygrad.renderer import Renderer
|
|
|
|
# import all pattern matchers here
|
|
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
|
|
from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing
|
|
from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander
|
|
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
|
|
pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns
|
|
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
|
|
|
@dataclass
|
|
class RewriteStep:
|
|
pm: PatternMatcher
|
|
ctx: Callable[[UOp], Any]|None = None
|
|
name: str|None = None
|
|
bottom_up: bool = False
|
|
def __call__(self, sink:UOp):
|
|
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
|
|
|
|
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
|
|
|
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
|
|
# cache with the values of the context vars
|
|
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
|
|
|
@functools.cache
|
|
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
|
# ** lowerer (rewrite_shapetracker_with_index) **
|
|
ret: list[RewriteStep] = []
|
|
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
|
ret.append(RewriteStep(pm_lowerer, lambda ast: get_index(ast, opts), name="lowerer"))
|
|
|
|
# ** expander (expand_rewrite) **
|
|
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
|
|
|
# ignore (for masked stores)
|
|
ret.append(RewriteStep(pm_store_ignore, name="store_ignore"))
|
|
ret.append(RewriteStep(pm_move_ignore, name="move_ignore"))
|
|
|
|
# expand + remove surviving ignores
|
|
ret.append(RewriteStep(pm_delete_ignore+sym+expander, name="expander"))
|
|
|
|
# ** devectorizer (full_graph_rewrite) **
|
|
# remove reduce
|
|
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
|
|
|
# devectorize (TODO: does this need opts?)
|
|
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
|
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
|
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
|
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
|
|
|
|
supported_ops = tuple(opts.code_for_op.keys())
|
|
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
|
|
|
# optional pre matcher
|
|
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
|
|
|
# final rules for the renderer (without sym)
|
|
pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)+pm_render+extra_matcher
|
|
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts, name="final rewrite"))
|
|
|
|
# ** linearizer **
|
|
if linearizer:
|
|
ret.append(RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True))
|
|
ret.append(RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"))
|
|
ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks"))
|
|
ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize"))
|
|
return ret
|
|
|
|
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
|
|
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
|
|
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
|
|
|