import math
from typing import Tuple , Optional , cast
from tinygrad . helpers import argsort , DType
from tinygrad . ops import UnaryOps , BinaryOps , TernaryOps , ReduceOps
from tinygrad . tensor import Function
from tinygrad . lazy import LazyBuffer
from tinygrad . shape . symbolic import sint
class Contiguous ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer : return x . contiguous ( )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer : return grad_output
class ContiguousBackward ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer : return x
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer : return grad_output . contiguous ( )
class Cast ( Function ) :
def forward ( self , x : LazyBuffer , dtype : DType , bitcast : bool = False ) - > LazyBuffer :
self . input_dtype , self . bitcast = x . dtype , bitcast
return x . cast ( dtype , bitcast )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . cast ( self . input_dtype , self . bitcast )
# ************* unary ops *************
class Zero ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer : return x . const ( 0 )
def backward ( self , grad : LazyBuffer ) - > LazyBuffer : return grad . const ( 0 )
class Neg ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer : return x . e ( UnaryOps . NEG )
def backward ( self , grad : LazyBuffer ) - > LazyBuffer : return grad . e ( UnaryOps . NEG )
class Sin ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer :
self . x = x
return x . e ( UnaryOps . SIN )
def backward ( self , grad : LazyBuffer ) - > LazyBuffer :
return self . x . const ( math . pi / 2 ) . e ( BinaryOps . SUB , self . x ) . e ( UnaryOps . SIN ) . e ( BinaryOps . MUL , grad )
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer :
self . ret = x . e ( BinaryOps . MAX , x . const ( 0 ) )
return self . ret
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return self . ret . const ( 0 ) . e ( BinaryOps . CMPLT , self . ret ) . e ( BinaryOps . MUL , grad_output )
class Log ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer :
self . x = x
return x . e ( UnaryOps . LOG2 ) . e ( BinaryOps . MUL , x . const ( math . log ( 2 ) ) )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . e ( BinaryOps . DIV , self . x )
class Exp ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer :
self . ret = x . e ( BinaryOps . MUL , x . const ( 1 / math . log ( 2 ) ) ) . e ( UnaryOps . EXP2 )
return self . ret
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return self . ret . e ( BinaryOps . MUL , grad_output )
class Sqrt ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer :
self . ret = x . e ( UnaryOps . SQRT )
return self . ret
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . e ( BinaryOps . DIV , self . ret . e ( BinaryOps . MUL , self . ret . const ( 2 ) ) )
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
# TODO: have the backend automatically find this
class Sigmoid ( Function ) :
def forward ( self , x : LazyBuffer ) - > LazyBuffer :
self . ret = x . const ( 1 ) . e ( BinaryOps . DIV , x . const ( 1 ) . e ( BinaryOps . ADD , x . e ( BinaryOps . MUL , x . const ( - 1 / math . log ( 2 ) ) ) . e ( UnaryOps . EXP2 ) ) )
return self . ret
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return self . ret . e ( BinaryOps . MUL , self . ret . const ( 1 ) . e ( BinaryOps . SUB , self . ret ) ) . e ( BinaryOps . MUL , grad_output )
# ************* binary ops *************
class Less ( Function ) :
def forward ( self , x : LazyBuffer , y : LazyBuffer ) - > LazyBuffer :
return x . e ( BinaryOps . CMPLT , y )
class Add ( Function ) :
def forward ( self , x : LazyBuffer , y : LazyBuffer ) - > LazyBuffer :
return x . e ( BinaryOps . ADD , y )
def backward ( self , grad_output : LazyBuffer ) - > Tuple [ Optional [ LazyBuffer ] , Optional [ LazyBuffer ] ] :
return grad_output if self . needs_input_grad [ 0 ] else None , \
grad_output if self . needs_input_grad [ 1 ] else None
class Sub ( Function ) :
def forward ( self , x : LazyBuffer , y : LazyBuffer ) - > LazyBuffer :
return x . e ( BinaryOps . SUB , y )
def backward ( self , grad_output : LazyBuffer ) - > Tuple [ Optional [ LazyBuffer ] , Optional [ LazyBuffer ] ] :
return grad_output if self . needs_input_grad [ 0 ] else None , \
grad_output . e ( UnaryOps . NEG ) if self . needs_input_grad [ 1 ] else None
class Mul ( Function ) :
def forward ( self , x : LazyBuffer , y : LazyBuffer ) - > LazyBuffer :
self . x , self . y = x , y
return x . e ( BinaryOps . MUL , y )
def backward ( self , grad_output : LazyBuffer ) - > Tuple [ Optional [ LazyBuffer ] , Optional [ LazyBuffer ] ] :
return self . y . e ( BinaryOps . MUL , grad_output ) if self . needs_input_grad [ 0 ] else None , \
self . x . e ( BinaryOps . MUL , grad_output ) if self . needs_input_grad [ 1 ] else None
class Div ( Function ) :
def forward ( self , x : LazyBuffer , y : LazyBuffer ) - > LazyBuffer :
self . x , self . y = x , y
return x . e ( BinaryOps . DIV , y )
def backward ( self , grad_output : LazyBuffer ) - > Tuple [ Optional [ LazyBuffer ] , Optional [ LazyBuffer ] ] :
return grad_output . e ( BinaryOps . DIV , self . y ) if self . needs_input_grad [ 0 ] else None , \
grad_output . e ( UnaryOps . NEG ) . e ( BinaryOps . MUL , self . x ) . e ( BinaryOps . DIV , self . y . e ( BinaryOps . MUL , self . y ) ) if self . needs_input_grad [ 1 ] else None
# ************* ternary ops *************
class Where ( Function ) :
def forward ( self , x : LazyBuffer , y : LazyBuffer , z : LazyBuffer ) - > LazyBuffer :
self . x = x
return x . e ( TernaryOps . WHERE , y , z )
def backward ( self , grad_output : LazyBuffer ) - > Tuple [ None , Optional [ LazyBuffer ] , Optional [ LazyBuffer ] ] :
return None , \
self . x . e ( TernaryOps . WHERE , grad_output , grad_output . const ( 0 ) ) if self . needs_input_grad [ 1 ] else None , \
self . x . e ( TernaryOps . WHERE , grad_output . const ( 0 ) , grad_output ) if self . needs_input_grad [ 2 ] else None
# ************* reduce ops *************
class Sum ( Function ) :
def forward ( self , x : LazyBuffer , new_shape : Tuple [ int , . . . ] ) - > LazyBuffer :
self . input_shape = x . shape
return x . r ( ReduceOps . SUM , new_shape )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . expand ( self . input_shape )
class Max ( Function ) :
def forward ( self , x : LazyBuffer , new_shape : Tuple [ int , . . . ] ) - > LazyBuffer :
self . x , self . ret = x , x . r ( ReduceOps . MAX , new_shape )
return self . ret
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self . x . const ( 1.0 ) . e ( BinaryOps . SUB , self . x . e ( BinaryOps . CMPLT , self . ret . expand ( self . x . shape ) ) )
div = max_is_1s . r ( ReduceOps . SUM , grad_output . shape ) . expand ( self . x . shape )
return max_is_1s . e ( BinaryOps . DIV , div ) . e ( BinaryOps . MUL , grad_output . expand ( self . x . shape ) )
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand ( Function ) :
def forward ( self , x : LazyBuffer , shape : Tuple [ int , . . . ] ) - > LazyBuffer :
self . input_shape = x . shape
return x . expand ( shape )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . r ( ReduceOps . SUM , self . input_shape )
class Reshape ( Function ) :
def forward ( self , x : LazyBuffer , shape : Tuple [ int , . . . ] ) - > LazyBuffer :
self . input_shape = x . shape
return x . reshape ( shape )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . reshape ( self . input_shape )
class Permute ( Function ) :
def forward ( self , x : LazyBuffer , order : Tuple [ int , . . . ] ) - > LazyBuffer :
self . input_order = order
return x . permute ( order )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . permute ( argsort ( self . input_order ) )
class Pad ( Function ) :
def forward ( self , x : LazyBuffer , arg : Tuple [ Tuple [ int , int ] , . . . ] ) - > LazyBuffer :
self . narg = tuple ( [ ( p [ 0 ] , s + p [ 0 ] ) for s , p in zip ( x . shape , arg ) ] )
return x . pad ( arg )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . shrink ( self . narg )
class Shrink ( Function ) :
def forward ( self , x : LazyBuffer , arg : Tuple [ Tuple [ sint , sint ] , . . . ] ) - > LazyBuffer :
self . narg = tuple ( [ ( p [ 0 ] , s - p [ 1 ] ) for s , p in zip ( x . shape , arg ) ] )
return x . shrink ( arg )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
assert all ( isinstance ( x [ 0 ] , int ) and isinstance ( x [ 1 ] , int ) for x in self . narg ) , " symbolic shrink does not support backward "
# need this cast because mypy cannot narrow the type even with assert
return grad_output . pad ( cast ( Tuple [ Tuple [ int , int ] , . . . ] , self . narg ) )
class Flip ( Function ) :
def forward ( self , x : LazyBuffer , axis : Tuple [ int , . . . ] ) - > LazyBuffer :
self . arg = tuple ( [ - 1 if i in set ( axis ) else 1 for i in range ( len ( x . shape ) ) ] )
return x . stride ( self . arg )
def backward ( self , grad_output : LazyBuffer ) - > LazyBuffer :
return grad_output . stride ( self . arg )