import unittest
from tinygrad import Tensor , TinyJit , Variable , dtypes
import numpy as np
class TestSetitem ( unittest . TestCase ) :
def test_simple_setitem ( self ) :
cases = (
( ( 6 , 6 ) , ( slice ( 2 , 4 ) , slice ( 3 , 5 ) ) , Tensor . ones ( 2 , 2 ) ) ,
( ( 6 , 6 ) , ( slice ( 2 , 4 ) , slice ( 3 , 5 ) ) , Tensor ( [ 1. , 2. ] ) ) ,
( ( 6 , 6 ) , ( slice ( 2 , 4 ) , slice ( 3 , 5 ) ) , 1.0 ) ,
( ( 6 , 6 ) , ( 3 , 4 ) , 1.0 ) ,
( ( 6 , 6 ) , ( 3 , None , 4 , None ) , 1.0 ) ,
( ( 4 , 4 , 4 , 4 ) , ( Ellipsis , slice ( 1 , 3 ) , slice ( None ) ) , Tensor ( 4 ) ) ,
( ( 4 , 4 , 4 , 4 ) , ( Ellipsis , slice ( 1 , 3 ) ) , 4 ) ,
( ( 4 , 4 , 4 , 4 ) , ( 2 , slice ( 1 , 3 ) , None , 1 ) , 4 ) ,
( ( 4 , 4 , 4 , 4 ) , ( slice ( 1 , 3 ) , slice ( None ) , slice ( 0 , 4 , 2 ) ) , 4 ) ,
( ( 4 , 4 , 4 , 4 ) , ( slice ( 1 , 3 ) , slice ( None ) , slice ( None ) , slice ( 0 , 3 ) ) , 4 ) ,
( ( 6 , 6 ) , ( slice ( 1 , 5 , 2 ) , slice ( 0 , 5 , 3 ) ) , 1.0 ) ,
( ( 6 , 6 ) , ( slice ( 5 , 1 , - 2 ) , slice ( 5 , 0 , - 3 ) ) , 1.0 ) ,
)
for shp , slc , val in cases :
t = Tensor . zeros ( shp ) . contiguous ( )
t [ slc ] = val
n = np . zeros ( shp )
n [ slc ] = val . numpy ( ) if isinstance ( val , Tensor ) else val
np . testing . assert_allclose ( t . numpy ( ) , n )
def test_setitem_into_unrealized ( self ) :
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 )
t [ 1 ] = 5
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 5 , 5 ] ] )
def test_setitem_dtype ( self ) :
for dt in ( dtypes . int , dtypes . float , dtypes . bool ) :
for v in ( 5. , 5 , True ) :
t = Tensor . ones ( 6 , 6 , dtype = dt ) . contiguous ( )
t [ 1 ] = v
self . assertEqual ( t . dtype , dt )
def test_setitem_into_noncontiguous ( self ) :
t = Tensor . ones ( 4 )
self . assertFalse ( t . lazydata . st . contiguous )
with self . assertRaises ( RuntimeError ) : t [ 1 ] = 5
@unittest . skip ( " TODO: flaky " )
def test_setitem_inplace_operator ( self ) :
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] + = 2
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 4 , 5 ] ] )
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] - = 1
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 1 , 2 ] ] )
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] * = 2
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 4 , 6 ] ] )
# NOTE: have to manually cast setitem target to least_upper_float for div
t = Tensor . arange ( 4 , dtype = dtypes . float ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] / = 2
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 1 , 1.5 ] ] )
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] * * = 2
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 4 , 9 ] ] )
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] ^ = 5
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 7 , 6 ] ] )
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_setitem_consecutive_inplace_operator ( self ) :
t = Tensor . arange ( 4 ) . reshape ( 2 , 2 ) . contiguous ( )
t [ 1 ] + = 2
t = t . contiguous ( )
# TODO: RuntimeError: can't double realize in one schedule
t [ 1 ] - = 1
np . testing . assert_allclose ( t . numpy ( ) , [ [ 0 , 1 ] , [ 3 , 4 ] ] )
def test_setitem_overlapping_indices ( self ) :
t = Tensor ( [ 1 , 2 , 3 , 4 ] )
# regular overlapping indices
t [ [ 1 , 1 ] ] = Tensor ( [ 5 , 6 ] )
np . testing . assert_allclose ( t . numpy ( ) , [ 1 , 6 , 3 , 4 ] )
# overlapping indices with zero value overlapped
t [ [ 1 , 1 ] ] = Tensor ( [ 0 , 1 ] )
np . testing . assert_allclose ( t . numpy ( ) , [ 1 , 1 , 3 , 4 ] )
def test_setitem_overlapping_indices_with_0 ( self ) :
t = Tensor ( [ 1 , 2 , 3 , 4 ] )
t [ [ 1 , 1 ] ] = Tensor ( [ 1 , 0 ] )
np . testing . assert_allclose ( t . numpy ( ) , [ 1 , 0 , 3 , 4 ] )
def test_setitem_with_1_in_shape ( self ) :
t = Tensor ( [ [ 1 ] , [ 2 ] , [ 3 ] ] )
t [ [ 0 , 0 ] ] = Tensor ( [ [ 1 ] , [ 2 ] ] )
np . testing . assert_allclose ( t . numpy ( ) , [ [ 2 ] , [ 2 ] , [ 3 ] ] )
def test_fancy_setitem ( self ) :
t = Tensor . zeros ( 6 , 6 ) . contiguous ( )
t [ [ 1 , 2 ] , [ 3 , 2 ] ] = 3
n = np . zeros ( ( 6 , 6 ) )
n [ [ 1 , 2 ] , [ 3 , 2 ] ] = 3
np . testing . assert_allclose ( t . numpy ( ) , n )
def test_simple_jit_setitem ( self ) :
@TinyJit
def f ( t : Tensor , a : Tensor ) :
t [ 2 : 4 , 3 : 5 ] = a
for i in range ( 1 , 6 ) :
t = Tensor . zeros ( 6 , 6 ) . contiguous ( ) . realize ( )
a = Tensor . full ( ( 2 , 2 ) , fill_value = i , dtype = dtypes . float ) . contiguous ( )
f ( t , a )
n = np . zeros ( ( 6 , 6 ) )
n [ 2 : 4 , 3 : 5 ] = np . full ( ( 2 , 2 ) , i )
np . testing . assert_allclose ( t . numpy ( ) , n )
def test_jit_setitem_variable_offset ( self ) :
@TinyJit
def f ( t : Tensor , a : Tensor , v : Variable ) :
t . shrink ( ( ( v , v + 1 ) , None ) ) . assign ( a ) . realize ( )
t = Tensor . zeros ( 6 , 6 ) . contiguous ( ) . realize ( )
n = np . zeros ( ( 6 , 6 ) )
for i in range ( 6 ) :
v = Variable ( " v " , 0 , 6 ) . bind ( i )
a = Tensor . full ( ( 1 , 6 ) , fill_value = i + 1 , dtype = dtypes . float ) . contiguous ( )
n [ i , : ] = i + 1
f ( t , a , v )
np . testing . assert_allclose ( t . numpy ( ) , n )
np . testing . assert_allclose ( t . numpy ( ) , [ [ 1 , 1 , 1 , 1 , 1 , 1 ] , [ 2 , 2 , 2 , 2 , 2 , 2 ] , [ 3 , 3 , 3 , 3 , 3 , 3 ] , [ 4 , 4 , 4 , 4 , 4 , 4 ] , [ 5 , 5 , 5 , 5 , 5 , 5 ] , [ 6 , 6 , 6 , 6 , 6 , 6 ] ] )
def test_setitem_overlapping_inplace1 ( self ) :
t = Tensor ( [ [ 3.0 ] , [ 2.0 ] , [ 1.0 ] ] ) . contiguous ( )
t [ 1 : ] = t [ : - 1 ]
self . assertEqual ( t . tolist ( ) , [ [ 3.0 ] , [ 3.0 ] , [ 2.0 ] ] )
def test_setitem_overlapping_inplace2 ( self ) :
t = Tensor ( [ [ 3.0 ] , [ 2.0 ] , [ 1.0 ] ] ) . contiguous ( )
t [ : - 1 ] = t [ 1 : ]
self . assertEqual ( t . tolist ( ) , [ [ 2.0 ] , [ 1.0 ] , [ 1.0 ] ] )
class TestWithGrad ( unittest . TestCase ) :
def test_no_requires_grad_works ( self ) :
z = Tensor . rand ( 8 , 8 )
x = Tensor . rand ( 8 )
z [ : 3 ] = x
def test_set_into_requires_grad ( self ) :
z = Tensor . rand ( 8 , 8 , requires_grad = True )
x = Tensor . rand ( 8 )
with self . assertRaises ( NotImplementedError ) :
z [ : 3 ] = x
def test_set_with_requires_grad ( self ) :
z = Tensor . rand ( 8 , 8 )
x = Tensor . rand ( 8 , requires_grad = True )
with self . assertRaises ( NotImplementedError ) :
z [ : 3 ] = x
if __name__ == ' __main__ ' :
unittest . main ( )