from __future__ import annotations
import math
from tinygrad . tensor import Tensor
from tinygrad . dtype import dtypes
from tinygrad . device import is_dtype_supported
from tinygrad . helpers import prod , make_tuple , flatten
from tinygrad . nn import optim , state , datasets # noqa: F401
class BatchNorm :
"""
Applies Batch Normalization over a 2 D or 3 D input .
- Described : https : / / paperswithcode . com / method / batch - normalization
- Paper : https : / / arxiv . org / abs / 1502.03167 v3
See : ` Tensor . batchnorm `
` ` ` python exec = " true " session = " tensor "
from tinygrad import Tensor , dtypes , nn
import numpy as np
np . set_printoptions ( precision = 4 )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
norm = nn . BatchNorm ( 3 )
t = Tensor . rand ( 2 , 3 , 4 , 4 )
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = norm ( t )
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
"""
def __init__ ( self , sz : int , eps = 1e-5 , affine = True , track_running_stats = True , momentum = 0.1 ) :
self . eps , self . track_running_stats , self . momentum = eps , track_running_stats , momentum
self . weight : Tensor | None = Tensor . ones ( sz ) if affine else None
self . bias : Tensor | None = Tensor . zeros ( sz ) if affine else None
self . num_batches_tracked = Tensor . zeros ( 1 , dtype = ' long ' if is_dtype_supported ( dtypes . long ) else ' int ' , requires_grad = False )
if track_running_stats : self . running_mean , self . running_var = Tensor . zeros ( sz , requires_grad = False ) , Tensor . ones ( sz , requires_grad = False )
def calc_stats ( self , x : Tensor ) - > tuple [ Tensor , Tensor ] :
shape_mask : list [ int ] = [ 1 , - 1 , * ( [ 1 ] * ( x . ndim - 2 ) ) ]
if self . track_running_stats and not Tensor . training : return self . running_mean , self . running_var . reshape ( shape = shape_mask ) . expand ( x . shape )
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x . mean ( axis = ( reduce_axes := tuple ( x for x in range ( x . ndim ) if x != 1 ) ) )
y = ( x - batch_mean . detach ( ) . reshape ( shape = shape_mask ) ) # d(var)/d(mean) = 0
batch_var = ( y * y ) . mean ( axis = reduce_axes )
return batch_mean , batch_var
def __call__ ( self , x : Tensor ) - > Tensor :
batch_mean , batch_var = self . calc_stats ( x )
# NOTE: wow, this is done all throughout training in most PyTorch models
if self . track_running_stats and Tensor . training :
self . running_mean . assign ( ( 1 - self . momentum ) * self . running_mean + self . momentum * batch_mean . detach ( ) )
self . running_var . assign ( ( 1 - self . momentum ) * self . running_var + self . momentum * x . numel ( ) / ( x . numel ( ) - x . shape [ 1 ] ) * batch_var . detach ( ) )
self . num_batches_tracked + = 1
return x . batchnorm ( self . weight , self . bias , batch_mean , batch_var . add ( self . eps ) . rsqrt ( ) )
BatchNorm2d = BatchNorm3d = BatchNorm
def Conv1d ( in_channels : int , out_channels : int , kernel_size : int , stride = 1 , padding : int | str = 0 , dilation = 1 , groups = 1 , bias = True ) - > Conv2d :
"""
Applies a 1 D convolution over an input signal composed of several input planes .
See : https : / / pytorch . org / docs / stable / generated / torch . nn . Conv1d
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
conv = nn . Conv1d ( 1 , 1 , 3 )
t = Tensor . rand ( 1 , 1 , 4 )
print ( t . numpy ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = conv ( t )
print ( t . numpy ( ) )
` ` `
"""
return Conv2d ( in_channels , out_channels , ( kernel_size , ) , stride , padding , dilation , groups , bias )
class Conv2d :
"""
Applies a 2 D convolution over an input signal composed of several input planes .
See : https : / / pytorch . org / docs / stable / generated / torch . nn . Conv2d
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
conv = nn . Conv2d ( 1 , 1 , 3 )
t = Tensor . rand ( 1 , 1 , 4 , 4 )
print ( t . numpy ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = conv ( t )
print ( t . numpy ( ) )
` ` `
"""
def __init__ ( self , in_channels : int , out_channels : int , kernel_size : int | tuple [ int , . . . ] , stride = 1 , padding : int | tuple [ int , . . . ] | str = 0 ,
dilation = 1 , groups = 1 , bias = True ) :
self . kernel_size = make_tuple ( kernel_size , 2 )
if isinstance ( padding , str ) :
if padding . lower ( ) != ' same ' : raise ValueError ( f " Invalid padding string { padding !r} , only ' same ' is supported " )
if stride != 1 : raise ValueError ( " padding= ' same ' is not supported for strided convolutions " )
pad = [ ( d * ( k - 1 ) / / 2 , d * ( k - 1 ) - d * ( k - 1 ) / / 2 ) for d , k in zip ( make_tuple ( dilation , len ( self . kernel_size ) ) , self . kernel_size [ : : - 1 ] ) ]
padding = tuple ( flatten ( pad ) )
self . stride , self . dilation , self . groups , self . padding = stride , dilation , groups , padding
scale = 1 / math . sqrt ( in_channels * prod ( self . kernel_size ) )
self . weight = Tensor . uniform ( out_channels , in_channels / / groups , * self . kernel_size , low = - scale , high = scale )
self . bias : Tensor | None = Tensor . uniform ( out_channels , low = - scale , high = scale ) if bias else None
def __call__ ( self , x : Tensor ) - > Tensor : return x . conv2d ( self . weight , self . bias , self . groups , self . stride , self . dilation , self . padding )
def ConvTranspose1d ( in_channels : int , out_channels : int , kernel_size : int , stride = 1 , padding = 0 , output_padding = 0 , dilation = 1 ,
groups = 1 , bias = True ) - > ConvTranspose2d :
"""
Applies a 1 D transposed convolution operator over an input signal composed of several input planes .
See : https : / / pytorch . org / docs / stable / generated / torch . nn . ConvTranspose1d
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
conv = nn . ConvTranspose1d ( 1 , 1 , 3 )
t = Tensor . rand ( 1 , 1 , 4 )
print ( t . numpy ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = conv ( t )
print ( t . numpy ( ) )
` ` `
"""
return ConvTranspose2d ( in_channels , out_channels , ( kernel_size , ) , stride , padding , output_padding , dilation , groups , bias )
class ConvTranspose2d ( Conv2d ) :
"""
Applies a 2 D transposed convolution operator over an input image .
See : https : / / pytorch . org / docs / stable / generated / torch . nn . ConvTranspose2d
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
conv = nn . ConvTranspose2d ( 1 , 1 , 3 )
t = Tensor . rand ( 1 , 1 , 4 , 4 )
print ( t . numpy ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = conv ( t )
print ( t . numpy ( ) )
` ` `
"""
def __init__ ( self , in_channels : int , out_channels : int , kernel_size : int | tuple [ int , . . . ] , stride = 1 , padding = 0 , output_padding = 0 ,
dilation = 1 , groups = 1 , bias = True ) :
super ( ) . __init__ ( in_channels , out_channels , kernel_size , stride , padding , dilation , groups , bias )
scale = 1 / math . sqrt ( in_channels * prod ( self . kernel_size ) )
self . weight = Tensor . uniform ( in_channels , out_channels / / groups , * self . kernel_size , low = - scale , high = scale )
self . output_padding = output_padding
def __call__ ( self , x : Tensor ) - > Tensor :
return x . conv_transpose2d ( self . weight , self . bias , self . groups , self . stride , self . dilation , self . padding , self . output_padding )
class Linear :
"""
Applies a linear transformation to the incoming data .
See : https : / / pytorch . org / docs / stable / generated / torch . nn . Linear
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
lin = nn . Linear ( 3 , 4 )
t = Tensor . rand ( 2 , 3 )
print ( t . numpy ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = lin ( t )
print ( t . numpy ( ) )
` ` `
"""
def __init__ ( self , in_features : int , out_features : int , bias = True ) :
bound = 1 / math . sqrt ( in_features )
self . weight = Tensor . uniform ( out_features , in_features , low = - bound , high = bound )
self . bias = Tensor . uniform ( out_features , low = - bound , high = bound ) if bias else None
def __call__ ( self , x : Tensor ) - > Tensor : return x . linear ( self . weight . transpose ( ) , self . bias )
class GroupNorm :
"""
Applies Group Normalization over a mini - batch of inputs .
- Described : https : / / paperswithcode . com / method / group - normalization
- Paper : https : / / arxiv . org / abs / 1803.08494 v3
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
norm = nn . GroupNorm ( 2 , 12 )
t = Tensor . rand ( 2 , 12 , 4 , 4 ) * 2 + 1
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = norm ( t )
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
"""
def __init__ ( self , num_groups : int , num_channels : int , eps = 1e-5 , affine = True ) :
self . num_groups , self . num_channels , self . eps = num_groups , num_channels , eps
self . weight : Tensor | None = Tensor . ones ( num_channels ) if affine else None
self . bias : Tensor | None = Tensor . zeros ( num_channels ) if affine else None
def __call__ ( self , x : Tensor ) - > Tensor :
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x . reshape ( x . shape [ 0 ] , self . num_groups , - 1 ) . layernorm ( eps = self . eps ) . reshape ( x . shape )
if self . weight is None or self . bias is None : return x
# elementwise_affine on channels
return x * self . weight . reshape ( 1 , - 1 , * [ 1 ] * ( x . ndim - 2 ) ) + self . bias . reshape ( 1 , - 1 , * [ 1 ] * ( x . ndim - 2 ) )
class InstanceNorm :
"""
Applies Instance Normalization over a mini - batch of inputs .
- Described : https : / / paperswithcode . com / method / instance - normalization
- Paper : https : / / arxiv . org / abs / 1607.08022 v3
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
norm = nn . InstanceNorm ( 3 )
t = Tensor . rand ( 2 , 3 , 4 , 4 ) * 2 + 1
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = norm ( t )
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
"""
def __init__ ( self , num_features : int , eps = 1e-5 , affine = True ) :
self . num_features , self . eps = num_features , eps
self . weight : Tensor | None = Tensor . ones ( num_features ) if affine else None
self . bias : Tensor | None = Tensor . zeros ( num_features ) if affine else None
def __call__ ( self , x : Tensor ) - > Tensor :
x = x . reshape ( x . shape [ 0 ] , self . num_features , - 1 ) . layernorm ( eps = self . eps ) . reshape ( x . shape )
if self . weight is None or self . bias is None : return x
return x * self . weight . reshape ( 1 , - 1 , * [ 1 ] * ( x . ndim - 2 ) ) + self . bias . reshape ( 1 , - 1 , * [ 1 ] * ( x . ndim - 2 ) )
class LayerNorm :
"""
Applies Layer Normalization over a mini - batch of inputs .
- Described : https : / / paperswithcode . com / method / layer - normalization
- Paper : https : / / arxiv . org / abs / 1607.06450 v1
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
norm = nn . LayerNorm ( 3 )
t = Tensor . rand ( 2 , 5 , 3 ) * 2 + 1
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = norm ( t )
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
"""
def __init__ ( self , normalized_shape : int | tuple [ int , . . . ] , eps = 1e-5 , elementwise_affine = True ) :
self . normalized_shape : tuple [ int , . . . ] = make_tuple ( normalized_shape , 1 )
self . axis , self . eps , self . elementwise_affine = tuple ( - 1 - i for i in range ( len ( self . normalized_shape ) ) ) , eps , elementwise_affine
self . weight : Tensor | None = Tensor . ones ( * self . normalized_shape ) if elementwise_affine else None
self . bias : Tensor | None = Tensor . zeros ( * self . normalized_shape ) if elementwise_affine else None
def __call__ ( self , x : Tensor ) - > Tensor :
assert self . normalized_shape == x . shape [ - len ( self . normalized_shape ) : ] , f " last dimensions of { x . shape } must match { self . normalized_shape } "
x = x . layernorm ( eps = self . eps , axis = self . axis )
if not self . elementwise_affine : return x
return x * self . weight + self . bias
class LayerNorm2d ( LayerNorm ) :
"""
Applies Layer Normalization over a mini - batch of 2 D inputs .
See : ` LayerNorm `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
norm = nn . LayerNorm2d ( 3 )
t = Tensor . rand ( 2 , 3 , 4 , 4 ) * 2 + 1
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
t = norm ( t )
print ( t . mean ( ) . item ( ) , t . std ( ) . item ( ) )
` ` `
"""
def __call__ ( self , x : Tensor ) - > Tensor : return super ( ) . __call__ ( x . permute ( 0 , 2 , 3 , 1 ) ) . permute ( 0 , 3 , 1 , 2 )
class RMSNorm :
"""
Applies Root Mean Square Normalization to input .
- Described : https : / / paperswithcode . com / method / rmsnorm
- Paper : https : / / arxiv . org / abs / 1910.07467
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
norm = nn . RMSNorm ( 4 )
t = Tensor . arange ( 12 , dtype = dtypes . float ) . reshape ( 3 , 4 )
print ( t . numpy ( ) )
` ` `
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
print ( norm ( t ) . numpy ( ) )
` ` `
"""
def __init__ ( self , dim : int , eps = 1e-6 ) : self . eps , self . weight = eps , Tensor . ones ( dim )
def _norm ( self , x : Tensor ) - > Tensor : return x * ( x . square ( ) . mean ( - 1 , keepdim = True ) + self . eps ) . rsqrt ( )
def __call__ ( self , x : Tensor ) - > Tensor : return self . _norm ( x . float ( ) ) . cast ( x . dtype ) * self . weight
class Embedding :
"""
A simple lookup table that stores embeddings of a fixed dictionary and size .
See : https : / / pytorch . org / docs / stable / generated / torch . nn . Embedding
` ` ` python exec = " true " source = " above " session = " tensor " result = " python "
emb = nn . Embedding ( 10 , 3 )
print ( emb ( Tensor ( [ 1 , 2 , 3 , 1 ] ) ) . numpy ( ) )
` ` `
"""
def __init__ ( self , vocab_size : int , embed_size : int ) :
self . vocab_sz , self . embed_sz , self . weight = vocab_size , embed_size , Tensor . glorot_uniform ( vocab_size , embed_size )
def __call__ ( self , idx : Tensor ) - > Tensor :
if not hasattr ( self , ' arange ' ) : self . arange = Tensor . arange ( self . vocab_sz , requires_grad = False , device = self . weight . device ) . unsqueeze ( - 1 )
big_shp = idx . shape + ( self . vocab_sz , self . embed_sz )
arange , idx , vals = self . arange . expand ( big_shp ) , idx . reshape ( idx . shape + ( 1 , 1 ) ) . expand ( big_shp ) , self . weight . expand ( big_shp )
return ( arange == idx ) . mul ( vals ) . sum ( - 2 , dtype = vals . dtype )
class LSTMCell :
"""
A long short - term memory ( LSTM ) cell .
Args :
input_size : The number of expected features in the input ` x `
hidden_size : The number of features in the hidden state ` h `
bias : If ` ` False ` ` , then the layer does not use bias weights ` b_ih ` and ` b_hh `
"""
def __init__ ( self , input_size : int , hidden_size : int , bias : bool = True ) :
stdv = 1.0 / math . sqrt ( hidden_size )
self . weight_ih = Tensor . uniform ( hidden_size * 4 , input_size , low = - stdv , high = stdv )
self . weight_hh = Tensor . uniform ( hidden_size * 4 , hidden_size , low = - stdv , high = stdv )
self . bias_ih : Tensor | None = Tensor . zeros ( hidden_size * 4 ) if bias else None
self . bias_hh : Tensor | None = Tensor . zeros ( hidden_size * 4 ) if bias else None
def __call__ ( self , x : Tensor , hc : tuple [ Tensor , Tensor ] | None = None ) - > tuple [ Tensor , Tensor ] :
if hc is None : hc = ( Tensor . zeros ( x . size ( 0 ) , self . weight_hh . size ( 1 ) , dtype = x . dtype , device = x . device ) , ) * 2
gates = x . linear ( self . weight_ih . T , self . bias_ih ) + hc [ 0 ] . linear ( self . weight_hh . T , self . bias_hh )
i , f , g , o = gates . chunk ( 4 , dim = 1 )
i , f , g , o = i . sigmoid ( ) , f . sigmoid ( ) , g . tanh ( ) , o . sigmoid ( )
new_c = f * hc [ 1 ] + i * g
new_h = o * new_c . tanh ( )
return ( new_h . contiguous ( ) , new_c . contiguous ( ) )