# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
from dataclasses import dataclass
import functools
from typing import Optional , Callable
from tinygrad . helpers import merge_dicts , getenv
from tinygrad . shape . view import View , strides_for_shape , unravel
from tinygrad . dtype import dtypes
from tinygrad . ops import UOp , Ops , graph_rewrite , Variable , sint , sint_to_uop , Context
from tinygrad . codegen . symbolic import sym , split_uop , symbolic_flat , uop_given_valid , simplify_valid
def overflow ( u : UOp ) : return u . vmax > dtypes . max ( dtypes . int ) or u . vmin < dtypes . min ( dtypes . int )
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
def upcast ( u : UOp ) :
srcs = tuple ( upcast ( _src ) for _src in u . src )
if u . dtype . scalar ( ) is dtypes . int :
dtype = dtypes . int64 . vec ( u . dtype . count ) if u . dtype . count > 1 else dtypes . int64
upcasted = u . replace ( dtype = dtype , src = tuple ( [ _src . cast ( dtype ) for _src in srcs ] ) )
if overflow ( u ) : return upcasted
# Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
# Cast back is required because if the node is in range, siblings would never be upcasted
if any ( ( overflow ( src ) for src in u . src ) ) : return upcasted . cast ( u . dtype )
return u . replace ( src = tuple ( srcs ) )
# pooling op may overflow before folding causing unnecessary upcast
def folded_upcast ( u : UOp ) :
with Context ( TRACK_MATCH_STATS = 0 ) :
return upcast ( graph_rewrite ( u , sym , { } ) )
@functools . lru_cache ( None )
def views_to_indexed_uops ( views : tuple [ View , . . . ] , _idxs : Optional [ tuple [ UOp , . . . ] ] = None ) - > tuple [ UOp , UOp ] :
idx , valid = views [ - 1 ] . to_indexed_uops ( _idxs )
for view in reversed ( views [ 0 : - 1 ] ) :
view = view . minify ( )
idx , valid = view . to_indexed_uops ( [ sint_to_uop ( i ) for i in unravel ( view . shape , idx ) ] , valid )
return idx , valid
@functools . lru_cache ( None )
def views_to_real_strides ( views : tuple [ View , . . . ] , ignore_valid = False ) - > tuple [ Optional [ sint ] , . . . ] :
# NOTE: if a stride is not always valid, it will be None
if len ( views ) == 1 and views [ - 1 ] . mask is None : return views [ - 1 ] . strides
ret : list [ Optional [ sint ] ] = [ None ] * len ( views [ - 1 ] . shape )
idx , valid = ( graph_rewrite ( u , symbolic_flat ) for u in views_to_indexed_uops ( views ) )
# TODO: always apply these in to_indexed_uops?
if ( newvalid := simplify_valid ( valid ) ) is not None : valid = newvalid
if ( newidx := uop_given_valid ( valid , idx ) ) is not None : idx = graph_rewrite ( newidx , symbolic_flat )
for c in split_uop ( idx , Ops . ADD ) :
if c . op is Ops . RANGE : ret [ c . arg ] = 1
if c . op is Ops . MUL and c . src [ 0 ] . op is Ops . RANGE and c . src [ 1 ] . op is Ops . CONST : ret [ c . src [ 0 ] . arg ] = c . src [ 1 ] . arg
if c . op is Ops . MUL and c . src [ 1 ] . op is Ops . RANGE and c . src [ 0 ] . op is Ops . CONST : ret [ c . src [ 1 ] . arg ] = c . src [ 0 ] . arg
used_ranges = [ x . arg for x in idx . toposort if x . op is Ops . RANGE ]
ret = [ x if i in used_ranges else 0 for i , x in enumerate ( ret ) ]
if not ignore_valid :
for masked_axis in [ x . arg for x in valid . toposort if x . op is Ops . RANGE ] : ret [ masked_axis ] = None
return tuple ( ret )
@dataclass ( frozen = True , order = True )
class ShapeTracker :
views : tuple [ View , . . . ]
def __add__ ( self , st : ShapeTracker ) - > ShapeTracker :
ret = self
for v in st . views : ret = ShapeTracker ( ret . views + ( v , ) ) . simplify ( ) # one view at a time = better simplification
return ret
def invert ( self , out_shape : tuple [ sint , . . . ] ) - > Optional [ ShapeTracker ] :
inverted_views : list [ View ] = [ ]
for v , s in zip ( self . views [ : : - 1 ] , [ x . shape for x in self . views [ : : - 1 ] [ 1 : ] ] + [ out_shape ] ) :
if ( inverted := v . invert ( s ) ) is None : return None
inverted_views . append ( inverted )
return ShapeTracker ( tuple ( inverted_views ) ) . reshape ( out_shape )
@staticmethod
def from_shape ( shape : tuple [ sint , . . . ] ) - > ShapeTracker : return ShapeTracker ( ( View . create ( shape ) , ) )
@property
def contiguous ( self ) - > bool : return len ( self . views ) == 1 and self . views [ 0 ] . contiguous
@property
def consecutive ( self ) - > bool : return len ( self . views ) == 1 and ( v := self . views [ 0 ] ) . mask is None and v . strides == strides_for_shape ( v . shape )
@property
def shape ( self ) - > tuple [ sint , . . . ] : return self . views [ - 1 ] . shape
@property
def size ( self ) - > int : return self . views [ - 1 ] . size ( )
def reduce ( self , axis : tuple [ int , . . . ] ) - > tuple [ sint , . . . ] : return tuple ( 1 if i in axis else s for i , s in enumerate ( self . shape ) )
def to_uop ( self ) - > UOp : return UOp ( Ops . VIEW , dtypes . void , ( ) , self )
def to_indexed_uops ( self , _idxs : Optional [ list [ UOp ] | tuple [ UOp , . . . ] ] = None ) - > tuple [ UOp , UOp ] :
idx , valid = views_to_indexed_uops ( self . views , tuple ( _idxs ) if _idxs is not None else None )
return folded_upcast ( idx ) , folded_upcast ( valid )
# upper bound on buffer size required to fit this shapetracker
def real_size ( self ) - > int :
if 0 in self . shape : return 0
view = ( v . shrink ( v . mask ) if ( v := self . views [ 0 ] ) . mask else v )
idx , _ = views_to_indexed_uops ( ( view , ) )
assert idx . vmax < 1e12 , f " real_size broken for { self } "
return int ( idx . vmax + 1 )
def vars ( self ) - > set [ Variable ] : return set ( ) . union ( * [ v . vars ( ) for v in self . views ] )
@property
def var_vals ( self ) - > dict [ Variable , int ] : return merge_dicts ( [ dict ( [ v . unbind ( ) ] ) for v in self . vars ( ) ] )
def unbind ( self ) - > tuple [ ShapeTracker , dict [ Variable , int ] ] :
unbound_views , var_vals = zip ( * [ v . unbind ( ) for v in self . views ] )
if all ( len ( x ) == 0 for x in var_vals ) : return self , { }
return ShapeTracker ( tuple ( unbound_views ) ) , merge_dicts ( var_vals )
def real_strides ( self , ignore_valid = False ) - > tuple [ Optional [ sint ] , . . . ] :
with Context ( TRACK_MATCH_STATS = 0 ) : return views_to_real_strides ( self . views , ignore_valid )
def unit_stride_axes ( self , ignore_valid = False ) - > list [ int ] : return [ i for i , st in enumerate ( self . real_strides ( ignore_valid ) ) if st == 1 ]
def axis_is_masked ( self , axis : int ) - > bool :
with Context ( TRACK_MATCH_STATS = 0 ) :
_ , valid = self . to_indexed_uops ( )
return axis in [ x . arg for x in graph_rewrite ( valid , symbolic_flat ) . toposort if x . op is Ops . RANGE ]
def simplify ( self ) - > ShapeTracker :
if len ( self . views ) > = 2 and ( new_view := self . views [ - 2 ] + self . views [ - 1 ] ) is not None :
return ShapeTracker ( self . views [ : - 2 ] + ( new_view , ) ) . simplify ( )
return self
# *** under this line are the movement ops ***
def pad ( self , arg : tuple [ tuple [ sint , sint ] , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . pad ( arg ) , ) )
def shrink ( self , arg : tuple [ tuple [ sint , sint ] , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . shrink ( arg ) , ) )
def expand ( self , new_shape : tuple [ sint , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . expand ( new_shape ) , ) )
def permute ( self , axis : tuple [ int , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . permute ( axis ) , ) )
def flip ( self , mul : tuple [ int , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . flip ( mul ) , ) )
def reshape ( self , new_shape : tuple [ sint , . . . ] ) - > ShapeTracker :
if getenv ( " MERGE_VIEW " , 1 ) and ( new_view := self . views [ - 1 ] . reshape ( new_shape ) ) is not None : return ShapeTracker ( self . views [ 0 : - 1 ] + ( new_view , ) )
return ShapeTracker ( self . views + ( View . create ( new_shape ) , ) )
def mop ( self , op , arg ) : return mops [ op ] ( self , arg )
mops : dict [ Ops , Callable ] = { Ops . RESHAPE : ShapeTracker . reshape , Ops . PERMUTE : ShapeTracker . permute , Ops . EXPAND : ShapeTracker . expand ,
Ops . SHRINK : ShapeTracker . shrink , Ops . FLIP : ShapeTracker . flip , Ops . PAD : ShapeTracker . pad }