from typing import Any , Callable
import itertools , inspect , functools , types
from tinygrad . helpers import partition , dedup , Context
from tinygrad . uop . ops import UPat , UPatAny , UOp , Ops , PatternMatcher , graph_rewrite , deconstruct_function
class UPatCompileError ( Exception ) : pass
# **** UPat compiled ****
def _get_clause ( self : UPat , base : UOp , depth = 0 ) - > UOp :
if isinstance ( self , UPatAny ) :
assert len ( self . src ) == 1
return UOp ( Ops . AND , src = ( UOp ( Ops . OR , src = tuple ( _get_clause ( s , base , depth ) for s in self . src [ 0 ] ) ) , ) )
# build the and_clause for acceptance
and_clause : list [ UOp ] = [ ]
if self . op is not None :
if len ( self . op ) > 1 : and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , UOp ( Ops . BIND , arg = tuple ( int ( x ) for x in self . op ) ) ) , arg = " {0} .op in {1} " ) )
else : and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , ) , arg = " {0} .op == " + str ( self . op [ 0 ] . value ) ) )
if self . arg is not None :
if isinstance ( self . arg , int ) : and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , ) , arg = " {0} .arg == " + str ( int ( self . arg ) ) ) )
else : and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , UOp ( Ops . BIND , arg = self . arg ) ) , arg = " {0} .arg == {1} " ) )
if self . strict_length or self . required_len > 0 :
and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , ) , arg = ( " len( {0} .src) " + ( " == " if self . strict_length else " >= " ) + str ( self . required_len ) ) ) )
if self . name is not None : and_clause . append ( UOp ( Ops . STORE , src = ( UOp ( Ops . DEFINE_VAR , arg = self . name ) , base ) ) )
if self . dtype is not None :
if len ( self . dtype ) > 1 :
and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , UOp ( Ops . BIND , arg = tuple ( self . dtype ) ) ) , arg = " ( {0} .dtype in {1} or {0} .dtype._scalar in {1} ) " ) )
else : and_clause . append ( UOp ( Ops . CUSTOM , src = ( base , UOp ( Ops . BIND , arg = self . dtype [ 0 ] ) ) , arg = " ( {0} .dtype == {1} or {0} .dtype._scalar == {1} ) " ) )
if self . src is not None :
# single match
if len ( self . src ) == 1 and isinstance ( self . src [ 0 ] , tuple ) :
and_clause + = [ _get_clause ( s , base . gep ( i ) , depth ) for i , s in enumerate ( self . src [ 0 ] ) ]
# repeat match
elif len ( self . src ) == 1 and isinstance ( self . src [ 0 ] , itertools . repeat ) :
it = UOp ( Ops . NOOP , arg = f " ituop { depth } " )
match = _get_clause ( next ( self . src [ 0 ] ) , it , depth + 1 )
and_clause . append ( UOp ( Ops . RANGE , src = ( match , it , base ) , arg = " all([ {0} for {1} in {2} .src]) " ) )
# multi match (fork)
elif len ( self . src ) > 1 and all ( isinstance ( x , tuple ) for x in self . src ) :
fork_cond = [ UOp ( Ops . AND , src = tuple ( [ _get_clause ( s , base . gep ( i ) , depth ) for i , s in enumerate ( ss ) ] ) ) for ss in self . src ]
and_clause . append ( UOp ( Ops . OR , src = tuple ( fork_cond ) ) )
else : raise RuntimeError ( " broken " )
return UOp ( Ops . AND , src = tuple ( and_clause ) )
# *** pattern matcher ***
def do_process_and ( a : UOp ) - > UOp | None :
found = False
new_src : list [ UOp ] = [ ]
or_clause : list [ UOp ] = [ ]
# remove any nested ANDs, extract or clauses
for x in a . src :
if x . op is Ops . AND :
new_src . extend ( x . src )
found = True
elif x . op is Ops . OR : or_clause . append ( x )
else : new_src . append ( x )
# too big to compile
if len ( or_clause ) > = 4 : raise UPatCompileError ( " too big to compile " )
# one or clause max
if len ( or_clause ) > 1 :
# need the product of the or clauses
or_clause = [ UOp ( Ops . OR , src = tuple ( [ UOp ( Ops . AND , src = x ) for x in itertools . product ( * [ x . src for x in or_clause ] ) ] ) ) ]
found = True
# handle stores
stores , new_src = partition ( new_src , lambda x : x . op is Ops . STORE )
if len ( stores ) :
if len ( or_clause ) :
# push stores to the top if we have an or_clause
assert len ( or_clause ) == 1 and all ( x . op is Ops . AND for x in or_clause [ 0 ] . src )
or_clause = [ UOp ( Ops . OR , src = tuple ( [ x . replace ( src = x . src + tuple ( stores ) ) for x in or_clause [ 0 ] . src ] ) ) ]
found = True
else :
# check for duplicate stores
dict_stores : dict [ UOp , UOp ] = { }
for a in stores :
if a . src [ 0 ] in dict_stores :
# duplicate store is a compare
new_src . append ( UOp ( Ops . CMPNE , src = ( dict_stores [ a . src [ 0 ] ] , a . src [ 1 ] ) ) )
found = True
else :
dict_stores [ a . src [ 0 ] ] = a . src [ 1 ]
# put the stores back
for k , v in dict_stores . items ( ) : new_src . append ( UOp ( Ops . STORE , src = ( k , v ) ) )
# reassemble, if there's any deduping to do, do it
if len ( dretand := dedup ( new_src + or_clause ) ) != len ( new_src ) + len ( or_clause ) : found = True
return UOp ( Ops . AND , src = tuple ( dretand ) ) if found else None
# processor
pm_proc = PatternMatcher ( [ ( UPat ( Ops . AND , name = " a " ) , do_process_and ) ] , compiled = False )
# renderer
def wrap ( ctx , x ) - > UOp :
ctx [ ret := f " a { len ( ctx ) } " ] = x . arg
return UOp ( Ops . NOOP , arg = ret )
pm_renderer = PatternMatcher ( [
( UPat ( Ops . BIND , name = " x " ) , wrap ) ,
# CMPNE is actually equal
( UPat ( Ops . CMPNE , name = " x " ) , lambda x : UOp ( Ops . CUSTOM , src = x . src , arg = " {0} is {1} " ) ) ,
# RANGE can't have OR inside it
( UPat ( Ops . RANGE , src = ( UPat ( Ops . AND , src = UPat ( Ops . NOOP ) , name = " x " ) , UPat ( ) , UPat ( ) ) , name = " r " ) ,
lambda r , x : r . replace ( op = Ops . CUSTOM , src = ( UOp ( Ops . NOOP , arg = " ( " + ' and ' . join ( y . arg for y in x . src ) + " ) " ) , ) + r . src [ 1 : ] ) ) ,
( UPat ( Ops . CUSTOM , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = x . arg . format ( * [ y . arg for y in x . src ] ) ) ) ,
( UPat ( Ops . GEP , src = UPat ( Ops . NOOP , name = " x " ) , name = " g " ) , lambda x , g : x . replace ( arg = x . arg + f " .src[ { g . arg [ 0 ] } ] " ) )
] , compiled = False )
def _final_render ( x : UOp , has_ctx : bool , depth = 1 ) - > list [ str ] :
assert x . op is Ops . AND
and_pieces , store_pieces = [ ] , [ ]
or_pieces : list [ str ] = [ ]
for s in x . src :
if s . op is Ops . OR :
assert len ( or_pieces ) == 0 and len ( s . src ) > = 1
for ss in s . src : or_pieces . extend ( _final_render ( ss , has_ctx , depth + 1 ) )
elif s . op is Ops . STORE :
assert s . src [ 0 ] . op is Ops . DEFINE_VAR and s . src [ 1 ] . op is Ops . NOOP
store_pieces . append ( f " { s . src [ 0 ] . arg } = { s . src [ 1 ] . arg } " )
elif s . op is Ops . NOOP : and_pieces . append ( s . arg )
else : raise UPatCompileError ( f " can ' t compile this { s } " )
# if we have an or, render it
if len ( or_pieces ) :
assert len ( store_pieces ) == 0
and_clause = ' and ' . join ( and_pieces )
return [ f " { ' ' * depth } if { and_clause if len ( and_clause ) else ' True ' } : " ] + or_pieces
# if we don't, this is a final return
store_clause = ' , ' . join ( ( [ " ctx=ctx " ] if has_ctx else [ ] ) + store_pieces )
and_clause = ' and ' . join ( and_pieces + [ f " (_ret:=_fxn( { store_clause } )) is not None " ] )
return [ f " { ' ' * depth } if { and_clause } : return _ret " ]
def _get_code ( self : UPat , has_ctx : bool ) :
ret = _get_clause ( self , UOp ( Ops . NOOP , arg = " uop " ) )
try :
# TODO: this should be tracked in a "system" rewrite, not untracked or tracked with kernel
with Context ( TRACK_MATCH_STATS = 0 ) :
ret = graph_rewrite ( ret , pm_proc , name = " process UPat " )
dyn_lookup : dict [ str , Any ] = { }
out = graph_rewrite ( ret , pm_renderer , ctx = dyn_lookup , name = " compile UPat " )
rendered = _final_render ( out , has_ctx )
except UPatCompileError :
#print("FAILED", self, self.location)
return None
return ' \n ' . join ( [ f " # match for { self . location } " , " def compiled_match(uop, ctx): " ] + rendered + [ " return None " ] ) , dyn_lookup
@functools . cache
def upat_compile ( self : UPat , fxn ) - > Callable | None :
real_fxn = types . FunctionType ( * deconstruct_function ( fxn ) )
code = _get_code ( self , ' ctx ' in inspect . signature ( real_fxn ) . parameters )
if code is None : return None
code_str , dyn_lookup = code
globs = dyn_lookup . copy ( )
globs [ " _fxn " ] = real_fxn
namespace : dict = { }
exec ( code_str , globs , namespace ) # pylint: disable=W0122
return namespace [ " compiled_match " ]