from __future__ import annotations
import os , functools , platform , time , re , contextlib , operator , hashlib , pickle , sqlite3
import numpy as np
from typing import Dict , Tuple , Union , List , NamedTuple , Final , Iterator , ClassVar , Optional , Iterable , Any , TypeVar , TYPE_CHECKING
if TYPE_CHECKING : # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
T = TypeVar ( " T " )
# NOTE: it returns int 1 if x is empty regardless of the type of x
def prod ( x : Iterable [ T ] ) - > Union [ T , int ] : return functools . reduce ( operator . __mul__ , x , 1 )
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform . system ( ) == " Darwin "
CI = os . getenv ( " CI " , " " ) != " "
def dedup ( x ) : return list ( dict . fromkeys ( x ) ) # retains list order
def argfix ( * x ) : return tuple ( x [ 0 ] ) if x and x [ 0 ] . __class__ in ( tuple , list ) else x
def argsort ( x ) : return type ( x ) ( sorted ( range ( len ( x ) ) , key = x . __getitem__ ) ) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same ( items ) : return all ( x == items [ 0 ] for x in items )
def all_int ( t : Tuple [ Any , . . . ] ) - > TypeGuard [ Tuple [ int , . . . ] ] : return all ( isinstance ( s , int ) for s in t )
def colored ( st , color , background = False ) : return f " \u001b [ { 10 * background + 60 * ( color . upper ( ) == color ) + 30 + [ ' black ' , ' red ' , ' green ' , ' yellow ' , ' blue ' , ' magenta ' , ' cyan ' , ' white ' ] . index ( color . lower ( ) ) } m { st } \u001b [0m " if color is not None else st # replace the termcolor library with one line
def ansistrip ( s ) : return re . sub ( ' \x1b \\ [(K|.*?m) ' , ' ' , s )
def ansilen ( s ) : return len ( ansistrip ( s ) )
def make_pair ( x : Union [ int , Tuple [ int , . . . ] ] , cnt = 2 ) - > Tuple [ int , . . . ] : return ( x , ) * cnt if isinstance ( x , int ) else x
def flatten ( l : Union [ List , Iterator ] ) : return [ item for sublist in l for item in sublist ]
def fromimport ( mod , frm ) : return getattr ( __import__ ( mod , fromlist = [ frm ] ) , frm )
def strip_parens ( fst ) : return fst [ 1 : - 1 ] if fst [ 0 ] == ' ( ' and fst [ - 1 ] == ' ) ' and fst [ 1 : - 1 ] . find ( ' ( ' ) < = fst [ 1 : - 1 ] . find ( ' ) ' ) else fst
def merge_dicts ( ds : Iterable [ Dict ] ) - > Dict :
assert len ( kvs := set ( [ ( k , v ) for d in ds for k , v in d . items ( ) ] ) ) == len ( set ( kv [ 0 ] for kv in kvs ) ) , f " cannot merge, { kvs } contains different values for the same key "
return { k : v for d in ds for k , v in d . items ( ) }
def partition ( lst , fxn ) :
a : list [ Any ] = [ ]
b : list [ Any ] = [ ]
for s in lst : ( a if fxn ( s ) else b ) . append ( s )
return a , b
@functools . lru_cache ( maxsize = None )
def getenv ( key , default = 0 ) : return type ( default ) ( os . getenv ( key , default ) )
class Context ( contextlib . ContextDecorator ) :
stack : ClassVar [ List [ dict [ str , int ] ] ] = [ { } ]
def __init__ ( self , * * kwargs ) : self . kwargs = kwargs
def __enter__ ( self ) :
Context . stack [ - 1 ] = { k : o . value for k , o in ContextVar . _cache . items ( ) } # Store current state.
for k , v in self . kwargs . items ( ) : ContextVar . _cache [ k ] . value = v # Update to new temporary state.
Context . stack . append ( self . kwargs ) # Store the temporary state so we know what to undo later.
def __exit__ ( self , * args ) :
for k in Context . stack . pop ( ) : ContextVar . _cache [ k ] . value = Context . stack [ - 1 ] . get ( k , ContextVar . _cache [ k ] . value )
class ContextVar :
_cache : ClassVar [ Dict [ str , ContextVar ] ] = { }
value : int
def __new__ ( cls , key , default_value ) :
if key in ContextVar . _cache : return ContextVar . _cache [ key ]
instance = ContextVar . _cache [ key ] = super ( ) . __new__ ( cls )
instance . value = getenv ( key , default_value )
return instance
def __bool__ ( self ) : return bool ( self . value )
def __ge__ ( self , x ) : return self . value > = x
def __gt__ ( self , x ) : return self . value > x
def __lt__ ( self , x ) : return self . value < x
DEBUG , IMAGE , BEAM , NOOPT = ContextVar ( " DEBUG " , 0 ) , ContextVar ( " IMAGE " , 0 ) , ContextVar ( " BEAM " , 0 ) , ContextVar ( " NOOPT " , 0 )
GRAPH , GRAPHPATH = getenv ( " GRAPH " , 0 ) , getenv ( " GRAPHPATH " , " /tmp/net " )
class Timing ( contextlib . ContextDecorator ) :
def __init__ ( self , prefix = " " , on_exit = None , enabled = True ) : self . prefix , self . on_exit , self . enabled = prefix , on_exit , enabled
def __enter__ ( self ) : self . st = time . perf_counter_ns ( )
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
self . et = time . perf_counter_ns ( ) - self . st
if self . enabled : print ( f " { self . prefix } { self . et * 1e-6 : .2f } ms " + ( self . on_exit ( self . et ) if self . on_exit else " " ) )
# **** tinygrad now supports dtypes! *****
class DType ( NamedTuple ) :
priority : int # this determines when things get upcasted
itemsize : int
name : str
np : Optional [ type ] # TODO: someday this will be removed with the "remove numpy" project
sz : int = 1
def __repr__ ( self ) : return f " dtypes. { INVERSE_DTYPES_DICT [ self ] } "
# dependent typing?
class ImageDType ( DType ) :
def __new__ ( cls , priority , itemsize , name , np , shape ) :
return super ( ) . __new__ ( cls , priority , itemsize , name , np )
def __init__ ( self , priority , itemsize , name , np , shape ) :
self . shape : Tuple [ int , . . . ] = shape # arbitrary arg for the dtype, used in image for the shape
super ( ) . __init__ ( )
def __repr__ ( self ) : return f " dtypes. { self . name } ( { self . shape } ) "
# TODO: fix this to not need these
def __hash__ ( self ) : return hash ( ( super ( ) . __hash__ ( ) , self . shape ) )
def __eq__ ( self , x ) : return super ( ) . __eq__ ( x ) and self . shape == x . shape
def __ne__ ( self , x ) : return super ( ) . __ne__ ( x ) or self . shape != x . shape
class PtrDType ( DType ) :
def __new__ ( cls , dt : DType ) : return super ( ) . __new__ ( cls , dt . priority , dt . itemsize , dt . name , dt . np , dt . sz )
def __repr__ ( self ) : return f " ptr. { super ( ) . __repr__ ( ) } "
class dtypes :
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int ( x : DType ) - > bool : return x in ( dtypes . int8 , dtypes . int16 , dtypes . int32 , dtypes . int64 , dtypes . uint8 , dtypes . uint16 , dtypes . uint32 , dtypes . uint64 )
@staticmethod
def is_float ( x : DType ) - > bool : return x in ( dtypes . float16 , dtypes . float32 , dtypes . float64 , dtypes . _half4 , dtypes . _float2 , dtypes . _float4 )
@staticmethod
def is_unsigned ( x : DType ) - > bool : return x in ( dtypes . uint8 , dtypes . uint16 , dtypes . uint32 , dtypes . uint64 )
@staticmethod
def from_np ( x ) - > DType : return DTYPES_DICT [ np . dtype ( x ) . name ]
@staticmethod
def fields ( ) - > Dict [ str , DType ] : return DTYPES_DICT
bool : Final [ DType ] = DType ( 0 , 1 , " bool " , np . bool_ )
float16 : Final [ DType ] = DType ( 0 , 2 , " half " , np . float16 )
half = float16
float32 : Final [ DType ] = DType ( 4 , 4 , " float " , np . float32 )
float = float32
float64 : Final [ DType ] = DType ( 0 , 8 , " double " , np . float64 )
double = float64
int8 : Final [ DType ] = DType ( 0 , 1 , " char " , np . int8 )
int16 : Final [ DType ] = DType ( 1 , 2 , " short " , np . int16 )
int32 : Final [ DType ] = DType ( 2 , 4 , " int " , np . int32 )
int64 : Final [ DType ] = DType ( 3 , 8 , " long " , np . int64 )
uint8 : Final [ DType ] = DType ( 0 , 1 , " unsigned char " , np . uint8 )
uint16 : Final [ DType ] = DType ( 1 , 2 , " unsigned short " , np . uint16 )
uint32 : Final [ DType ] = DType ( 2 , 4 , " unsigned int " , np . uint32 )
uint64 : Final [ DType ] = DType ( 3 , 8 , " unsigned long " , np . uint64 )
# NOTE: bfloat16 isn't supported in numpy
bfloat16 : Final [ DType ] = DType ( 0 , 2 , " __bf16 " , None )
# NOTE: these are internal dtypes, should probably check for that
_int2 : Final [ DType ] = DType ( 2 , 4 * 2 , " int2 " , None , 2 )
_half4 : Final [ DType ] = DType ( 0 , 2 * 4 , " half4 " , None , 4 )
_float2 : Final [ DType ] = DType ( 4 , 4 * 2 , " float2 " , None , 2 )
_float4 : Final [ DType ] = DType ( 4 , 4 * 4 , " float4 " , None , 4 )
_arg_int32 : Final [ DType ] = DType ( 2 , 4 , " _arg_int32 " , None )
# NOTE: these are image dtypes
@staticmethod
def imageh ( shp ) : return ImageDType ( 100 , 2 , " imageh " , np . float16 , shp )
@staticmethod
def imagef ( shp ) : return ImageDType ( 100 , 4 , " imagef " , np . float32 , shp )
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = { k : v for k , v in dtypes . __dict__ . items ( ) if not k . startswith ( ' __ ' ) and not callable ( v ) and not v . __class__ == staticmethod }
INVERSE_DTYPES_DICT = { v : k for k , v in DTYPES_DICT . items ( ) }
class GlobalCounters :
global_ops : ClassVar [ int ] = 0
global_mem : ClassVar [ int ] = 0
time_sum_s : ClassVar [ float ] = 0.0
kernel_count : ClassVar [ int ] = 0
mem_used : ClassVar [ int ] = 0 # NOTE: this is not reset
mem_cached : ClassVar [ int ] = 0 # NOTE: this is not reset
@staticmethod
def reset ( ) : GlobalCounters . global_ops , GlobalCounters . global_mem , GlobalCounters . time_sum_s , GlobalCounters . kernel_count = 0 , 0 , 0.0 , 0
# *** universal database cache ***
CACHEDB = getenv ( " CACHEDB " , " /tmp/tinygrad_cache " )
CACHELEVEL = getenv ( " CACHELEVEL " , 2 )
VERSION = 6
_db_connection = None
def db_connection ( ) :
global _db_connection
if _db_connection is None :
_db_connection = sqlite3 . connect ( CACHEDB )
if DEBUG > = 5 : _db_connection . set_trace_callback ( print )
if diskcache_get ( " meta " , " version " ) != VERSION :
print ( " cache is out of date, clearing it " )
os . unlink ( CACHEDB )
_db_connection = sqlite3 . connect ( CACHEDB )
if DEBUG > = 5 : _db_connection . set_trace_callback ( print )
diskcache_put ( " meta " , " version " , VERSION )
return _db_connection
def diskcache_get ( table : str , key : Union [ Dict , str , int ] ) - > Any :
if isinstance ( key , ( str , int ) ) : key = { " key " : key }
try :
res = db_connection ( ) . cursor ( ) . execute ( f " SELECT val FROM { table } WHERE { ' AND ' . join ( [ f ' { x } =? ' for x in key . keys ( ) ] ) } " , tuple ( key . values ( ) ) )
except sqlite3 . OperationalError :
return None # table doesn't exist
if ( val := res . fetchone ( ) ) is not None :
return pickle . loads ( val [ 0 ] )
return None
_db_tables = set ( )
def diskcache_put ( table : str , key : Union [ Dict , str , int ] , val : Any ) :
if isinstance ( key , ( str , int ) ) : key = { " key " : key }
conn = db_connection ( )
cur = conn . cursor ( )
if table not in _db_tables :
TYPES = { str : " text " , bool : " integer " , int : " integer " , float : " numeric " , bytes : " blob " }
ltypes = ' , ' . join ( f " { k } { TYPES [ type ( key [ k ] ) ] } " for k in key . keys ( ) )
cur . execute ( f " CREATE TABLE IF NOT EXISTS { table } ( { ltypes } , val blob, PRIMARY KEY ( { ' , ' . join ( key . keys ( ) ) } )) " )
_db_tables . add ( table )
cur . execute ( f " REPLACE INTO { table } ( { ' , ' . join ( key . keys ( ) ) } , val) VALUES ( { ' , ' . join ( [ ' ? ' ] * len ( key . keys ( ) ) ) } , ?) " , tuple ( key . values ( ) ) + ( pickle . dumps ( val ) , ) )
conn . commit ( )
cur . close ( )
return val
def diskcache ( func ) :
def wrapper ( * args , * * kwargs ) - > bytes :
table , key = f " cache_ { func . __name__ } " , hashlib . sha256 ( pickle . dumps ( ( args , kwargs ) ) ) . hexdigest ( )
if ( ret := diskcache_get ( table , key ) ) : return ret
return diskcache_put ( table , key , func ( * args , * * kwargs ) )
setattr ( wrapper , " __wrapped__ " , func )
return wrapper