from __future__ import annotations
import heapq
from collections import defaultdict
from dataclasses import dataclass , replace
from tinygrad . ops import UOp , Ops , PatternMatcher , UPat , GroupOp
from tinygrad . helpers import dedup , partition , all_same , flatten
from tinygrad . spec import type_verify
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
def block_reorder ( lst : list [ UOp ] ) - > list [ UOp ] :
in_this_block = set ( lst )
local_children : defaultdict [ UOp , list [ UOp ] ] = defaultdict ( list )
in_degree : dict [ UOp , int ] = { }
priorities : dict [ UOp , int ] = { }
# get local children and assign priorities
# NOTE: this requires the lst be locally toposorted
for u in reversed ( lst ) :
in_degree [ u ] = 0
for s in u . src :
if s in in_this_block :
local_children [ s ] . append ( u )
in_degree [ u ] + = 1
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
priority = [ 0 ] + [ priorities [ x ] for x in local_children [ u ] ]
if u . op is Ops . LOAD : priority . append ( - 1000 )
if u . op is Ops . BARRIER : priority . append ( - 1500 )
priorities [ u ] = min ( priority )
# number the uops in "ideal" order
nkey = { u : i for i , u in enumerate ( sorted ( lst , key = lambda x : ( priorities [ x ] , ) + x . tuplize ) ) }
# then force then to be toposorted in as close to the ideal order as possible
heapq . heapify ( heap := [ ( nkey [ u ] , u ) for u in lst if in_degree [ u ] == 0 ] )
newlst = [ ]
while heap :
newlst . append ( u := heapq . heappop ( heap ) [ 1 ] )
for v in local_children [ u ] :
in_degree [ v ] - = 1
if in_degree [ v ] == 0 : heapq . heappush ( heap , ( nkey [ v ] , v ) )
assert len ( newlst ) == len ( lst ) , f " len mismatch { len ( newlst ) } != { len ( lst ) } "
return newlst
# ***** basic block *****
def disp ( y : UOp ) - > str :
if y . op is Ops . IF : return f ' IF { id ( y ) } '
if y . op is Ops . RANGE : return str ( y . arg )
return " <NONE> "
@dataclass ( frozen = True , eq = False )
class BasicBlock2 :
lst : tuple [ UOp , . . . ]
ctx : tuple [ UOp , . . . ] = ( )
end : UOp | None = None
cnt : int = 0
child_ctx : tuple [ UOp , . . . ] | None = None
def __lt__ ( self , _ : BasicBlock2 ) : raise RuntimeError ( " no comparing basic blocks " )
def __repr__ ( self ) :
return f " { ( str ( disp ( self . end ) ) + ' ' ) if self . end is not None else ' ' } " + f ' f { self . cnt } ' + \
f " { [ disp ( y ) for y in self . ctx ] } { [ disp ( y ) for y in self . child_ctx ] if self . child_ctx is not None else ' - ' } " + \
f " { len ( self . lst ) } " + " \n " + ' \n ' . join ( [ str ( x . op ) for x in self . lst ] )
def last_ctx ( self ) : return self . child_ctx if self . child_ctx is not None else self . ctx
def _sort_ctx ( inp ) : return tuple ( sorted ( dedup ( inp ) , key = lambda x : x . tuplize ) )
# ***** block context *****
@dataclass
class BlockContext :
child_count : dict [ UOp , int ]
block_ctxs : dict [ UOp , tuple [ UOp , . . . ] ]
child_ctxs : dict [ UOp , tuple [ UOp , . . . ] ]
def last_ctx ( self , u ) : return ret if ( ret := self . child_ctxs . get ( u ) ) is not None else self . block_ctxs [ u ]
@staticmethod
def from_sink ( sink : UOp ) - > BlockContext :
# get children and all block contexts
ctx = BlockContext ( { } , { } , { } )
for u in sink . toposort ( ) :
this_block_ctx : list [ UOp ] = [ ]
ctx . child_count [ u ] = 0
# get children and accumulate the last_ctx
for s in u . src :
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
ctx . child_count [ s ] + = 1
this_block_ctx + = ctx . last_ctx ( s )
# save the block ctx
ctx . block_ctxs [ u ] = _sort_ctx ( this_block_ctx )
# RANGE/IF add to the next ctx
# STORE/ASSIGN subtract from the next ctx
if u . op in { Ops . RANGE , Ops . IF } : ctx . child_ctxs [ u ] = _sort_ctx ( ctx . block_ctxs [ u ] + ( u , ) )
elif u . op is Ops . STORE :
# ugh, deal with non-reduce locals. probably wrong
if any ( x . op is Ops . DEFINE_LOCAL for x in u . src [ 0 ] . toposort ( ) ) :
idx_context , store_context = ctx . last_ctx ( u . src [ 0 ] ) , ctx . last_ctx ( u . src [ 1 ] )
ctx . child_ctxs [ u ] = tuple ( [ y for y in store_context if y not in idx_context and y . op is Ops . RANGE ] )
else : ctx . child_ctxs [ u ] = ( )
elif u . op is Ops . ASSIGN :
assert u . src [ 0 ] . op is Ops . DEFINE_ACC
ctx . child_ctxs [ u ] = tuple ( [ y for y in ctx . last_ctx ( u . src [ 1 ] ) if y not in u . src [ 0 ] . src [ 1 : ] ] )
return ctx
# ***** make blocks *****
DONT_PLACE_IN_BLOCK = { Ops . DEFINE_GLOBAL , Ops . DEFINE_LOCAL , Ops . DEFINE_VAR , Ops . SPECIAL , Ops . CONST }
def add_blockends ( base_block : UOp , new_ctx : tuple [ UOp , . . . ] , current_ctx : tuple [ UOp , . . . ] , cnt : int = 1 ) - > UOp :
ends_to_add = [ z for z in new_ctx if z not in current_ctx ]
while len ( ends_to_add ) :
r : UOp = ends_to_add . pop ( - 1 )
new_ctx = tuple ( [ z for z in new_ctx if z is not r ] )
end_uop = UOp ( Ops . ENDIF if r . op is Ops . IF else Ops . ENDRANGE , src = ( r , ) )
base_block = UOp ( Ops . BLOCKEND , src = ( base_block , ) * cnt , arg = BasicBlock2 ( ( end_uop , ) , tuple ( new_ctx ) , end = r , cnt = cnt ) )
return base_block
def make_block_bottom_up ( ctx : BlockContext , x : UOp ) :
if x . op is Ops . BLOCKSTART :
current_ctx , child_ctx = x . arg
lst = list ( x . src )
child_count = 1
else :
current_ctx , child_count , child_ctx = ctx . block_ctxs [ x ] , ctx . child_count [ x ] , ctx . child_ctxs . get ( x , None )
lst = [ x ]
# count of times we've seen this block, or a seed for a new block if we can't merge it
unmergable : defaultdict [ UOp , int ] = defaultdict ( int )
blockseeds = defaultdict ( list )
# add the srcs of this to the frontier
# NOTE: things may be in here multiple times, that's okay
frontier_nodes = list ( flatten ( y . src [ : : - 1 ] for y in lst ) )
while len ( frontier_nodes ) :
u = frontier_nodes . pop ( 0 )
if u . op not in DONT_PLACE_IN_BLOCK and ctx . child_count [ u ] == unmergable [ u ] + 1 :
# count is correct
if ( newctx := ctx . block_ctxs [ u ] ) == current_ctx :
# block has same context, merge it, and put the srcs on the frontier
lst . append ( u )
frontier_nodes . extend ( u . src [ : : - 1 ] )
else :
# block has different context, add it to blockseeds
blockseeds [ ( newctx , ctx . child_ctxs . get ( u , None ) ) ] . append ( u )
del unmergable [ u ]
else :
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
unmergable [ u ] + = 1
# add unmergables to sources
srcs = [ ]
for u , cnt in unmergable . items ( ) : srcs + = [ add_blockends ( u , ctx . block_ctxs [ u ] , current_ctx , cnt = cnt ) ] * cnt
# add blockseeds, with blockends as needed
for ( new_ctx , new_child_ctx ) , v in blockseeds . items ( ) :
base_block = UOp ( Ops . BLOCKSTART , src = tuple ( v ) , arg = ( new_ctx , new_child_ctx ) )
srcs . append ( add_blockends ( base_block , new_ctx , current_ctx ) )
lst = block_reorder ( lst [ : : - 1 ] )
bb = BasicBlock2 ( tuple ( lst ) , ctx = current_ctx , cnt = child_count , child_ctx = child_ctx )
return UOp ( Ops . BLOCK , src = tuple ( srcs ) , arg = bb )
block_create = PatternMatcher ( [
( UPat ( GroupOp . All - DONT_PLACE_IN_BLOCK . union ( { Ops . BLOCK , Ops . BLOCKEND } ) , name = " x " ) , make_block_bottom_up ) ,
] )
# ***** blockend merging ****
def merge_blockends ( sink : UOp ) - > UOp | None :
# only run on the final BLOCK with the SINK in it
if sink . arg . lst [ - 1 ] . op is not Ops . SINK : return None
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
blockends_to_arg : dict [ UOp , list [ UOp ] ] = { }
for be in sink . toposort ( ) :
if be . op is Ops . BLOCKEND : blockends_to_arg . setdefault ( be . arg . end , [ ] ) . append ( be )
new_forks = { }
for k , v in blockends_to_arg . items ( ) :
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
if len ( v ) > 1 :
bb = BasicBlock2 ( v [ 0 ] . arg . lst , _sort_ctx ( flatten ( [ y . arg . ctx for y in v ] ) ) , k , cnt = sum ( y . arg . cnt for y in v ) )
out = UOp ( Ops . BLOCKEND , src = tuple ( flatten ( [ x . src for x in v ] ) ) , arg = bb )
# NOTE: bb.ctx != u.arg.ctx can cause problems here
for u in v : new_forks [ u ] = out
if len ( new_forks ) == 0 : return None
return sink . substitute ( new_forks )
pm_blockend_merge = PatternMatcher ( [ ( UPat ( Ops . BLOCK , name = " sink " ) , merge_blockends ) ] )
# ***** block merging ****
def merge_block ( x : UOp ) :
unmergable_blocks , mergable_blocks = [ ] , [ ]
mergable_dict : defaultdict [ UOp , int ] = defaultdict ( int )
for y in x . src :
if y . op is Ops . BLOCK and x . op is Ops . BLOCK and x . arg . ctx == y . arg . ctx : mergable_dict [ y ] + = 1
elif y . op is Ops . BLOCK and x . op is Ops . BLOCKEND and x . arg . end in y . arg . ctx : mergable_dict [ y ] + = 1
else : unmergable_blocks . append ( y )
for k , v in mergable_dict . items ( ) :
if v == k . arg . cnt : mergable_blocks . append ( k )
else : unmergable_blocks . extend ( [ k ] * v )
if len ( mergable_blocks ) == 0 : return None
del mergable_dict
# create the block
arg = replace ( x . arg , lst = tuple ( flatten ( [ y . arg . lst for y in mergable_blocks ] ) ) + x . arg . lst )
return UOp ( x . op , src = tuple ( flatten ( [ y . src for y in mergable_blocks ] ) + unmergable_blocks ) , arg = arg )
def remove_blockend ( x : UOp ) :
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
if any ( x . arg . end in y . arg . ctx for y in x . src if y . op in { Ops . BLOCK , Ops . BLOCKEND } ) : return None
parent_blocks = [ y for y in x . src if y . op is Ops . BLOCK and y . arg . child_ctx is not None and x . arg . end in y . arg . child_ctx ]
assert all_same ( parent_blocks ) , f " should never have two parent blocks (has { len ( parent_blocks ) } ) "
if len ( parent_blocks ) > 0 :
parent_block = parent_blocks [ 0 ]
assert len ( parent_blocks ) == parent_block . arg . cnt
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
early_ops , late_ops = partition ( x . arg . lst , lambda y : y . op is Ops . DEFINE_ACC and x . arg . end in y . src )
# NOTE: we have to add a barrier at the start if barrier is used in the range
if x . op is Ops . BLOCKEND and any ( y . op is Ops . BARRIER for y in late_ops ) and late_ops [ - 1 ] . op is Ops . ENDRANGE :
late_ops = [ UOp ( Ops . BARRIER ) ] + late_ops
arg = BasicBlock2 ( tuple ( early_ops ) + parent_block . arg . lst + tuple ( late_ops ) , tuple ( [ y for y in x . arg . ctx if y is not x . arg . end ] ) , cnt = x . arg . cnt )
return UOp ( Ops . BLOCK , src = tuple ( y for y in x . src if y is not parent_block ) + parent_block . src , arg = arg )
block_merge = PatternMatcher ( [
( UPat ( ( Ops . BLOCK , Ops . BLOCKEND ) , name = " x " ) , merge_block ) ,
( UPat ( Ops . BLOCKEND , name = " x " ) , remove_blockend ) ,
] )
# ****** finalize ******
def finalize ( sink : UOp ) - > UOp :
if sink . op is not Ops . BLOCK or not all ( x . op in DONT_PLACE_IN_BLOCK for x in sink . src ) :
raise RuntimeError ( " linearize failure " )
# place the early things
lst = sorted ( dedup ( sink . src ) , key = lambda x : x . tuplize ) + list ( sink . arg . lst )
if __debug__ : type_verify ( lst )
return UOp ( Ops . BLOCKFINAL , arg = BasicBlock2 ( tuple ( lst ) ) )
pm_finalize = PatternMatcher ( [ ( UPat ( Ops . BLOCK , name = " sink " ) , finalize ) ] )