# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
from dataclasses import dataclass
import functools
from typing import Optional , Callable
from tinygrad . helpers import merge_dicts , getenv
from tinygrad . shape . view import View , strides_for_shape
from tinygrad . dtype import dtypes
from tinygrad . ops import UOp , Ops , graph_rewrite , split_uop , symbolic_flat , Variable , sint , uop_given_valid , simplify_valid
@functools . lru_cache ( None )
def views_to_indexed_uops ( views : tuple [ View , . . . ] , _idxs : Optional [ tuple [ UOp , . . . ] ] = None ) - > tuple [ UOp , UOp ] :
idx , valid = views [ - 1 ] . to_indexed_uops ( _idxs )
for view in reversed ( views [ 0 : - 1 ] ) :
view = view . minify ( )
acc , idxs = 1 , [ ]
for d in reversed ( view . shape ) :
idxs . append ( ( idx / / acc ) % d )
acc * = d
idx , valid = view . to_indexed_uops ( idxs [ : : - 1 ] , valid )
return idx , valid
@functools . lru_cache ( None )
def views_to_real_strides ( views : tuple [ View , . . . ] , ignore_valid = False ) - > tuple [ Optional [ sint ] , . . . ] :
# NOTE: if a stride is not always valid, it will be None
if len ( views ) == 1 and views [ - 1 ] . mask is None : return views [ - 1 ] . strides
ret : list [ Optional [ sint ] ] = [ None ] * len ( views [ - 1 ] . shape )
idx , valid = ( graph_rewrite ( u , symbolic_flat ) for u in views_to_indexed_uops ( views ) )
# TODO: always apply these in to_indexed_uops?
if ( newvalid := simplify_valid ( valid ) ) is not None : valid = newvalid
if ( newidx := uop_given_valid ( valid , idx ) ) is not None : idx = graph_rewrite ( newidx , symbolic_flat )
for c in split_uop ( idx , Ops . ADD ) :
if c . op is Ops . RANGE : ret [ c . arg ] = 1
if c . op is Ops . MUL and c . src [ 0 ] . op is Ops . RANGE and c . src [ 1 ] . op is Ops . CONST : ret [ c . src [ 0 ] . arg ] = c . src [ 1 ] . arg
if c . op is Ops . MUL and c . src [ 1 ] . op is Ops . RANGE and c . src [ 0 ] . op is Ops . CONST : ret [ c . src [ 1 ] . arg ] = c . src [ 0 ] . arg
used_ranges = [ x . arg for x in idx . toposort if x . op is Ops . RANGE ]
ret = [ x if i in used_ranges else 0 for i , x in enumerate ( ret ) ]
if not ignore_valid :
for masked_axis in [ x . arg for x in valid . toposort if x . op is Ops . RANGE ] : ret [ masked_axis ] = None
return tuple ( ret )
@dataclass ( frozen = True , order = True )
class ShapeTracker :
views : tuple [ View , . . . ]
def __add__ ( self , st : ShapeTracker ) - > ShapeTracker :
ret = self
for v in st . views : ret = ShapeTracker ( ret . views + ( v , ) ) . simplify ( ) # one view at a time = better simplification
return ret
def invert ( self , out_shape : tuple [ sint , . . . ] ) - > Optional [ ShapeTracker ] :
inverted_views : list [ View ] = [ ]
for v , s in zip ( self . views [ : : - 1 ] , [ x . shape for x in self . views [ : : - 1 ] [ 1 : ] ] + [ out_shape ] ) :
if ( inverted := v . invert ( s ) ) is None : return None
inverted_views . append ( inverted )
return ShapeTracker ( tuple ( inverted_views ) ) . reshape ( out_shape )
@staticmethod
def from_shape ( shape : tuple [ sint , . . . ] ) - > ShapeTracker : return ShapeTracker ( ( View . create ( shape ) , ) )
@property
def contiguous ( self ) - > bool : return len ( self . views ) == 1 and self . views [ 0 ] . contiguous
@property
def consecutive ( self ) - > bool : return len ( self . views ) == 1 and ( v := self . views [ 0 ] ) . mask is None and v . strides == strides_for_shape ( v . shape )
@property
def shape ( self ) - > tuple [ sint , . . . ] : return self . views [ - 1 ] . shape
@property
def size ( self ) - > int : return self . views [ - 1 ] . size ( )
def reduce ( self , axis : tuple [ int , . . . ] ) - > tuple [ sint , . . . ] : return tuple ( 1 if i in axis else s for i , s in enumerate ( self . shape ) )
def to_uop ( self ) - > UOp : return UOp ( Ops . VIEW , dtypes . void , ( ) , self )
def to_indexed_uops ( self , _idxs : Optional [ list [ UOp ] | tuple [ UOp , . . . ] ] = None ) - > tuple [ UOp , UOp ] :
return views_to_indexed_uops ( self . views , tuple ( _idxs ) if _idxs is not None else None )
def real_size ( self ) - > int :
if 0 in self . shape : return 0
idx , valid = self . to_indexed_uops ( )
if not valid . vmax : return 0
assert idx . vmax < 1e12 , f " real_size broken for { self } "
return int ( idx . vmax + 1 )
def vars ( self ) - > set [ Variable ] : return set ( ) . union ( * [ v . vars ( ) for v in self . views ] )
@property
def var_vals ( self ) - > dict [ Variable , int ] : return merge_dicts ( [ dict ( [ v . unbind ( ) ] ) for v in self . vars ( ) ] )
def unbind ( self ) - > tuple [ ShapeTracker , dict [ Variable , int ] ] :
unbound_views , var_vals = zip ( * [ v . unbind ( ) for v in self . views ] )
return ShapeTracker ( tuple ( unbound_views ) ) , merge_dicts ( var_vals )
def real_strides ( self , ignore_valid = False ) - > tuple [ Optional [ sint ] , . . . ] : return views_to_real_strides ( self . views , ignore_valid )
def unit_stride_axes ( self , ignore_valid = False ) - > list [ int ] : return [ i for i , st in enumerate ( self . real_strides ( ignore_valid ) ) if st == 1 ]
def axis_is_masked ( self , axis : int ) - > bool :
_ , valid = self . to_indexed_uops ( )
return axis in [ x . arg for x in graph_rewrite ( valid , symbolic_flat ) . toposort if x . op is Ops . RANGE ]
def simplify ( self ) - > ShapeTracker :
if len ( self . views ) > = 2 and ( new_view := self . views [ - 2 ] + self . views [ - 1 ] ) is not None :
return ShapeTracker ( self . views [ : - 2 ] + ( new_view , ) ) . simplify ( )
return self
# *** under this line are the movement ops ***
def pad ( self , arg : tuple [ tuple [ sint , sint ] , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . pad ( arg ) , ) )
def shrink ( self , arg : tuple [ tuple [ sint , sint ] , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . shrink ( arg ) , ) )
def expand ( self , new_shape : tuple [ sint , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . expand ( new_shape ) , ) )
def permute ( self , axis : tuple [ int , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . permute ( axis ) , ) )
def stride ( self , mul : tuple [ int , . . . ] ) - > ShapeTracker : return ShapeTracker ( self . views [ 0 : - 1 ] + ( self . views [ - 1 ] . stride ( mul ) , ) )
def reshape ( self , new_shape : tuple [ sint , . . . ] ) - > ShapeTracker :
if getenv ( " MERGE_VIEW " , 1 ) and ( new_view := self . views [ - 1 ] . reshape ( new_shape ) ) is not None : return ShapeTracker ( self . views [ 0 : - 1 ] + ( new_view , ) )
return ShapeTracker ( self . views + ( View . create ( new_shape ) , ) )
def mop ( self , op , arg ) : return mops [ op ] ( self , arg )
mops : dict [ Ops , Callable ] = { Ops . RESHAPE : ShapeTracker . reshape , Ops . PERMUTE : ShapeTracker . permute , Ops . EXPAND : ShapeTracker . expand ,
Ops . SHRINK : ShapeTracker . shrink , Ops . STRIDE : ShapeTracker . stride , Ops . PAD : ShapeTracker . pad }