# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
import unittest
import numpy as np
import functools
from typing import List , Optional , Union , cast
from tinygrad import nn , dtypes , Device , Tensor
from tinygrad . device import is_dtype_supported
from tinygrad . dtype import DType , ImageDType
from tinygrad . shape . shapetracker import ShapeTracker
from tinygrad . ops import PatternMatcher , UOp , Ops , UPat , graph_rewrite , track_rewrites , merge_views , GroupOp
from tinygrad . codegen . symbolic import symbolic_simple
from tinygrad . spec import type_verify , shape_spec
from tinygrad . helpers import CI , DEBUG , FUSE_ARANGE , SPLIT_REDUCEOP , GlobalCounters , Context , getenv , all_same , temp
from tinygrad . engine . schedule import ScheduleItem , create_schedule_with_vars , view_right , view_left , sym
from tinygrad . engine . realize import CompiledRunner , run_schedule , lower_schedule
from extra . models . llama import precompute_freqs_cis
remove_movement_ops = merge_views
def verify_ast ( sink : UOp ) : return type_verify ( list ( sink . toposort ) , shape_spec )
class KernelCountException ( Exception ) : pass
def check_schedule ( t : Union [ Tensor , List [ Tensor ] , UOp ] , allowed : int , to_prerealize : Optional [ List [ Tensor ] ] = None , filter_sink = True ) :
if to_prerealize :
for pre in to_prerealize : pre . schedule ( )
if isinstance ( t , Tensor ) : sched = t . schedule ( )
elif isinstance ( t , List ) and isinstance ( t [ 0 ] , Tensor ) : sched = Tensor . schedule ( * t )
else :
assert isinstance ( t , UOp ) , f " can ' t schedule { t } "
sched , _ , __ = create_schedule_with_vars ( t . sink ( ) )
# test lowering all the ScheduleItems to ExecItems
lowered = [ x [ 1 ] for x in lower_schedule ( sched . copy ( ) ) ]
if filter_sink : sched = [ s for s , ei in zip ( sched , lowered ) if isinstance ( ei . prg , CompiledRunner ) ]
if len ( sched ) != allowed :
print ( f " SCHEDULE ISSUE, expecting { allowed } got { len ( sched ) } " )
if DEBUG > = 3 :
for i , s in enumerate ( sched ) :
print ( " kernel " , i + 1 )
print ( s . ast )
raise KernelCountException ( f " { len ( sched ) =} != { allowed } " )
return sched
def _realize_weights ( m ) :
for p in nn . state . get_parameters ( m ) : p . realize ( )
def _test_conv2d ( allowed : int , dtype : DType = dtypes . float , * * kwargs ) :
old_default_float , dtypes . default_float = dtypes . default_float , dtype
dtypes . default_float = dtype
Tensor . manual_seed ( 0 )
BS , CIN = 2 , 3
img = Tensor . randn ( BS , CIN , 64 , 64 , requires_grad = True ) . realize ( )
w = Tensor . uniform ( 16 , CIN , 3 , 3 , requires_grad = True ) . realize ( )
ret = Tensor . conv2d ( img , w ) . relu ( ) . mean ( ) . backward ( )
dtypes . default_float = old_default_float
with Context ( * * kwargs ) : s = Tensor . schedule ( ret , img . grad , w . grad )
run_schedule ( s . copy ( ) )
cnt = len ( [ si for si in s if si . ast . op is Ops . SINK ] )
assert cnt == allowed , f " expected { allowed } kernels, got { cnt } "
if getenv ( " CHECK " , 1 ) :
import torch
ref_img = torch . tensor ( img . numpy ( ) , requires_grad = True )
ref_w = torch . tensor ( w . numpy ( ) , requires_grad = True )
torch . nn . functional . conv2d ( ref_img , ref_w ) . relu ( ) . mean ( ) . backward ( )
assert ref_img . grad is not None and ref_w . grad is not None and img . grad is not None and w . grad is not None
np . testing . assert_allclose ( img . grad . numpy ( ) , ref_img . grad . detach ( ) . numpy ( ) , atol = 1e-6 if dtype == dtypes . float else 1e-2 )
np . testing . assert_allclose ( w . grad . numpy ( ) , ref_w . grad . detach ( ) . numpy ( ) , atol = 1e-6 if dtype == dtypes . float else 1e-2 )
@track_rewrites ( named = True )
def schedule_graph_rewrite ( big_sink : UOp ) : return graph_rewrite ( big_sink , remove_movement_ops + sym , { } )
class TestSchedule ( unittest . TestCase ) :
@unittest . skipIf ( Device . DEFAULT == " CPU " , " devices must mismatch " )
def test_error_on_device_mismatch ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 , device = " CPU " )
c = a + b
with self . assertRaises ( RuntimeError ) : check_schedule ( c , 1 )
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) and getenv ( " CAST_AFTER_EXPAND " ) , " need half and CAST_AFTER_EXPAND=1 " )
def test_expand_buffer_before_cast ( self ) :
a = Tensor . randn ( 4 , 2 , 1 ) . realize ( ) . permute ( ( 1 , 0 , 2 ) )
b = a . cast ( dtypes . half ) . expand ( ( 2 , 4 , 4 ) ) + 2
run_schedule ( check_schedule ( b , 1 ) )
np . testing . assert_allclose ( b . numpy ( ) , np . broadcast_to ( a . numpy ( ) . astype ( np . float16 ) , ( 2 , 4 , 4 ) ) + 2 )
def test_empty_is_not_realized ( self ) :
a = Tensor . empty ( 10 )
child = a + 2
assert not a . lazydata . is_realized
child . realize ( )
assert a . lazydata . is_realized
# NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated.
def test_childless_empty_never_allocates ( self ) :
a = Tensor . empty ( 10 )
a . realize ( )
assert not a . lazydata . is_realized
def test_simplify_padded_const ( self ) :
a = Tensor . empty ( 1022 ) . cummax ( axis = 0 )
sched = check_schedule ( a , 5 )
ast = sched [ 0 ] . ast
self . assertLessEqual ( len ( [ u for u in ast . toposort if u . op is Ops . WHERE ] ) , 6 )
def test_basic_binop_fusion ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = Tensor . empty ( 10 )
d = a + b + c
check_schedule ( d , 1 )
def test_basic_binop_fusion_deep ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = Tensor . empty ( 10 )
d = Tensor . empty ( 10 )
e = a + b + c + d
check_schedule ( e , 1 )
def test_mulacc_fusion ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = ( a * b ) . sum ( )
check_schedule ( c , 1 )
def test_mulacc_relu_fusion ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = ( a * b ) . sum ( ) . relu ( )
check_schedule ( c , 1 )
def test_binop_reshape_fusion ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = Tensor . empty ( 5 , 2 )
d = ( a + b ) . reshape ( 5 , 2 ) + c
check_schedule ( d , 1 )
def test_binop_permute_fusion ( self ) :
a = Tensor . empty ( 2 , 5 )
b = Tensor . empty ( 2 , 5 )
c = Tensor . empty ( 5 , 2 )
d = ( a + b ) . permute ( 1 , 0 ) + c
check_schedule ( d , 1 )
def test_constants_are_embedded ( self ) :
a = Tensor . empty ( 3 , 3 ) * 2
check_schedule ( a , 1 , filter_sink = False )
def tests_constants_are_folded ( self ) :
a = Tensor ( 2 )
check_schedule ( a , 0 )
def test_constants_can_store ( self ) :
a = Tensor ( 2 ) . contiguous ( )
run_schedule ( check_schedule ( a , 1 ) )
np . testing . assert_equal ( a . numpy ( ) , 2 )
def test_binop_elu_fusion ( self ) :
a = Tensor . empty ( 10 )
b = a . elu ( )
check_schedule ( b , 1 )
def test_binop_reshape_reduce_fusion ( self ) :
a = Tensor . empty ( 100 )
b = Tensor . empty ( 100 )
c = ( a + b ) . reshape ( 10 , 10 ) . sum ( axis = 0 , keepdim = True )
check_schedule ( c , 1 )
def test_reduce_reshape_binop_fusion ( self ) :
a = Tensor . empty ( 10 , 10 )
b = Tensor . empty ( 10 )
c = a . sum ( axis = 0 ) + b
check_schedule ( c , 1 )
# not pushing permutes through reduces
def test_reduce_permute_binop_fusion ( self ) :
a = Tensor . empty ( 10 , 10 , 10 )
b = Tensor . empty ( 10 , 10 , 1 )
c = a . sum ( axis = 0 , keepdim = True ) . permute ( 2 , 1 , 0 ) + b
with self . assertRaises ( KernelCountException ) : check_schedule ( c , 1 )
def test_allow_push_permutes ( self ) :
a = Tensor . randn ( 10 , 10 , 10 ) . realize ( )
b = Tensor . randn ( 10 , 10 , 1 ) . realize ( )
c = a . sum ( axis = 0 , keepdim = True ) . permute ( 2 , 1 , 0 ) + b
with Context ( DONT_GROUP_REDUCES = 1 ) : run_schedule ( check_schedule ( c , 1 ) )
np . testing . assert_allclose ( c . numpy ( ) , np . sum ( a . numpy ( ) , axis = 0 , keepdims = True ) . transpose ( 2 , 1 , 0 ) + b . numpy ( ) )
def test_binop_early_reshape_reduce_fusion ( self ) :
a = Tensor . empty ( 100 )
b = Tensor . empty ( 100 )
c = Tensor . empty ( 10 , 10 )
d = ( ( a + b ) . reshape ( 10 , 10 ) + c ) . sum ( axis = 0 )
check_schedule ( d , 1 )
def test_diamond_folded ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = Tensor . empty ( 10 )
d = Tensor . empty ( 10 )
ab = a + b
e = ( ab + c ) + ( ab + d )
check_schedule ( e , 1 )
def test_cache_binaryop ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = a + b
d = a + b
check_schedule ( d , 0 , [ c ] )
# failing in new lazy
def test_cache_binaryop_reshaped ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = a + b
d = a . reshape ( 10 , 1 ) + b . reshape ( 10 , 1 )
with self . assertRaises ( KernelCountException ) : check_schedule ( d , 0 , [ c ] )
# failing in new lazy
def test_cache_binaryop_transpose ( self ) :
a = Tensor . empty ( 10 , 10 )
b = Tensor . empty ( 10 , 10 )
c = ( a . T * b . T ) . T #.contiguous()
d = a * b
with self . assertRaises ( KernelCountException ) : check_schedule ( d , 0 , [ c ] )
def test_cache_two_reduceops ( self ) :
a = Tensor . empty ( 10 )
b = a . sum ( )
c = a . sum ( )
bc = b + c
check_schedule ( bc , 1 )
def test_cache_reduce_parent ( self ) :
x = Tensor . empty ( 32 )
r0 = x . mean ( axis = 0 , keepdim = True )
r1 = ( x - r0 ) . sum ( axis = 0 ) . div ( 2 )
out = r0 + r1
schedule = check_schedule ( out , 2 )
reduceops = [ x for si in schedule for x in si . ast . toposort if x . op is Ops . REDUCE_AXIS ]
assert len ( reduceops ) == 2
def test_cache_reduce_multiple_children ( self ) :
x = Tensor . empty ( 32 )
y = Tensor . empty ( 4 , 4 )
r0 = x . mean ( axis = 0 , keepdim = True )
r1 = ( x - r0 ) . sum ( axis = 0 ) . div ( 2 )
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule ( [ out0 , out1 ] , 4 )
reduceops = [ x for si in schedule for x in si . ast . toposort if x . op is Ops . REDUCE_AXIS ]
assert len ( reduceops ) == 2
def test_div_collapse_buffer ( self ) :
a = Tensor . full ( ( 4 , ) , 4.0 ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 4 , ) , 2.0 ) . contiguous ( ) . realize ( )
expr = ( a * b ) / b
check_schedule ( expr , 0 )
np . testing . assert_allclose ( expr . numpy ( ) , np . full ( ( 4 , ) , 4.0 ) )
def test_div_collapse_const ( self ) :
a = Tensor . full ( ( 4 , ) , 4.0 ) . contiguous ( ) . realize ( )
expr = a / a
check_schedule ( expr , 0 )
np . testing . assert_allclose ( expr . numpy ( ) , np . full ( ( 4 , ) , 1.0 ) )
def test_div_collapse ( self ) :
a = Tensor . full ( ( 4 , ) , 1.0 ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 4 , ) , 2.0 ) . contiguous ( ) . realize ( )
c = Tensor . full ( ( 4 , ) , 3.0 ) . contiguous ( ) . realize ( )
GlobalCounters . reset ( )
expr = ( a / b ) / c
expr . realize ( )
self . assertEqual ( GlobalCounters . kernel_count , 1 )
self . assertLessEqual ( GlobalCounters . global_ops , 4 * 3 )
np . testing . assert_allclose ( expr . numpy ( ) , ( a . numpy ( ) / b . numpy ( ) ) / c . numpy ( ) )
def test_dedup_assign ( self ) :
a = Tensor . ones ( 4 ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 4 , ) , 2. ) . contiguous ( )
first = a . assign ( b )
second = a . assign ( b )
check_schedule ( [ first , second ] , 1 )
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
# should contiguous dedup?
def test_dedup_contiguous ( self ) :
a = Tensor . ones ( 4 ) . contiguous ( )
b = Tensor . ones ( 4 ) . contiguous ( )
sched = check_schedule ( [ a , b ] , 1 )
run_schedule ( sched )
# a and b share the same underlying device memory
self . assertIs ( a . lazydata . realized , b . lazydata . realized )
def test_copy_dedups ( self ) :
src = Tensor . ones ( 4 ) . contiguous ( ) . realize ( )
a = src . clone ( )
b = src . clone ( )
sched = check_schedule ( [ a , b ] , 1 , filter_sink = False )
run_schedule ( sched )
# a and b are assigned to the same device Buffer
self . assertIs ( a . lazydata . realized , b . lazydata . realized )
# EMPTY is assigned to a unique device Buffer
def test_no_dedup_empty ( self ) :
a = Tensor . empty ( ( 4 , ) )
b = Tensor . empty ( ( 4 , ) )
# NOTE: empty does not have any schedule
check_schedule ( [ a , b ] , 0 , filter_sink = False )
self . assertIsNot ( a . lazydata . buffer , b . lazydata . buffer )
def test_dedup_outputs ( self ) :
a = Tensor . full ( ( 4 , 4 ) , 1. ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 4 , 4 ) , 1. ) . contiguous ( ) . realize ( )
check_schedule ( [ a + b , a + b ] , 1 )
def test_fold_double_unary ( self ) :
y = Tensor . empty ( 2 )
out = y . sum ( keepdim = True ) . sqrt ( ) . neg ( )
check_schedule ( out , 1 )
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 1 , 32 , 4 , 4 )
bn = nn . BatchNorm2d ( 32 , track_running_stats = False )
out = bn ( img )
check_schedule ( out , 3 )
def test_fold_conv_batchnorm_notrain ( self ) :
with Tensor . train ( False ) :
img = Tensor . empty ( 1 , 3 , 8 , 8 )
c1 = nn . Conv2d ( 3 , 32 , 3 )
bn = nn . BatchNorm2d ( 32 , track_running_stats = True )
out = bn ( c1 ( img ) ) . relu ( )
check_schedule ( out , 1 , [ c1 . weight , c1 . bias ] )
def test_fold_conv_batchnorm_notrain_no_running_stats ( self ) :
with Tensor . train ( False ) :
img = Tensor . empty ( 1 , 3 , 8 , 8 )
c1 = nn . Conv2d ( 3 , 32 , 3 )
bn = nn . BatchNorm2d ( 32 , track_running_stats = False )
out = bn ( c1 ( img ) ) . relu ( )
check_schedule ( out , 4 , [ c1 . weight , c1 . bias ] )
def test_fold_conv_batchnorm ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 1 , 3 , 8 , 8 )
c1 = nn . Conv2d ( 3 , 32 , 3 )
bn = nn . BatchNorm2d ( 32 , track_running_stats = False )
out = bn ( c1 ( img ) ) . relu ( )
check_schedule ( out , 4 , [ c1 . weight , c1 . bias ] )
@unittest . skipUnless ( is_dtype_supported ( dtypes . ulong ) , " Needs ulong " )
def test_fold_conv_batchnorm_optim ( self ) :
# this is too high
for optim , cnt in [ ( nn . optim . Adam , 30 ) , ( nn . optim . SGD , 11 ) ] :
with self . subTest ( optim = optim . __name__ ) :
with Tensor . train ( ) :
img = Tensor . ones ( 1 , 3 , 4 , 4 )
c1 = nn . Conv2d ( 3 , 32 , 3 )
bn = nn . BatchNorm2d ( 32 , track_running_stats = False )
_realize_weights ( [ c1 , bn ] )
opt = optim ( nn . state . get_parameters ( [ c1 , bn ] ) )
img_bn = bn ( c1 ( img ) ) . elu ( ) . sum ( )
opt . zero_grad ( )
img_bn . backward ( )
check_schedule ( opt . schedule_step ( ) , cnt )
def test_fold_batchnorm_backward ( self ) :
with Context ( FUSE_CONV_BW = 1 ) :
with Tensor . train ( ) :
x = Tensor . empty ( ( 2 , 16 , 8 , 8 ) ) . contiguous ( )
bn = nn . BatchNorm2d ( 16 )
bn . weight . requires_grad = bn . bias . requires_grad = x . requires_grad = True
fw = bn ( x ) . contiguous_backward ( ) . relu ( ) . contiguous ( )
fw . sum ( ) . backward ( )
# TODO: this is too many
check_schedule ( [ x . grad , bn . weight . grad , bn . bias . grad , fw ] , 10 )
def test_fold_conv_relu ( self ) :
c1 = nn . Conv2d ( 3 , 16 , 3 )
# run
img = Tensor . ones ( 2 , 3 , 64 , 64 )
out = c1 ( img ) . relu ( )
check_schedule ( out , 1 , [ c1 . weight , c1 . bias ] )
def test_fold_conv_relu_alt ( self ) :
img = Tensor . ones ( 1 , 4 , 8 , 8 )
c1 = nn . Conv2d ( 4 , 4 , kernel_size = 3 )
c2 = nn . Conv2d ( 4 , 4 , kernel_size = 3 )
img_conv = img . sequential ( [ c1 , Tensor . relu , c2 , Tensor . relu ] )
check_schedule ( img_conv , 2 , [ * nn . state . get_parameters ( c1 ) , * nn . state . get_parameters ( c2 ) , img ] )
def test_fold_conv_relu_nobias ( self ) :
img = Tensor . ones ( 1 , 4 , 8 , 8 )
c1 = nn . Conv2d ( 4 , 4 , kernel_size = 3 , bias = False )
c2 = nn . Conv2d ( 4 , 4 , kernel_size = 3 , bias = False )
out = img . sequential ( [ c1 , Tensor . relu , c2 , Tensor . relu ] )
check_schedule ( out , 2 , [ c1 . weight , c2 . weight , img ] )
def test_fold_conv_elu ( self ) :
c1 = nn . Conv2d ( 3 , 16 , 3 )
# run
img = Tensor . rand ( 2 , 3 , 64 , 64 )
out = c1 ( img ) . elu ( )
check_schedule ( out , 1 , [ c1 . weight , c1 . bias , img ] )
def test_fold_conv_elu_alt ( self ) :
img = Tensor . ones ( 1 , 4 , 8 , 8 ) . contiguous ( )
c1 = nn . Conv2d ( 4 , 4 , kernel_size = 3 )
c2 = nn . Conv2d ( 4 , 4 , kernel_size = 3 )
img_conv = img . sequential ( [ c1 , Tensor . elu , c2 , Tensor . elu ] )
check_schedule ( img_conv , 2 , [ * nn . state . get_parameters ( c1 ) , * nn . state . get_parameters ( c2 ) , img ] )
def test_two_sum ( self ) :
img = Tensor . empty ( 64 , 64 )
x = ( img . sum ( 0 ) + img . sum ( 1 ) )
out = x . relu ( )
check_schedule ( out , 2 )
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape ( self ) :
a = Tensor . empty ( 16 , 16 )
b = Tensor . empty ( 16 , 16 )
c = ( a + b ) . reshape ( 4 , 4 , 4 , 4 ) . permute ( 2 , 3 , 0 , 1 ) . contiguous ( )
check_schedule ( c , 1 )
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape_alt ( self ) :
a = Tensor . empty ( 4 , 4 , 4 , 4 )
b = Tensor . empty ( 4 , 4 , 4 , 4 )
c = ( a + b ) . reshape ( 16 , 16 ) . permute ( 1 , 0 ) . contiguous ( )
check_schedule ( c , 1 )
def test_no_binop_rerun ( self ) :
a = Tensor . empty ( 16 )
b = Tensor . empty ( 16 )
c = a + b
d = ( a + b ) . reshape ( 16 , 1 )
check_schedule ( d , 0 , [ c ] )
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) , " need half " )
def test_multi_permute_should_collapse ( self ) :
a = Tensor . empty ( 4 , 4 , 4 , 4 )
b = Tensor . empty ( 16 )
c = a . sum ( ( 0 , 1 ) ) . cast ( dtypes . float16 ) . permute ( 1 , 0 ) . reshape ( 4 , 4 , 1 ) . permute ( 1 , 0 , 2 ) . reshape ( 16 ) + b
check_schedule ( c , 1 )
def test_fancy_reshape_fusion ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = a + b
d = a . reshape ( 10 , 1 ) + b . reshape ( 10 , 1 )
out = c . sum ( ) + d . sum ( )
with self . assertRaises ( KernelCountException ) : check_schedule ( out , 1 )
def test_children_dont_push ( self ) :
a = Tensor . empty ( 10 , 10 , 1 )
b = Tensor . empty ( 10 , 10 , 1 )
d = ( a + b ) . expand ( 10 , 10 , 10 )
e = ( a + b ) . permute ( 2 , 1 , 0 )
f = d + e
check_schedule ( f , 2 )
# failing in new lazy
def test_dont_fuse_binops_with_children ( self ) :
a = Tensor . empty ( 10 )
b = Tensor . empty ( 10 )
c = Tensor . empty ( 10 )
keep_me = a + b
e = keep_me . sum ( ) # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me + c
with self . assertRaises ( KernelCountException ) : check_schedule ( d , 2 )
with self . assertRaises ( KernelCountException ) : check_schedule ( keep_me , 0 , [ d ] )
#@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion ( self ) :
a = Tensor . empty ( 10 , 10 , 10 )
b = Tensor . empty ( 10 , 10 )
c = ( a . sum ( axis = 2 ) + b ) . permute ( 1 , 0 )
d = c . permute ( 1 , 0 )
check_schedule ( d , 1 )
def test_some_permute_fusion ( self ) :
a = Tensor . empty ( 8192 , 16 )
b = Tensor . empty ( 1 , 16 )
d = ( a . T + b . expand ( 8192 , 16 ) . T )
c = a + b . expand ( 8192 , 16 )
e = d . T
check_schedule ( c , 1 )
check_schedule ( e , 1 )
def test_shrink_fuse ( self ) :
a = Tensor . empty ( 8192 , 16 )
b = Tensor . empty ( 8192 , 16 )
c = a * b
d = Tensor . empty ( 1 , 16 )
e = c [ 0 ] * d
check_schedule ( e , 1 )
def test_expand_nofuse ( self ) :
a = Tensor . empty ( 1 , 16 )
b = Tensor . empty ( 1 , 16 )
c = a * b
d = Tensor . empty ( 8192 , 16 )
e = c * d
check_schedule ( e , 2 )
# this is the failing case in openpilot...it's very simple like this
def test_image_conv_fusion ( self ) :
w1 = Tensor . empty ( 16 , 16 , 1 , 1 )
b1 = Tensor . empty ( 16 )
w2 = Tensor . empty ( 16 , 16 , 1 , 1 )
b2 = Tensor . empty ( 16 )
w3 = Tensor . empty ( 16 , 16 , 1 , 1 )
b3 = Tensor . empty ( 16 )
x = Tensor . empty ( 1 , 16 , 32 , 32 )
x = base = x . image_conv2d ( w1 , b1 )
x = x . image_conv2d ( w2 , b2 ) + base
x = x . image_conv2d ( w3 , b3 )
# NOOP, 3 convs, contiguous
with self . assertRaises ( KernelCountException ) : check_schedule ( x , 5 )
def test_image_conv_fusion_minimal ( self ) :
b1 = Tensor . empty ( 16 )
b2 = Tensor . empty ( 16 )
def p ( x ) : return x . permute ( 1 , 0 ) . contiguous ( ) . reshape ( 32 , 16 , 1 ) . expand ( 32 , 16 , 16 ) . sum ( axis = 2 ) . permute ( 1 , 0 )
x = Tensor . empty ( 16 , 32 )
x = base = p ( x ) + b1 . reshape ( 16 , 1 )
x = p ( x )
x = x + b2 . reshape ( 16 , 1 )
x = x + base
del base
x = p ( x )
check_schedule ( x , 4 )
def test_image_conv_fusion_more_minimal ( self ) :
b1 = Tensor . empty ( 16 )
def p ( x ) : return x . permute ( 1 , 0 ) . contiguous ( ) . reshape ( 32 , 16 , 1 ) . expand ( 32 , 16 , 16 ) . sum ( axis = 2 ) . permute ( 1 , 0 )
x = Tensor . empty ( 16 , 32 )
x = base = p ( x ) + b1 . reshape ( 16 , 1 )
x = p ( x )
del base
check_schedule ( x , 3 )
def test_resnet_block ( self ) :
old_training = Tensor . training
Tensor . training = False
in_planes , planes = 64 , 64
conv1 = nn . Conv2d ( in_planes , planes , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
bn1 = nn . BatchNorm2d ( planes )
conv2 = nn . Conv2d ( planes , planes , kernel_size = 3 , padding = 1 , stride = 1 , bias = False )
bn2 = nn . BatchNorm2d ( planes )
x = Tensor . empty ( 1 , 64 , 32 , 32 )
out = bn1 ( conv1 ( x ) ) . relu ( )
out = bn2 ( conv2 ( out ) )
out = ( out + x ) . relu ( )
check_schedule ( out , 2 , [ conv1 . weight , conv2 . weight ] )
Tensor . training = old_training
def test_contiguous_while_contiguous ( self ) :
x = Tensor . empty ( 1 , 64 , 32 , 32 )
out = x . contiguous ( )
check_schedule ( out , 0 , filter_sink = False )
def test_contiguous_while_not_contiguous ( self ) :
x = Tensor . empty ( 1 , 64 , 32 , 32 )
out = x . permute ( 0 , 2 , 3 , 1 ) . contiguous ( )
check_schedule ( out , 1 , filter_sink = False )
def test_fold_with_contiguous ( self ) :
a = Tensor . randn ( 16 , 16 , 16 ) . realize ( )
b = Tensor . randn ( 16 , 16 ) . realize ( )
c = ( a . sum ( 2 ) . contiguous ( ) + b ) . contiguous ( )
check_schedule ( c , 2 )
@unittest . skip ( " no longer supported " )
def test_double_from ( self ) :
x = Tensor ( [ 1 , 2 , 3 , 4 ] )
out = x . to ( ' python ' )
check_schedule ( out , 0 , filter_sink = False )
def _alu_from_tensor ( self , t : Tensor ) :
s = [ s for s in t . schedule ( ) if s . ast . op is Ops . SINK ]
self . assertEqual ( len ( s ) , 1 )
return [ u . op for u in s [ 0 ] . ast . toposort if u . op in GroupOp . ALU ]
def test_2_pow_is_exp2 ( self ) :
t = 2.0 * * Tensor ( [ 1.0 , 2.0 , 3.0 ] )
self . assertEqual ( self . _alu_from_tensor ( t ) , [ Ops . EXP2 ] )
def test_pow_05_is_sqrt ( self ) :
t = Tensor ( [ 1.0 , 2.0 , 3.0 ] ) * * 0.5
self . assertEqual ( self . _alu_from_tensor ( t ) , [ Ops . SQRT ] )
def test_pow_neg_05_is_rsqrt ( self ) :
t = Tensor ( [ 1.0 , 2.0 , 3.0 ] ) * * - 0.5
self . assertEqual ( self . _alu_from_tensor ( t ) , [ Ops . RECIP , Ops . SQRT ] )
def test_pow_2_has_1_mul ( self ) :
t = Tensor ( [ 1.0 , 2.0 , 3.0 ] ) * * Tensor ( 2.0 )
self . assertEqual ( self . _alu_from_tensor ( t ) , [ Ops . MUL ] )
def test_pow_8_has_3_muls ( self ) :
t = Tensor ( [ 1.0 , 2.0 , 3.0 ] ) * * 8
self . assertEqual ( self . _alu_from_tensor ( t ) , [ Ops . MUL , Ops . MUL , Ops . MUL ] )
def test_pow_const_tensor_to_zero ( self ) :
x = Tensor ( [ 1 , 2 , 3 , 4 ] )
out = x * * Tensor ( 0.0 )
# NOTE: this is ConstBuffer 0 + ConstBuffer 1
check_schedule ( out , 0 )
def test_zero_size ( self ) :
x = Tensor . empty ( 2 , 3 , 0 )
out = x + 1
check_schedule ( out , 0 , filter_sink = False )
def test_reduce_permute_nofuse ( self ) :
x = Tensor . empty ( 32 , 32 , 32 )
y = Tensor . empty ( 32 , 32 )
out = x . sum ( axis = 2 ) . T + y
check_schedule ( out , 2 )
def test_two_elus_sum ( self ) :
x = Tensor . empty ( 32 , 32 )
y = Tensor . empty ( 32 , 32 )
out = x . sum ( 1 ) . relu ( ) . elu ( ) + y . sum ( 1 ) . relu ( ) . elu ( )
check_schedule ( out , 2 )
# multireduce spec
@unittest . skipUnless ( SPLIT_REDUCEOP , " Testing split reducop requires SPLIT_REDUCEOP " )
def test_preserve_multistage_reduce ( self ) :
big_enough = getenv ( " REDUCEOP_SPLIT_THRESHOLD " , 32768 )
x = Tensor . randn ( big_enough ) . realize ( )
out = ( x - x . max ( keepdim = True ) ) . max ( )
run_schedule ( check_schedule ( out , 4 ) )
np . testing . assert_allclose ( out . numpy ( ) , ( x . numpy ( ) - x . numpy ( ) . max ( keepdims = True ) ) . max ( ) )
def test_multistage_reduce ( self ) :
x = Tensor . empty ( 32 , 32 , 32 )
out = x . sum ( 2 ) . relu ( ) . sum ( 1 )
check_schedule ( out , 2 )
def test_multistage_reduce_fork ( self ) :
x = Tensor . empty ( 32 , 32 , 32 )
x = x . sum ( 2 )
out2 = x + 1
out = x . relu ( ) . sum ( 1 ) + out2 [ 0 ]
check_schedule ( out , 2 )
# multireduce spec
@unittest . skip ( " these two Tensors are the same " )
def test_example_matmul ( self ) :
x = Tensor . eye ( 64 , requires_grad = True )
y = Tensor . eye ( 64 , requires_grad = True )
z = y . matmul ( x ) . sum ( )
z . backward ( )
out = x . grad . contiguous ( )
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . ones ( ( 64 , 64 ) ) )
def test_example_matmul_contig ( self ) :
x = Tensor . eye ( 64 , requires_grad = True ) . contiguous ( ) . realize ( )
y = Tensor . eye ( 64 , requires_grad = True ) . contiguous ( ) . realize ( )
z = y . matmul ( x ) . sum ( )
z . backward ( )
out = x . grad . contiguous ( )
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . ones ( ( 64 , 64 ) ) )
def test_example_matmul_same ( self ) :
x = Tensor . eye ( 64 , requires_grad = True )
z = x . matmul ( x ) . sum ( )
z . backward ( )
out = x . grad . contiguous ( )
run_schedule ( check_schedule ( out , 2 ) )
# NOTE: the gradient flows twice
np . testing . assert_allclose ( out . numpy ( ) , 2 * np . ones ( ( 64 , 64 ) ) )
def test_contiguous_add ( self ) :
x = Tensor . empty ( 32 )
y = Tensor . empty ( 32 )
z = Tensor . empty ( 32 )
out = ( x + y ) . contiguous ( ) + z
check_schedule ( out , 2 )
def test_double_sum_ref ( self ) :
x = Tensor . empty ( 32 , 32 , 32 )
x = x . sum ( 2 )
out = x + x [ : , 4 ]
check_schedule ( out , 2 )
def test_reduce_shrink ( self ) :
x = Tensor . empty ( 32 , 32 )
y = Tensor . empty ( 16 )
x = x . sum ( 1 )
x = x [ : 16 ]
out = x + y
check_schedule ( out , 2 ) # TODO: this should be 1
# multireduce spec
def test_multireduce_shrink ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 32 , 32 ) . realize ( )
b = Tensor . randn ( 32 , 32 ) . realize ( )
c = Tensor . randn ( 16 ) . realize ( )
a_out = a . sum ( 1 )
a_out = a_out [ : 16 ]
b_out = b . sum ( 1 )
b_out = b_out [ : 16 ]
out = a_out + b_out + c
# run_schedule(check_schedule(out, 2)) # TODO: this should be 1 (can we make it 1 with the new linearizer?)
run_schedule ( check_schedule ( out , 3 ) )
np . testing . assert_allclose ( out . numpy ( ) , a . numpy ( ) . sum ( axis = 1 ) [ : 16 ] + b . numpy ( ) . sum ( axis = 1 ) [ : 16 ] + c . numpy ( ) , atol = 1e-4 , rtol = 1e-4 )
# broken due to const folding and two contiguous are different kernels
# NOTE: passes after delete_lazy
def test_const_no_recompute ( self ) :
x = Tensor ( 2 ) + Tensor ( 2 )
y = Tensor ( 2 ) + Tensor ( 2 )
out = x . contiguous ( ) + y . contiguous ( )
check_schedule ( out , 2 , filter_sink = False )
# multireduce spec
@unittest . expectedFailure
def test_reduce_same_size ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
out0 = a . sum ( ) + 2
out1 = a . sum ( ) + 4
out2 = out0 * out1
run_schedule ( check_schedule ( [ out0 , out1 , out2 ] , 1 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , out0_np := a . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-6 )
np . testing . assert_allclose ( out1 . numpy ( ) , out1_np := a . numpy ( ) . sum ( ) + 4 , atol = 1e-4 , rtol = 1e-6 )
np . testing . assert_allclose ( out2 . numpy ( ) , out0_np * out1_np , atol = 1e-4 , rtol = 1e-6 )
# multireduce spec
@unittest . expectedFailure
def test_reduce_multiple_paths ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
out0 = a . sum ( ) . exp2 ( )
# out1 has two paths to a.sum()
out1 = a . sum ( ) + out0
run_schedule ( check_schedule ( [ out0 , out1 ] , 1 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , out0_np := np . exp2 ( a . numpy ( ) . sum ( ) ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , a . numpy ( ) . sum ( ) + out0_np , atol = 1e-4 , rtol = 1e-6 )
# multireduce spec
def test_multireduce_reduce_multiple_paths ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
out0 = a . sum ( ) . exp2 ( )
out1 = a . sum ( ) + out0
b = ( a + out0 + out1 )
out2 = b . sum ( ) . exp2 ( )
out3 = b . sum ( ) + out2
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
run_schedule ( check_schedule ( [ out0 , out1 , out2 , out3 ] , 6 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , np_out0 := np . exp2 ( a . numpy ( ) . sum ( ) ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , np_out1 := a . numpy ( ) . sum ( ) + np_out0 , atol = 1e-4 , rtol = 1e-4 )
np_b = ( a . numpy ( ) + np_out0 + np_out1 )
np . testing . assert_allclose ( out2 . numpy ( ) , np_out2 := np . exp2 ( np_b . sum ( ) ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out3 . numpy ( ) , np_b . sum ( ) + np_out2 , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_reduce_ext_reduce_child ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
b = Tensor . randn ( 4 , 4 ) . realize ( )
# b.sum() is not a descendant of the fused nodes
out0 = a . sum ( ) + b . sum ( ) + 2
out1 = a . sum ( ) + b . sum ( ) + 4
# run_schedule(check_schedule([out0, out1], 1))
run_schedule ( check_schedule ( [ out0 , out1 ] , 4 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , a . numpy ( ) . sum ( ) + b . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , a . numpy ( ) . sum ( ) + b . numpy ( ) . sum ( ) + 4 , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_reduce_multiple_paths_midreduce ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
r = a . sum ( )
out0 = r . exp2 ( )
# reduce node in the indirect path from r to out2
out1 = ( a - out0 ) . max ( )
out2 = r + out1
# run_schedule(check_schedule([r, out0, out1, out2], 1))
run_schedule ( check_schedule ( [ r , out0 , out1 , out2 ] , 4 ) )
np . testing . assert_allclose ( r . numpy ( ) , r_np := a . numpy ( ) . sum ( ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out0 . numpy ( ) , out0_np := np . exp2 ( r_np ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , out1_np := ( a . numpy ( ) - out0_np ) . max ( ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out2 . numpy ( ) , r_np + out1_np , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_reduce_multiple_paths_midreduce_fused ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
b = Tensor . randn ( 4 , 4 ) . realize ( )
out0 = a . sum ( ) + 4
out1 = b . max ( ) + out0 * 2
out2 = a . sum ( ) + out1
# run_schedule(check_schedule([out0, out1, out2], 1))
run_schedule ( check_schedule ( [ out0 , out1 , out2 ] , 4 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , out0_np := a . numpy ( ) . sum ( ) + 4 , atol = 1e-4 , rtol = 1e-6 )
np . testing . assert_allclose ( out1 . numpy ( ) , out1_np := b . numpy ( ) . max ( ) + out0_np * 2 , atol = 1e-4 , rtol = 1e-6 )
np . testing . assert_allclose ( out2 . numpy ( ) , a . numpy ( ) . sum ( ) + out1_np , atol = 1e-4 , rtol = 1e-6 )
# multireduce spec
def test_reduce_multiple_paths_midexpand ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 ) . realize ( )
b = Tensor . randn ( 4 , 4 , 4 ) . realize ( )
r = a . sum ( )
out0 = r . exp2 ( )
# e1 is in the indirect path from a.sum() to out1
e = b + out0
out1 = r + e [ 0 ] [ 0 ] [ 0 ]
# run_schedule(check_schedule([r, out0, out1, e], 3)) # 1 or 2 or 3? should be 1 (one reduce) but the different outputs might make it 3
run_schedule ( check_schedule ( [ r , out0 , out1 , e ] , 4 ) )
np . testing . assert_allclose ( r . numpy ( ) , r_np := a . numpy ( ) . sum ( ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out0 . numpy ( ) , out0_np := np . exp2 ( r_np ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( e . numpy ( ) , e_np := b . numpy ( ) + out0_np , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , r_np + e_np [ 0 ] [ 0 ] [ 0 ] , atol = 1e-4 , rtol = 1e-4 )
# changed by multireduce
def test_reduce_expand_child ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( ( 32 , 32 , 32 ) ) . realize ( )
b = Tensor . randn ( ( 1 , 16 ) ) . realize ( )
out0 = a . sum ( ) + 2
out1 = a . sum ( ) + b
# run_schedule(check_schedule([out0, out1], 2))
run_schedule ( check_schedule ( [ out0 , out1 ] , 4 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , a . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , a . numpy ( ) . sum ( ) + b . numpy ( ) , atol = 1e-4 , rtol = 1e-4 )
@unittest . expectedFailure
def test_reduce_shrink_child ( self ) :
a = Tensor . empty ( 100 , 100 )
b = Tensor . empty ( 10 , )
c = a . sum ( ) + b [ 0 ]
d = a . sum ( ) + 2
check_schedule ( [ c , d ] , 1 )
def test_reduce_multiple_paths_midshrink ( self ) :
a = Tensor . empty ( 4 , 4 )
r = a . sum ( axis = 1 )
out0 = r . exp2 ( )
out1 = out0 [ 0 ] + out0
check_schedule ( [ r , out0 , out1 ] , 3 )
def test_reduce_shrink_output ( self ) :
a = Tensor . empty ( 4 , 4 )
r = a . sum ( keepdim = True )
out0 = r . exp2 ( )
out1 = out0 [ 0 ] + Tensor . empty ( 1 , )
check_schedule ( [ r , out0 , out1 ] , 3 )
# multireduce spec
def test_std_multireduce_fusion ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . std ( - 1 )
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , x . numpy ( ) . std ( axis = - 1 , ddof = 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_argmin_multireduce_fusion ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . argmin ( - 1 )
run_schedule ( check_schedule ( out , 3 ) )
np . testing . assert_equal ( out . numpy ( ) , x . numpy ( ) . argmin ( axis = - 1 ) )
# multireduce spec
def test_argmax_multireduce_fusion ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . argmax ( - 1 )
run_schedule ( check_schedule ( out , 3 ) )
np . testing . assert_equal ( out . numpy ( ) , x . numpy ( ) . argmax ( axis = - 1 ) )
def test_scaled_dot_product_attention_multireduce_fusion ( self ) :
Tensor . manual_seed ( 0 )
q = Tensor . randn ( 32 , 8 , 16 , 64 ) . realize ( )
k = Tensor . randn ( 32 , 8 , 16 , 64 ) . realize ( )
v = Tensor . randn ( 32 , 8 , 16 , 64 ) . realize ( )
out = Tensor . scaled_dot_product_attention ( q , k , v )
run_schedule ( check_schedule ( out , 5 ) )
if getenv ( " CHECK " , 1 ) :
import torch
compare = torch . nn . functional . scaled_dot_product_attention ( torch . tensor ( q . numpy ( ) ) , torch . tensor ( k . numpy ( ) ) , torch . tensor ( v . numpy ( ) ) )
np . testing . assert_allclose ( out . numpy ( ) , compare . numpy ( ) , atol = 1e-6 , rtol = 1e-3 )
# multireduce spec
def test_ugly_reduceop_pairing ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 32 ) . realize ( )
b = Tensor . randn ( 4 , 32 ) . realize ( )
c = Tensor . randn ( 4 , 32 ) . realize ( )
out = ( c * a . sum ( - 1 , keepdim = True ) ) . sum ( - 1 ) + ( b * a . sum ( - 1 , keepdim = True ) ) . sum ( - 1 ) # a.sum has >1 children but should still fuse
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 3 ) )
np . testing . assert_allclose ( out . numpy ( ) , \
( c . numpy ( ) * a . numpy ( ) . sum ( axis = - 1 , keepdims = True ) ) . sum ( - 1 ) + ( b . numpy ( ) * a . numpy ( ) . sum ( axis = - 1 , keepdims = True ) ) . sum ( - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_reduce_expand_reduce_fusion ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 32 ) . realize ( )
out = ( a + a . sum ( - 1 , keepdim = True ) ) . sum ( - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , ( a . numpy ( ) + a . numpy ( ) . sum ( axis = - 1 , keepdims = True ) ) . sum ( axis = - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_reduce_expand_reduce_expand_fusion ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 32 ) . realize ( )
out = a + ( a + a . sum ( - 1 , keepdim = True ) ) . sum ( - 1 , keepdim = True )
# run_schedule(check_schedule(out, 2))
run_schedule ( check_schedule ( out , 3 ) )
np . testing . assert_allclose ( out . numpy ( ) , \
a . numpy ( ) + ( a . numpy ( ) + a . numpy ( ) . sum ( axis = - 1 , keepdims = True ) ) . sum ( axis = - 1 , keepdims = True ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_branching_reduces_and_expands_fusion ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 32 ) . realize ( )
out0 = a + a . sum ( - 1 , keepdim = True )
out1 = out0 . sum ( - 1 )
# run_schedule(check_schedule(out, 2))
run_schedule ( check_schedule ( [ out0 , out1 ] , 3 ) )
np . testing . assert_allclose ( out0 . numpy ( ) , a . numpy ( ) + a . numpy ( ) . sum ( axis = - 1 , keepdims = True ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out1 . numpy ( ) , ( a . numpy ( ) + a . numpy ( ) . sum ( axis = - 1 , keepdims = True ) ) . sum ( axis = - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_fusion_simple_sequential ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
y = Tensor . randn ( 4 , 32 ) . realize ( )
out = ( y + x . sum ( axis = - 1 , keepdim = True ) ) . sum ( axis = - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , ( y . numpy ( ) + x . numpy ( ) . sum ( axis = - 1 , keepdims = True ) ) . sum ( axis = - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_fusion_simple_parallel ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
y = Tensor . randn ( 4 , 32 ) . realize ( )
out = y . sum ( axis = - 1 ) + x . sum ( axis = - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , y . numpy ( ) . sum ( axis = - 1 ) + x . numpy ( ) . sum ( axis = - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_fusion_sequential ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . std ( - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , x . numpy ( ) . std ( axis = - 1 , ddof = 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_fusion_parallel ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
y = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . std ( - 1 ) + y . std ( - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 4 ) )
np . testing . assert_allclose ( out . numpy ( ) , x . numpy ( ) . std ( axis = - 1 , ddof = 1 ) + y . numpy ( ) . std ( axis = - 1 , ddof = 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_diffops_sequential ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = ( x - x . max ( - 1 , keepdim = True ) ) . sum ( - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , ( x . numpy ( ) - x . numpy ( ) . max ( axis = - 1 , keepdims = True ) ) . sum ( axis = - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_fusion_diffops_parallel ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
y = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . sum ( - 1 ) + y . max ( - 1 )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , x . numpy ( ) . sum ( axis = - 1 ) + y . numpy ( ) . max ( axis = - 1 ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multireduce_fusion_sequential_and_parallel ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
y = Tensor . randn ( 4 , 32 ) . realize ( )
mu = ( x - x . max ( axis = - 1 , keepdim = True ) ) . mean ( axis = - 1 , keepdim = True ) + ( y - y . max ( axis = - 1 , keepdim = True ) ) . mean ( axis = - 1 , keepdim = True )
out = [ ( ( x - mu ) . square ( ) . sum ( - 1 ) / x . shape [ - 1 ] ) . sqrt ( ) , ( ( y - mu ) . square ( ) . sum ( - 1 ) / y . shape [ - 1 ] ) . sqrt ( ) ]
np_mu = ( x . numpy ( ) - x . numpy ( ) . max ( axis = - 1 , keepdims = True ) ) . mean ( axis = - 1 , keepdims = True ) + \
( y . numpy ( ) - y . numpy ( ) . max ( axis = - 1 , keepdims = True ) ) . mean ( axis = - 1 , keepdims = True )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 6 ) )
np . testing . assert_allclose ( out [ 0 ] . numpy ( ) , np . sqrt ( np . square ( x . numpy ( ) - np_mu ) . sum ( - 1 ) / x . shape [ - 1 ] ) , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( out [ 1 ] . numpy ( ) , np . sqrt ( np . square ( y . numpy ( ) - np_mu ) . sum ( - 1 ) / y . shape [ - 1 ] ) , atol = 1e-4 , rtol = 1e-4 )
# multireduce spec
def test_multimatmul_fusion ( self ) :
Tensor . manual_seed ( 0 )
a , b = Tensor . randn ( 4 , 64 ) . realize ( ) , Tensor . rand ( 64 , 8 ) . realize ( )
c , d = Tensor . randn ( 4 , 64 ) . realize ( ) , Tensor . rand ( 64 , 8 ) . realize ( )
out = a @b + c @d
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , a . numpy ( ) @b . numpy ( ) + c . numpy ( ) @d . numpy ( ) , atol = 1e-4 , rtol = 1e-4 )
def test_softmax_fusion ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 12 , 64 , 64 ) . realize ( )
out = x . softmax ( )
run_schedule ( check_schedule ( out , 3 ) )
expected = ( x_exp := np . exp ( x . numpy ( ) - x . numpy ( ) . max ( - 1 , keepdims = True ) ) ) / x_exp . sum ( - 1 , keepdims = True )
np . testing . assert_allclose ( out . numpy ( ) , expected , atol = 1e-4 , rtol = 1e-4 )
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) , " need half " )
def test_softmax_upcast ( self ) :
# input half, softmax in float
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 12 , 64 , 64 , dtype = dtypes . half ) . realize ( )
out = x . softmax ( dtype = dtypes . float )
sched = out . schedule ( )
self . assertEqual ( len ( sched ) , 3 )
self . assertEqual ( sched [ 0 ] . bufs [ 0 ] . dtype , dtypes . half )
# input float, softmax in float
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 12 , 64 , 64 , dtype = dtypes . float ) . realize ( )
out = x . softmax ( dtype = dtypes . float )
sched = out . schedule ( )
self . assertEqual ( len ( sched ) , 3 )
self . assertEqual ( sched [ 0 ] . bufs [ 0 ] . dtype , dtypes . float )
def test_softmax_backward ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 12 , 64 , 64 , requires_grad = True ) . realize ( )
x . softmax ( ) . sum ( ) . backward ( )
run_schedule ( check_schedule ( x . grad , 4 ) )
# changed by: multireduce spec
def test_layernorm_onelayer_fusion ( self ) :
Tensor . manual_seed ( 0 )
layer = nn . LayerNorm ( [ 10 , 10 ] )
layer . weight = Tensor . randn ( 10 , 10 ) . realize ( )
layer . bias = Tensor . randn ( 10 , 10 ) . realize ( )
x = Tensor . randn ( 20 , 5 , 10 , 10 ) . realize ( )
out = layer ( x )
# run_schedule(check_schedule(out, 2))
run_schedule ( check_schedule ( out , 3 ) )
y = ( x . numpy ( ) - x . numpy ( ) . mean ( layer . axis , keepdims = True ) )
expected = y / np . sqrt ( ( y * y ) . mean ( layer . axis , keepdims = True ) + layer . eps )
np . testing . assert_allclose ( out . numpy ( ) , expected * layer . weight . numpy ( ) + layer . bias . numpy ( ) , atol = 1e-4 , rtol = 1e-4 )
def test_scaled_dot_product_attention_fusion ( self ) :
x , y , z , m = ( Tensor . empty ( 32 , 8 , 16 , 16 ) for _ in range ( 4 ) )
out = Tensor . scaled_dot_product_attention ( x , y , z , attn_mask = m )
check_schedule ( out , 5 )
def test_scaled_dot_product_attention_causal_fusion ( self ) :
x , y , z = ( Tensor . empty ( 32 , 8 , 16 , 16 ) for _ in range ( 3 ) )
out = Tensor . scaled_dot_product_attention ( x , y , z , is_causal = True )
check_schedule ( out , 5 )
def test_adam_step_fusion ( self ) :
with Tensor . train ( ) :
x = Tensor . empty ( 4 , 64 , 768 )
layer = nn . Linear ( 768 , 768 * 4 )
_realize_weights ( layer )
opt = nn . optim . Adam ( nn . state . get_parameters ( layer ) , lr = 1e-4 )
layer ( x ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 16 )
def test_adam_conv_fuse ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 4 , 4 )
c1 = nn . Conv2d ( 3 , 32 , 3 )
_realize_weights ( c1 )
opt = nn . optim . Adam ( nn . state . get_parameters ( c1 ) , lr = 1e-4 )
opt . zero_grad ( )
c1 ( img ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 16 )
def test_adam_2convs_fuse ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 4 , 4 )
c1 = nn . Conv2d ( 3 , 16 , 3 , bias = False )
c2 = nn . Conv2d ( 16 , 32 , 2 , bias = False )
_realize_weights ( [ c1 , c2 ] )
opt = nn . optim . Adam ( nn . state . get_parameters ( [ c1 , c2 ] ) , lr = 1e-4 )
opt . zero_grad ( )
c2 ( c1 ( img ) . relu ( ) ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 20 )
def test_sgd_conv_fuse ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 4 , 4 )
c1 = nn . Conv2d ( 3 , 32 , 3 )
_realize_weights ( c1 )
opt = nn . optim . SGD ( nn . state . get_parameters ( c1 ) )
opt . zero_grad ( )
c1 ( img ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 3 )
def test_sgd_2convs_fuse ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 4 , 4 )
c1 = nn . Conv2d ( 3 , 16 , 3 , bias = False )
c2 = nn . Conv2d ( 16 , 32 , 2 , bias = False )
_realize_weights ( [ c1 , c2 ] )
opt = nn . optim . SGD ( nn . state . get_parameters ( [ c1 , c2 ] ) )
opt . zero_grad ( )
c2 ( c1 ( img ) . relu ( ) ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 7 )
@unittest . skipUnless ( is_dtype_supported ( dtypes . ulong ) , " Needs ulong " )
def test_fold_2convs_sgd_nesterov_momentum_wd ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 4 , 4 )
c1 = nn . Conv2d ( 3 , 16 , 3 , bias = False )
c2 = nn . Conv2d ( 16 , 32 , 2 , bias = False )
_realize_weights ( [ c1 , c2 ] )
opt = nn . optim . SGD ( nn . state . get_parameters ( [ c1 , c2 ] ) , nesterov = True , momentum = 0.9 , weight_decay = 0.1 )
opt . zero_grad ( )
c2 ( c1 ( img ) . relu ( ) ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 13 )
def test_sgd_4convs_fuse ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 64 , 64 )
c1 = nn . Conv2d ( 3 , 4 , 3 , bias = False )
c2 = nn . Conv2d ( 4 , 8 , 3 , bias = False )
c3 = nn . Conv2d ( 8 , 16 , 3 , bias = False )
c4 = nn . Conv2d ( 16 , 32 , 3 , bias = False )
_realize_weights ( [ c1 , c2 , c3 , c4 ] )
opt = nn . optim . SGD ( nn . state . get_parameters ( [ c1 , c2 , c3 , c4 ] ) )
opt . zero_grad ( )
c4 ( c3 ( c2 ( c1 ( img ) . relu ( ) ) . relu ( ) ) . relu ( ) ) . relu ( ) . sum ( ) . backward ( )
check_schedule ( opt . schedule_step ( ) , 17 )
def test_sgd_4convs_fuse_conv_bw ( self ) :
with Tensor . train ( ) :
img = Tensor . empty ( 2 , 3 , 64 , 64 )
c1 = nn . Conv2d ( 3 , 4 , 3 , bias = False )
c2 = nn . Conv2d ( 4 , 8 , 3 , bias = False )
c3 = nn . Conv2d ( 8 , 16 , 3 , bias = False )
c4 = nn . Conv2d ( 16 , 32 , 3 , bias = False )
_realize_weights ( [ c1 , c2 , c3 , c4 ] )
opt = nn . optim . SGD ( nn . state . get_parameters ( [ c1 , c2 , c3 , c4 ] ) )
opt . zero_grad ( )
c4 ( c3 ( c2 ( c1 ( img ) . relu ( ) ) . relu ( ) ) . relu ( ) ) . relu ( ) . sum ( ) . backward ( )
with Context ( FUSE_CONV_BW = 1 ) : check_schedule ( opt . schedule_step ( ) , 14 )
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) , " need half " )
def test_prefer_half_buffer ( self ) :
x = Tensor . ones ( 4 ) . contiguous ( ) . realize ( )
# y = Tensor.ones(4).contiguous().realize()
z = Tensor . ones ( 4 , 4 ) . contiguous ( ) . realize ( )
# should not create extra kernel if output will be realized anyways
dummy = x . sum ( ) . half ( ) . float ( )
check_schedule ( dummy , 1 )
dummy = x . sum ( ) . half ( ) . float ( ) . contiguous ( ) + 1
check_schedule ( dummy , 2 )
# shared between two outputs
shared = x . sum ( ) . half ( ) . float ( )
a = shared * 2
b = shared * 3
sched = check_schedule ( [ a , b ] , 3 )
# store reduceop in half
self . assertEqual ( sched [ 0 ] . bufs [ 0 ] . dtype , dtypes . half )
# fuse cast with the child kernel
self . assertEqual ( sched [ 1 ] . bufs [ 0 ] . dtype , dtypes . float )
self . assertEqual ( sched [ 2 ] . bufs [ 0 ] . dtype , dtypes . float )
# reduce
a = z . sum ( axis = 0 ) . half ( ) . float ( ) . sum ( axis = 0 )
sched = check_schedule ( a , 2 )
self . assertEqual ( sched [ 0 ] . bufs [ 0 ] . dtype , dtypes . half )
self . assertEqual ( sched [ 1 ] . bufs [ 0 ] . dtype , dtypes . float )
# expand
# expand will realize just after the .float(), so requires change to realize-before-expand
# normal = (x.sum().half().float().reshape(1) * y).sum()
# sched = check_schedule(normal, 2)
# for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs[:-1])
# parallel reduce
# a = x.sum().half().float() * y.sum().half().float()
# b = a + 1
# c = a + 2
# sched = check_schedule([b, c], 4)
# doesn't store either in half because it doesn't chase
def test_reduce_simple_chase ( self ) :
a = Tensor . empty ( 4 , 4 , 4 )
r = a . sum ( 0 ) + 6
b = r . sum ( 0 ) * 4
c = r . sum ( 1 ) * 2
schedule = check_schedule ( [ b , c ] , 3 )
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . ADD )
# multireduce spec
def test_multireduce_simple_chase ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 , 4 ) . realize ( )
r = ( a + ( a . sum ( 0 , keepdim = True ) + 6 ) ) . sum ( 0 ) * 2
b = r . sum ( 0 ) + 8
c = r . sum ( 1 ) + 12
np_r = ( a . numpy ( ) + ( a . numpy ( ) . sum ( 0 ) + 6 ) ) . sum ( 0 ) * 2
# schedule = check_schedule([b,c], 3)
# self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL)
schedule = check_schedule ( [ b , c ] , 4 )
run_schedule ( schedule )
np . testing . assert_allclose ( b . numpy ( ) , np_r . sum ( 0 ) + 8 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( c . numpy ( ) , np_r . sum ( 1 ) + 12 , atol = 1e-4 , rtol = 1e-4 )
def test_push_permute_chase ( self ) :
a = Tensor . empty ( 4 , 4 , 4 )
b = Tensor . empty ( 4 , 4 )
r = a . sum ( 2 ) + b
d = r . T * 4
e = r * d
schedule = check_schedule ( [ d , e ] , 3 )
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . ADD )
# multireduce spec
def test_multireduce_push_permute_chase ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 4 , 4 , 4 ) . realize ( )
b = Tensor . randn ( 4 , 4 ) . realize ( )
r = a . sum ( 2 ) + b
d = r . T * 4
e = r * ( d + a ) . sum ( 2 )
schedule = check_schedule ( [ d , e ] , 3 ) # make sure it doesn't fuse
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . ADD )
run_schedule ( schedule )
np . testing . assert_allclose ( d . numpy ( ) , ( a . numpy ( ) . sum ( 2 ) + b . numpy ( ) ) . T * 4 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( e . numpy ( ) , ( a . numpy ( ) . sum ( 2 ) + b . numpy ( ) ) * ( d . numpy ( ) + a . numpy ( ) ) . sum ( 2 ) , atol = 1e-4 , rtol = 1e-4 )
def test_push_shrink_chase ( self ) :
a = Tensor . empty ( 16 , 16 )
b = Tensor . empty ( 4 )
c = Tensor . empty ( 16 , )
r = a . sum ( 1 ) + c
d = r [ : 4 ] * b
schedule = check_schedule ( d , 2 )
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . ADD )
# multireduce spec
def test_multireduce_push_shrink_chase ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 16 , 16 ) . realize ( )
b = Tensor . randn ( 4 ) . realize ( )
c = Tensor . randn ( 16 , ) . realize ( )
d = Tensor . randn ( 16 , 16 ) . realize ( )
r = a . sum ( 1 ) + c
out = r [ : 4 ] * b + d . sum ( 1 ) [ : 4 ]
# schedule = check_schedule(out, 2)
schedule = check_schedule ( out , 3 )
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . ADD )
run_schedule ( schedule )
np . testing . assert_allclose ( out . numpy ( ) , ( a . numpy ( ) . sum ( 1 ) + c . numpy ( ) ) [ : 4 ] * b . numpy ( ) + d . numpy ( ) . sum ( 1 ) [ : 4 ] , atol = 1e-4 , rtol = 1e-4 )
def test_midreduce_nochase ( self ) :
a = Tensor . empty ( 16 , 16 )
b = ( a . sum ( 0 ) + a . max ( 1 ) ) + 2
schedule = check_schedule ( b , 2 )
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . REDUCE_AXIS )
# multireduce spec
def test_multireduce_midreduce_nochase ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 16 , 16 ) . realize ( )
b = ( a . sum ( 0 ) + a . max ( 0 ) + a . max ( 1 ) + a . sum ( 1 ) ) + 2
# schedule = check_schedule(b, 2)
schedule = check_schedule ( b , 4 )
self . assertIs ( schedule [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . REDUCE_AXIS )
run_schedule ( schedule )
np . testing . assert_allclose ( b . numpy ( ) , a . numpy ( ) . sum ( 0 ) + a . numpy ( ) . max ( 0 ) + a . numpy ( ) . max ( 1 ) + a . numpy ( ) . sum ( 1 ) + 2 , atol = 1e-4 , rtol = 1e-4 )
# changed by: multireduce spec
# pattern in test_transformer
def test_partial_fuse1 ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 16 , 16 ) . realize ( )
b = Tensor . randn ( 16 , 16 ) . realize ( )
c = a . sum ( ) + 2
d = ( a . sum ( ) - b . sum ( ) ) * 4
# run_schedule(check_schedule([c, d], 1))
run_schedule ( check_schedule ( [ c , d ] , 3 ) )
np . testing . assert_allclose ( c . numpy ( ) , a . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( d . numpy ( ) , ( a . numpy ( ) . sum ( ) - b . numpy ( ) . sum ( ) ) * 4 , atol = 1e-4 , rtol = 1e-4 )
# changed by: multireduce spec
# pattern in conv
def test_partial_fuse2 ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 16 , 16 ) . realize ( )
b = Tensor . randn ( 16 , 16 ) . realize ( )
c = a . sum ( ) + 2
d = b . sum ( ) - c
# run_schedule(check_schedule([c, d], 1))
run_schedule ( check_schedule ( [ c , d ] , 2 ) )
np . testing . assert_allclose ( c . numpy ( ) , a . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( d . numpy ( ) , b . numpy ( ) . sum ( ) - ( a . numpy ( ) . sum ( ) + 2 ) , atol = 1e-4 , rtol = 1e-4 )
# changed by: multireduce spec
# pattern in adam
@unittest . expectedFailure
def test_partial_fuse3 ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 16 , 16 ) . realize ( )
b = Tensor . randn ( 16 , 16 ) . realize ( )
c = a . sum ( ) + 2
d = a . sum ( ) * 2
e = c * d
f = b . sum ( ) - e
# run_schedule(check_schedule([c, d, e, f], 1))
run_schedule ( check_schedule ( [ c , d , e , f ] , 2 ) )
np . testing . assert_allclose ( c . numpy ( ) , c_np := a . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( d . numpy ( ) , d_np := a . numpy ( ) . sum ( ) * 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( e . numpy ( ) , e_np := c_np * d_np , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( f . numpy ( ) , b . numpy ( ) . sum ( ) - e_np , atol = 1e-4 , rtol = 1e-4 )
# changed by: multireduce spec
@unittest . expectedFailure
def test_partial_fuse4 ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 16 , 16 ) . realize ( )
b = Tensor . randn ( 16 , 16 ) . realize ( )
c = a . sum ( ) + 2
d = a . sum ( ) * 2
e = c * d
f = ( b - d ) . sum ( ) - e
# run_schedule(check_schedule([c, d, e, f], 1))
run_schedule ( check_schedule ( [ c , d , e , f ] , 3 ) )
np . testing . assert_allclose ( c . numpy ( ) , c_np := a . numpy ( ) . sum ( ) + 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( d . numpy ( ) , d_np := a . numpy ( ) . sum ( ) * 2 , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( e . numpy ( ) , e_np := c_np * d_np , atol = 1e-4 , rtol = 1e-4 )
np . testing . assert_allclose ( f . numpy ( ) , ( b . numpy ( ) - d_np ) . sum ( ) - e_np , atol = 1e-4 , rtol = 1e-4 )
def test_pad_reduce_safe ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . rand ( 3 , 4 , 5 ) . realize ( )
b = Tensor . rand ( 3 , 4 , 5 ) . realize ( )
out = ( a + b ) . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , value = 1.0 ) . sum ( ) . contiguous ( )
run_schedule ( check_schedule ( out , 1 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . pad ( a . numpy ( ) + b . numpy ( ) , ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , constant_values = 1.0 ) . sum ( ) , atol = 1e-5 , rtol = 1e-6 )
# multireduce spec
def test_multireduce_pad_reduce_safe ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 3 , 4 , 5 ) . realize ( )
b = Tensor . randn ( 3 , 4 , 5 ) . realize ( )
out = ( a . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , value = 1.0 ) . sum ( keepdim = True ) + b . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , value = 1.0 ) . sum ( ) ) . contiguous ( )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . pad ( a . numpy ( ) , ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , constant_values = 1.0 ) . sum ( keepdims = True ) + \
np . pad ( b . numpy ( ) , ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , constant_values = 1.0 ) . sum ( ) , atol = 1e-4 , rtol = 1e-4 )
def test_pad_reduce_unsafe ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . rand ( 3 , 4 , 5 ) . realize ( )
out = a . log2 ( ) . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , value = 1.0 ) . sum ( ) . contiguous ( )
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . pad ( np . log2 ( a . numpy ( ) ) , ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , constant_values = 1.0 ) . sum ( ) , atol = 1e-5 , rtol = 1e-6 )
# multireduce spec
def test_multireduce_pad_reduce_unsafe ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . randn ( 3 , 4 , 5 ) . abs ( ) . realize ( )
b = Tensor . randn ( 3 , 4 , 5 ) . abs ( ) . realize ( )
out = ( a . log2 ( ) . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , value = 1.0 ) . sum ( ) + b ) . abs ( ) . log2 ( ) . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , value = 1.0 ) . sum ( ) . contiguous ( )
# run_schedule(check_schedule(out, 1))
run_schedule ( check_schedule ( out , 4 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . pad ( np . log2 ( np . abs ( np . pad ( np . log2 ( a . numpy ( ) ) , ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , constant_values = 1.0 ) . sum ( ) + \
b . numpy ( ) ) ) , ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 1 ) ) , constant_values = 1.0 ) . sum ( ) , atol = 3e-4 , rtol = 1e-6 )
def test_shrink_pad_safe ( self ) :
a = Tensor . ones ( ( 3 , ) ) . contiguous ( ) . realize ( )
b = Tensor . ones ( ( 3 , ) ) . contiguous ( ) . realize ( )
out = ( a + b ) . shrink ( ( ( 0 , 1 ) , ) ) . pad ( ( ( 0 , 1 ) , ) ) . contiguous ( )
run_schedule ( check_schedule ( out , 1 ) )
np . testing . assert_equal ( out . numpy ( ) , [ 2 , 0 ] )
def test_shrink_pad_unsafe ( self ) :
a = Tensor . ones ( ( 3 , ) ) . contiguous ( ) . realize ( )
out = a . exp2 ( ) . shrink ( ( ( 0 , 1 ) , ) ) . pad ( ( ( 0 , 1 ) , ) ) . contiguous ( )
run_schedule ( check_schedule ( out , 2 ) )
np . testing . assert_equal ( out . numpy ( ) , [ 2 , 0 ] )
def test_base_change_shrink_pad ( self ) :
a = Tensor . ones ( 3 , 3 ) . contiguous ( ) . realize ( )
b = a . exp2 ( )
c = b [ : - 1 , : - 1 ]
d = c . pad ( ( ( 0 , 1 ) , ( 0 , 1 ) ) ) * 2
run_schedule ( check_schedule ( d , 2 ) )
np . testing . assert_equal ( d . numpy ( ) , np . pad ( np . exp2 ( a . numpy ( ) ) [ : - 1 , : - 1 ] , ( ( 0 , 1 ) , ( 0 , 1 ) ) ) * 2 )
def test_base_change_expand_pad ( self ) :
a = Tensor . ones ( 3 , 3 ) . contiguous ( ) . realize ( )
b = a . exp2 ( )
c = b [ : , None , : ]
d = c . pad ( ( ( 0 , 0 ) , ( 1 , 1 ) , ( 0 , 0 ) ) ) * 2
run_schedule ( check_schedule ( d , 2 ) )
np . testing . assert_equal ( d . numpy ( ) , np . pad ( np . exp2 ( a . numpy ( ) ) [ : , None , : ] , ( ( 0 , 0 ) , ( 1 , 1 ) , ( 0 , 0 ) ) ) * 2 )
# TODO like openpilot with imagef
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) , " need half " )
def test_base_change_expand_expand ( self ) :
a = Tensor . ones ( 4 , 4 ) . contiguous ( ) . realize ( )
b = a . cast ( dtypes . half ) . expand ( 2 , 4 , 4 )
c = b . cast ( dtypes . int ) . expand ( 2 , 2 , 4 , 4 )
run_schedule ( check_schedule ( c , 2 ) )
np . testing . assert_equal ( c . numpy ( ) , np . ones ( ( ( 2 , 2 , 4 , 4 ) ) , dtype = np . int32 ) )
def test_base_change_pad_expand ( self ) :
a = Tensor . full ( ( 4 , 4 ) , 1. ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 4 , 4 ) , 2. ) . contiguous ( ) . realize ( )
c = ( a + b ) . pad ( ( ( 1 , 1 ) , ( 1 , 1 ) ) )
d = c . cast ( dtypes . int ) . expand ( ( 2 , 6 , 6 ) ) * 4
run_schedule ( check_schedule ( d , 2 ) )
c_np = np . pad ( ( np . full ( ( 4 , 4 ) , 2. , dtype = np . float32 ) + np . full ( ( 4 , 4 ) , 1. , dtype = np . float32 ) ) , ( ( 1 , 1 ) , ( 1 , 1 ) ) , constant_values = 0.0 )
np . testing . assert_equal ( d . numpy ( ) , np . broadcast_to ( c_np . astype ( np . half ) , ( 2 , * c_np . shape ) ) * 4 )
def test_pad_reduce_unsafe_multiview_st ( self ) :
P = Tensor . ones ( 3 , 3 ) . contiguous ( )
sums = P . sum ( axis = 1 , keepdim = True )
P / = sums
p = P [ 0 ]
p = p . pad ( ( ( 1 , 0 ) , ) )
p = p . repeat ( [ 2 ] )
run_schedule ( check_schedule ( p , 3 ) )
tiny_ret = p . numpy ( )
P = np . ones ( ( 3 , 3 ) , dtype = np . float32 )
sums = P . sum ( axis = 1 , keepdims = True )
P / = sums
p = P [ 0 ]
p = np . pad ( p , ( 1 , 0 ) , ' constant ' )
p = np . tile ( p , 2 )
np . testing . assert_allclose ( tiny_ret , p )
def test_bitcast_fuses ( self ) :
x = cast ( UOp , Tensor . empty ( 1 , dtype = dtypes . float32 ) . realize ( ) . lazydata )
a = x . alu ( Ops . EXP2 ) . bitcast ( dtypes . int32 )
b = x . bitcast ( dtypes . int32 )
b = a . alu ( Ops . ADD , b )
check_schedule ( b , 1 ) # this should fuse when it makes sense
@unittest . skip ( " disabling subbuffer manually isn ' t supported anymore " )
def test_bitcast_disable_subbufer ( self ) :
x = cast ( UOp , Tensor . empty ( 1 , dtype = dtypes . float32 ) . realize ( ) . lazydata )
a = x . alu ( Ops . EXP2 ) . cast ( dtypes . int32 , True , allow_buffer_view = False )
b = x . cast ( dtypes . int32 , True , allow_buffer_view = False )
b = a . alu ( Ops . ADD , b )
check_schedule ( b , 1 )
def test_reduceop_reshape_dont_push ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 10 , 20 ) . realize ( )
out = x . argmax ( 1 )
run_schedule ( check_schedule ( out , 3 ) ) # TODO: push a reduceop through a reshape
def test_conv2d ( self ) : _test_conv2d ( 7 )
def test_conv2d_fused ( self ) : _test_conv2d ( 6 , FUSE_CONV_BW = 1 )
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) and is_dtype_supported ( dtypes . ulong ) , " need half and ulong " )
def test_conv2d_half ( self ) : _test_conv2d ( 7 , dtype = dtypes . half )
@unittest . skipUnless ( is_dtype_supported ( dtypes . half ) , " need half " )
@unittest . skipIf ( Device . DEFAULT == " WEBGPU " , " Causes other tests to fail " )
@unittest . expectedFailure
def test_conv2d_fused_half ( self ) : _test_conv2d ( 5 , dtype = dtypes . half )
@unittest . skip ( " splitting kernels exceeding device buffer count is not yet supported " )
def _test_buf_cnt ( self , cnt : int , allowed : int ) :
#if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
alu = functools . reduce ( lambda x , y : x + y , [ Tensor . ones ( ( 1 , 1 ) ) . contiguous ( ) . realize ( ) for _ in range ( cnt - 1 ) ] )
s = alu . schedule ( )
assert len ( s ) == allowed
run_schedule ( s )
expected = functools . reduce ( lambda x , y : x + y , [ np . ones ( ( 1 , 1 ) ) for _ in range ( cnt - 1 ) ] )
np . testing . assert_equal ( alu . numpy ( ) , expected )
def test_buf_cnt_at_limit ( self ) : self . _test_buf_cnt ( 31 , allowed = 1 )
@unittest . expectedFailure
def test_buf_cnt_over_limit ( self ) : self . _test_buf_cnt ( 32 , allowed = 2 )
@unittest . expectedFailure
def test_buf_cnt_over_limit_alt ( self ) : self . _test_buf_cnt ( 63 , allowed = 3 )
def test_schedule_mem_used ( self ) :
base = GlobalCounters . mem_used
Tensor . ones ( 256 ) . contiguous ( ) . realize ( )
Tensor . ones ( 5 , 5 ) . contiguous ( ) . schedule ( )
self . assertEqual ( GlobalCounters . mem_used - base , 0 )
@unittest . skip ( " TODO: this is consistently creating non reproducible failures " )
def test_schedule_mem_used_with_inputs ( self ) :
base = GlobalCounters . mem_used
x = Tensor . ones ( 256 ) . contiguous ( ) . realize ( )
( x + Tensor . ones ( 256 ) . contiguous ( ) ) . schedule ( )
self . assertEqual ( GlobalCounters . mem_used - base , 1024 )
def test_const_schedule ( self ) :
constv = Tensor . empty ( 2 , 2 ) . lazydata . const_like ( 10 )
check_schedule ( constv , 0 )
def test_const_schedule_contig ( self ) :
constv = Tensor . empty ( 2 , 2 ) . lazydata . const_like ( 10 ) . contiguous ( )
check_schedule ( constv , 1 )
@unittest . skipIf ( Device . DEFAULT != " GPU " , " image only supported on GPU " )
def test_image_matmul ( self ) :
with Context ( IMAGE = 2 ) :
x = Tensor . randn ( ( 9 , 9 ) ) . realize ( )
y = Tensor . randn ( ( 9 , 9 ) ) . realize ( )
out = x @y
run_schedule ( check_schedule ( out , 3 ) )
np . testing . assert_allclose ( out . numpy ( ) , x . numpy ( ) @y . numpy ( ) , atol = 1e-4 , rtol = 1e-4 )
self . assertIsInstance ( out . dtype , ImageDType )
self . assertIsNotNone ( out . lazydata . base . realized )
self . assertIsInstance ( out . lazydata . base . realized . dtype , ImageDType )
def _test_fusion ( self , shapes , f , cnt ) :
with Context ( DEBUG = 0 , TRACK_MATCH_STATS = 0 ) : args = [ Tensor . randn ( s ) . realize ( ) for s in shapes ]
run_schedule ( check_schedule ( compare := f ( * args ) , cnt ) )
if getenv ( " COMPARE " , 1 ) :
import torch
good = f ( * [ torch . tensor ( x . numpy ( ) ) for x in args ] )
np . testing . assert_allclose ( compare . numpy ( ) , good . numpy ( ) , atol = 1e-4 , rtol = 1e-4 )
def test_late_fusion_simple ( self ) :
self . _test_fusion ( [ ( 4 , 4 ) , ( 4 , 1 ) ] , lambda a , b : a . sum ( 1 , keepdim = True ) + b , 1 )
def test_late_fusion_post_reshape ( self ) :
self . _test_fusion ( [ ( 4 , 4 ) , ( 1 , 4 ) ] , lambda a , b : a . sum ( 1 ) . reshape ( b . shape ) + b , 1 )
def test_late_fusion_post_permute ( self ) :
self . _test_fusion ( [ ( 4 , 6 , 4 ) , ( 4 , 4 , 1 ) ] , lambda a , b : a . sum ( 1 , keepdim = True ) . permute ( ( 2 , 0 , 1 ) ) + b , 2 )
def test_late_fusion_double_transpose ( self ) :
self . _test_fusion ( [ ( 32 , 16 , 1 ) ] ,
lambda a : ( a . expand ( 32 , 16 , 16 ) . sum ( ( 2 , ) , keepdim = True ) . permute ( ( 1 , 0 , 2 ) ) + 2 ) . permute ( ( 1 , 0 , 2 ) ) . contiguous ( ) , 1 )
def test_late_fusion_post_expand ( self ) :
self . _test_fusion ( [ ( 32 , 32 ) ] , lambda a : a - a . sum ( 1 ) , 2 )
def test_cast_padded_view ( self ) :
a = Tensor . arange ( 4 ) . reshape ( 1 , 4 )
casted_view = a . pad ( ( ( 0 , 1 ) , ( 0 , 0 ) ) ) . cast ( dtypes . float )
casted_view . realize ( )
self . assertEqual ( casted_view . lazydata . base . realized . size , 4 )
realized_view = casted_view . contiguous ( ) . realize ( )
self . assertEqual ( realized_view . lazydata . base . realized . size , 8 )
self . assertListEqual ( realized_view . tolist ( ) , [ [ 0.0 , 1.0 , 2.0 , 3.0 ] , [ 0.0 , 0.0 , 0.0 , 0.0 ] ] )
# NOTE: we only reorder CAST if it's an EXPAND
def test_cast_after_shrink ( self ) :
a = Tensor . arange ( 4 ) . reshape ( 1 , 4 )
casted_view = a . shrink ( ( ( 0 , 1 ) , ( 0 , 2 ) ) ) . cast ( dtypes . float )
casted_view . realize ( )
self . assertEqual ( casted_view . lazydata . base . realized . size , 2 )
realized_view = casted_view . contiguous ( ) . realize ( )
self . assertEqual ( realized_view . lazydata . base . realized . size , 2 )
self . assertListEqual ( realized_view . tolist ( ) , [ [ 0 , 1 ] ] )
def test_cast_const_view ( self ) :
a = Tensor . ones ( ( 4 , 4 ) , dtype = dtypes . float32 )
casted_view = a . cast ( dtypes . int32 )
run_schedule ( check_schedule ( casted_view , 0 ) )
self . assertIsNone ( casted_view . lazydata . base . realized )
realized_const_view = casted_view . contiguous ( )
run_schedule ( check_schedule ( realized_const_view , 1 ) )
self . assertListEqual ( realized_const_view . tolist ( ) , [ [ 1 , 1 , 1 , 1 ] , [ 1 , 1 , 1 , 1 ] , [ 1 , 1 , 1 , 1 ] , [ 1 , 1 , 1 , 1 ] ] )
def test_cast_padded_const ( self ) :
a = Tensor ( 1 , dtype = dtypes . int32 ) . reshape ( 1 , 1 ) . pad ( ( ( 1 , 1 ) , None ) )
casted_view = a . cast ( dtypes . float32 )
run_schedule ( check_schedule ( casted_view , 0 ) )
realized_const_view = casted_view . contiguous ( )
run_schedule ( check_schedule ( realized_const_view , 1 ) )
self . assertListEqual ( realized_const_view . tolist ( ) , [ [ 0 ] , [ 1 ] , [ 0 ] ] )
class TestIndexing ( unittest . TestCase ) :
def check_schedule ( self , xt : Union [ Tensor , List [ Tensor ] ] , cnt : int ) :
with Context ( FUSE_ARANGE = getenv ( " FUSE_ARANGE " , 1 ) ) :
lst = [ xt ] if isinstance ( xt , Tensor ) else xt
s = Tensor . schedule ( * lst )
lowered = [ x [ 1 ] for x in lower_schedule ( s . copy ( ) ) ]
kernels = [ ei for ei in list ( lowered ) if isinstance ( ei . prg , CompiledRunner ) ]
if FUSE_ARANGE : self . assertEqual ( len ( kernels ) , cnt )
for ei in lowered : ei . run ( do_update_stats = True )
return s
def test_simple_indexing ( self ) :
X = Tensor . randn ( 10 , 10 ) . realize ( )
idxs = Tensor ( [ 0 , 2 ] ) . realize ( )
xt = X [ idxs ]
self . check_schedule ( xt , 2 )
np . testing . assert_equal ( xt . numpy ( ) , X . numpy ( ) [ idxs . numpy ( ) ] )
@unittest . skip ( " TODO: support pads in graph_rewrite " )
def test_simple_indexing_alt ( self ) :
X = Tensor . arange ( 16 ) . reshape ( 4 , 4 )
xt = X [ [ 1 , 2 ] , [ 1 , 2 ] ]
self . check_schedule ( xt , 3 )
np . testing . assert_equal ( xt . numpy ( ) , ( np . arange ( 16 ) . reshape ( 4 , 4 ) ) [ [ 1 , 2 ] , [ 1 , 2 ] ] )
def test_advanced_indexing ( self ) :
X = Tensor . arange ( 10 ) + 1
xt = X [ [ 0 ] ]
self . check_schedule ( xt , 2 )
np . testing . assert_equal ( xt . numpy ( ) , ( np . arange ( 10 ) + 1 ) [ [ 0 ] ] )
@unittest . expectedFailure
def test_advanced_indexing_alt ( self ) :
X = Tensor . arange ( 6 ) . reshape ( 3 , 2 ) + 1
xt = X [ [ Tensor ( [ 2 ] ) , Tensor ( [ 1 ] ) ] ]
self . check_schedule ( xt , 6 )
np . testing . assert_equal ( xt . numpy ( ) , 6 )
@unittest . skip ( " TODO: support pads in graph_rewrite " )
def test_advanced_simple_indexing_combined ( self ) :
X = Tensor . arange ( 16 ) . reshape ( 4 , 4 )
xt = X [ 1 : 2 , [ 1 , 2 ] ]
self . check_schedule ( xt , 2 )
def test_push_through_reshape ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 10 , 20 ) . realize ( )
out = x . argmax ( 1 )
self . check_schedule ( out , 2 )
np . testing . assert_allclose ( out . numpy ( ) , np . argmax ( x . numpy ( ) , 1 ) )
def test_arange_push_through_expand ( self ) :
Tensor . manual_seed ( 0 )
a = Tensor . arange ( 4 , )
b = Tensor . randn ( 4 , 4 ) . realize ( )
out = a + b
self . check_schedule ( out , 1 )
np . testing . assert_allclose ( out . numpy ( ) , np . arange ( 4 ) + b . numpy ( ) )
def test_argmin ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . argmin ( - 1 )
self . check_schedule ( out , 2 )
np . testing . assert_equal ( out . numpy ( ) , x . numpy ( ) . argmin ( axis = - 1 ) )
def test_argmax ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 32 ) . realize ( )
out = x . argmax ( - 1 )
self . check_schedule ( out , 2 )
np . testing . assert_equal ( out . numpy ( ) , x . numpy ( ) . argmax ( axis = - 1 ) )
def test_arange_transposed ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randint ( 4 , 1 ) . realize ( )
a = ( Tensor . arange ( 4 , ) * x ) . T
self . check_schedule ( a , 1 )
np . testing . assert_equal ( a . numpy ( ) , ( np . arange ( 4 ) * x . numpy ( ) ) . T )
def test_arange_transposed_descendants ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randint ( 4 , 1 ) . realize ( )
a = ( Tensor . arange ( 4 , ) * x ) . T
b = Tensor . randint ( 4 , 4 ) . realize ( )
out = a + b
self . check_schedule ( out , 1 )
np . testing . assert_equal ( out . numpy ( ) , ( np . arange ( 4 ) * x . numpy ( ) ) . T + b . numpy ( ) )
def test_arange_index ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 5 , 2 ) . realize ( )
a = Tensor . arange ( 10 )
out = ( x + a [ 2 ] ) . sum ( )
self . check_schedule ( out , 2 )
np . testing . assert_allclose ( out . numpy ( ) , ( x . numpy ( ) + np . arange ( 10 ) [ 2 ] ) . sum ( ) , atol = 1e-5 , rtol = 1e-6 )
@unittest . skip ( " TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous() " )
def test_arange_index_contiguous ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 5 , 2 ) . realize ( )
a = Tensor . arange ( 10 ) . contiguous ( )
out = ( x + a [ 2 ] ) . sum ( )
self . check_schedule ( out , 3 )
np . testing . assert_allclose ( out . numpy ( ) , ( x . numpy ( ) + np . arange ( 10 ) [ 2 ] ) . sum ( ) , atol = 1e-5 , rtol = 1e-6 )
def test_arange_index_child ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 5 , 2 ) . realize ( )
a = Tensor . arange ( 10 ) + 1
out = ( x + a [ 2 ] ) . sum ( )
self . check_schedule ( out , 2 )
np . testing . assert_allclose ( out . numpy ( ) , ( x . numpy ( ) + ( np . arange ( 10 ) + 1 ) [ 2 ] ) . sum ( ) , atol = 1e-5 , rtol = 1e-6 )
@unittest . skip ( " TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous() " )
def test_arange_index_contiguous_child ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 5 , 2 ) . realize ( )
a = ( Tensor . arange ( 10 ) + 1 ) . contiguous ( )
out = ( x + a [ 2 ] ) . sum ( )
self . check_schedule ( out , 3 )
np . testing . assert_allclose ( out . numpy ( ) , ( x . numpy ( ) + ( np . arange ( 10 ) + 1 ) [ 2 ] ) . sum ( ) , atol = 1e-5 , rtol = 1e-6 )
def test_arange_childless_base ( self ) :
a = Tensor . arange ( 4 )
self . check_schedule ( a , 1 )
np . testing . assert_equal ( a . numpy ( ) , np . arange ( 4 ) )
def test_arange_childless_view ( self ) :
a = Tensor . arange ( 4 ) . reshape ( 2 , 2 )
a [ 0 ] = 4
np . testing . assert_equal ( a . numpy ( ) , [ [ 4 , 4 ] , [ 2 , 3 ] ] )
def test_arange_group_childless_base ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randint ( 4 ) . realize ( )
a = Tensor . arange ( 4 ) + x
self . check_schedule ( a , 1 )
np . testing . assert_equal ( a . numpy ( ) , np . arange ( 4 ) + x . numpy ( ) )
def test_arange_group_childless_view ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . ones ( 4 ) . contiguous ( ) . realize ( )
a = Tensor . arange ( 4 ) + x
a [ 0 ] = 6
np . testing . assert_equal ( a . numpy ( ) , [ 6. , 2. , 3. , 4. ] )
#@unittest.skipUnless(Device.DEFAULT in view_supported_devices, "need view")
@unittest . skip ( " BUFFER_VIEW no longer supported on non-disk devices " )
def test_arange_view_op ( self ) :
a = Tensor . arange ( 12 ) . reshape ( 4 , 3 ) . shrink ( ( ( 1 , 2 ) , ( 1 , 3 ) ) ) . contiguous ( )
sched = self . check_schedule ( a , 1 )
self . assertIs ( sched [ 1 ] . ast . op , Ops . BUFFER_VIEW )
np . testing . assert_equal ( a . numpy ( ) , [ [ 4 , 5 ] ] )
@unittest . skipIf ( Device . DEFAULT == " CPU " , " tests copy from ext device " )
def test_arange_shrink_copy ( self ) :
a = Tensor . arange ( 12 ) . reshape ( 4 , 3 ) . shrink ( ( ( 1 , 2 ) , ( 1 , 3 ) ) ) . to ( " CPU " )
sched = self . check_schedule ( a , 1 )
self . assertIs ( sched [ - 1 ] . ast . op , Ops . COPY )
np . testing . assert_equal ( a . numpy ( ) , [ [ 4 , 5 ] ] )
@unittest . skipIf ( Device . DEFAULT == " CPU " , " tests copy from ext device " )
def test_arange_expand_copy ( self ) :
a = Tensor . arange ( 4 ) . reshape ( 2 , 2 , 1 ) . expand ( 2 , 2 , 2 ) . contiguous ( ) . to ( " CPU " )
sched = self . check_schedule ( a , 1 )
self . assertIs ( sched [ 1 ] . ast . op , Ops . COPY )
self . assertIs ( sched [ 0 ] . ast . src [ 0 ] . src [ 2 ] . op , Ops . ADD )
np . testing . assert_equal ( a . numpy ( ) , [ [ [ 0 , 0 ] , [ 1 , 1 ] ] , [ [ 2 , 2 ] , [ 3 , 3 ] ] ] )
@unittest . skip ( " TODO: support pads in graph_rewrite " )
#@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_precompute_freqs_cis ( self ) :
args = { " dim " : 32 if CI else 128 , " end " : 2048 if CI else 8192 , " theta " : 10000 , " dtype " : dtypes . half }
fused = precompute_freqs_cis ( * * args )
self . check_schedule ( fused , 1 )
if getenv ( " CHECK " , 1 ) :
ref = precompute_freqs_cis ( * * args )
run_schedule ( check_schedule ( ref , 3 ) )
np . testing . assert_equal ( fused . numpy ( ) , ref . numpy ( ) )
@unittest . skip ( " TOOD: FUSE_ARANGE overrules this contiguous " )
def test_fuse_assign_contiguous ( self ) :
x = Tensor . zeros ( 4 , 4 , dtype = dtypes . int ) . contiguous ( ) . realize ( )
a = Tensor . arange ( 8 ) . reshape ( 4 , 2 )
self . check_schedule ( x . shrink ( ( None , ( 0 , 2 ) ) ) . assign ( a . contiguous ( ) ) , 2 )
np . testing . assert_equal ( x . numpy ( ) , [ [ 0 , 1 , 0 , 0 ] , [ 2 , 3 , 0 , 0 ] , [ 4 , 5 , 0 , 0 ] , [ 6 , 7 , 0 , 0 ] ] )
def test_assign_non_contiguous ( self ) :
x = Tensor . zeros ( 4 , 4 , dtype = dtypes . int ) . contiguous ( ) . realize ( )
y = Tensor . randint ( 4 , 2 )
a = Tensor . arange ( 8 ) . reshape ( 4 , 2 ) + y
x . shrink ( ( None , ( 0 , 2 ) ) ) . assign ( a ) . realize ( )
xref = np . zeros ( ( 4 , 4 ) , dtype = int )
xref [ : , : 2 ] = np . arange ( 8 ) . reshape ( 4 , 2 ) + y . numpy ( )
np . testing . assert_equal ( x . numpy ( ) , xref )
def test_sparse_categorical_crossentropy_simple ( self ) :
X = Tensor ( [ [ 0 , 2 , 3 ] , [ 1 , 2 , 3 ] ] ) . realize ( )
Y = Tensor ( [ 1 , 2 ] ) . realize ( )
loss = X . sparse_categorical_crossentropy ( Y )
self . check_schedule ( loss , 4 )
np . testing . assert_allclose ( loss . item ( ) , 0.878309 , atol = 1e-5 , rtol = 1e-6 )
@unittest . skipIf ( Device . DEFAULT == " WEBGPU " , " Validation error on WebGPU " )
def test_mnist_val ( self ) :
from tinygrad . nn . datasets import mnist
import torch
_ , Y_train , _ , _ = mnist ( )
samples = Tensor . randint ( BS := getenv ( " BS " , 512 ) , high = cast ( int , Y_train . shape [ - 1 ] ) ) . realize ( )
yt = Tensor . randn ( BS , 10 ) . realize ( )
with Context ( SPLIT_REDUCEOP = 0 ) :
loss = yt . sparse_categorical_crossentropy ( Y_train [ samples ] )
self . check_schedule ( loss , 6 )
loss_fused = loss . numpy ( )
loss_ref = torch . nn . CrossEntropyLoss ( ) ( torch . tensor ( yt . numpy ( ) ) , torch . tensor ( Y_train . numpy ( ) ) [ torch . tensor ( samples . numpy ( ) ) ] )
np . testing . assert_allclose ( loss_fused , loss_ref . numpy ( ) , atol = 1e-6 , rtol = 1e-6 )
@unittest . expectedFailure
def test_arange_fuse_grouped_children ( self ) :
X = Tensor . randn ( 4 , 4 ) . realize ( )
r = ( X + Tensor . arange ( 16 ) . reshape ( 4 , 4 ) ) . sum ( )
out0 = r + 2
out1 = r + 3
self . check_schedule ( [ out0 , out1 ] , 1 )
r_ref = ( X . numpy ( ) + np . arange ( 16 ) . reshape ( 4 , 4 ) ) . sum ( )
np . testing . assert_allclose ( out0 . numpy ( ) , r_ref + 2 , rtol = 2e-7 )
np . testing . assert_allclose ( out1 . numpy ( ) , r_ref + 3 , rtol = 2e-7 )
@unittest . expectedFailure
def test_fold_arange_view ( self ) :
X = Tensor . randn ( 4 , 4 ) . realize ( )
r = ( X + Tensor . arange ( 16 ) . reshape ( 4 , 4 ) . contiguous ( ) ) . sum ( 1 , keepdim = True )
self . check_schedule ( [ r ] , 1 )
np . testing . assert_allclose ( r . numpy ( ) , ( X . numpy ( ) + np . arange ( 16 ) . reshape ( 4 , 4 ) ) . sum ( 1 , keepdims = True ) )
@unittest . skip ( " multi output isn ' t supported " )
def test_multiview_arange_children ( self ) :
X = Tensor . randn ( 2 , 3 , 4 , 4 ) . numpy ( )
with Context ( FUSE_ARANGE = 1 ) :
compare = Tensor ( X ) . interpolate ( size = ( 2 , 2 ) , mode = " linear " ) . numpy ( )
with Context ( FUSE_ARANGE = 0 , TRACK_MATCH_STATS = 0 ) :
ref = Tensor ( X ) . interpolate ( size = ( 2 , 2 ) , mode = " linear " ) . numpy ( )
np . testing . assert_allclose ( ref , compare , atol = 1e-5 , rtol = 1e-6 )
def test_recursive_swizzle ( self ) :
a = Tensor ( [ 1 , 2 , 3 , 4 ] ) . realize ( )
for _ in range ( 24 ) : a = a + a
new_uop = swizzle_rewrite ( a . lazydata . reshape ( ( 4 , 1 ) ) )
self . assertEqual ( new_uop . st , ShapeTracker . from_shape ( ( 4 , ) ) . reshape ( ( 4 , 1 ) ) )
self . assertEqual ( swizzle_cnt ( new_uop ) , 0 )
def test_no_rewrite_elementwise ( self ) :
a = Tensor . empty ( 32 , 32 )
b = Tensor . empty ( 32 , 32 )
sink = ( a + b ) . schedule ( ) [ 0 ] . ast
self . assertEqual ( swizzle_cnt ( sink ) , 0 )
def test_simple_store_reshape ( self ) :
a = Tensor . empty ( 32 , 32 ) . sum ( axis = 1 ) + Tensor . empty ( 1 , 32 )
ast = a . schedule ( ) [ 0 ] . ast
self . assertEqual ( ast . shape , ( 32 , 1 ) )
self . assertEqual ( a . lazydata . shape , ( 1 , 32 ) )
def test_no_reshape_reduceop ( self ) :
a = Tensor . empty ( 32 , 32 ) . sum ( axis = ( 1 , ) ) . contiguous ( )
ast = a . schedule ( ) [ 0 ] . ast
self . assertEqual ( ast . shape , ( 32 , 1 ) )
self . assertEqual ( a . lazydata . shape , ( 32 , ) )
@track_rewrites ( named = True )
def swizzle_rewrite ( u : UOp ) - > UOp : return graph_rewrite ( graph_rewrite ( u , view_left ) , view_right )
def swizzle_cnt ( u : UOp ) - > int : return len ( [ x for x in u . toposort if x . op is Ops . VIEW and len ( x . src ) != 0 and x . src [ 0 ] . op is not Ops . BUFFER ] )
class TestSwizzle ( unittest . TestCase ) :
def test_swizzle_simple ( self ) :
Tensor . manual_seed ( 0 )
with Context ( DEBUG = 0 , TRACK_MATCH_STATS = 0 ) :
a = Tensor . randint ( 32 , 32 ) . realize ( )
r = ( a + a ) . sum ( 1 ) . sum ( 0 )
# double reduce collapses to a single reduce
with Context ( DONT_GROUP_REDUCES = 1 ) :
run_schedule ( check_schedule ( r , 1 ) )
self . assertEqual ( r . numpy ( ) , ( a . numpy ( ) + a . numpy ( ) ) . sum ( 1 ) . sum ( 0 ) )
def test_single_swizzle ( self ) :
Tensor . manual_seed ( 0 )
with Context ( DEBUG = 0 , TRACK_MATCH_STATS = 0 ) :
a = Tensor . randint ( 4 , 1 ) . realize ( )
b = Tensor . ones ( ( 1 , 1 ) , dtype = a . dtype ) . contiguous ( ) . realize ( )
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
r = a . sum ( 0 ) + b
run_schedule ( check_schedule ( r , 1 ) )
self . assertEqual ( r . numpy ( ) , a . numpy ( ) . sum ( 0 ) + 1 )
def test_double_swizzle_possible ( self ) :
Tensor . manual_seed ( 0 )
with Context ( DEBUG = 0 , TRACK_MATCH_STATS = 0 ) :
a = Tensor . randint ( 4 , ) . realize ( )
b = Tensor . randint ( 4 , ) . realize ( )
# parallel reduce!
add = a . sum ( 0 ) + b . sum ( 0 )
with Context ( DONT_GROUP_REDUCES = 1 ) :
run_schedule ( check_schedule ( add , 1 ) )
self . assertEqual ( add . numpy ( ) , a . numpy ( ) . sum ( 0 ) + b . numpy ( ) . sum ( 0 ) )
@unittest . skip ( " TODO: how do we express the norm " )
def test_softmax_one_kernel ( self ) :
Tensor . manual_seed ( 0 )
with Context ( DEBUG = 0 , TRACK_MATCH_STATS = 0 ) :
a = Tensor . randn ( 32 , 32 ) . realize ( )
t = a . softmax ( )
with Context ( DONT_GROUP_REDUCES = 1 , DONT_REALIZE_EXPAND = 1 ) :
check_schedule ( t , 1 )
def test_argmax_one_kernel ( self ) :
Tensor . manual_seed ( 0 )
with Context ( DEBUG = 0 , TRACK_MATCH_STATS = 0 ) :
a = Tensor . randn ( 10 , 20 ) . realize ( )
t = a . argmax ( 0 )
with Context ( DONT_GROUP_REDUCES = 1 , DONT_REALIZE_EXPAND = 1 ) : t . realize ( )
def test_swizzle_reduceop ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 4 ) . realize ( )
y = Tensor . randn ( 4 , 4 , 4 ) . realize ( )
out = x . reshape ( 4 , 4 , 1 ) . expand ( 4 , 4 , 4 ) . sum ( axis = ( 1 , ) ) + y
with Context ( DONT_REALIZE_EXPAND = 1 , DONT_GROUP_REDUCES = 1 ) :
run_schedule ( check_schedule ( out , 1 ) )
np . testing . assert_allclose ( out . numpy ( ) , np . tile ( x . numpy ( ) . reshape ( 4 , 4 , 1 ) , ( 1 , 1 , 4 ) ) . sum ( axis = 1 ) + y . numpy ( ) )
def test_permute_rewrite ( self ) :
x = Tensor . randn ( 4 , 4 , 16 ) . realize ( )
y = Tensor . randn ( 4 , 1 , 16 ) . realize ( )
z = Tensor . randn ( 4 , 4 , 1 ) . realize ( )
t = ( x * y ) . sum ( axis = ( 0 , 2 ) ) . reshape ( 1 , 4 , 1 ) . permute ( 0 , 2 , 1 ) + z
with Context ( DONT_GROUP_REDUCES = 1 , DONT_REALIZE_EXPAND = 1 ) : run_schedule ( check_schedule ( t , 1 ) )
t_np = ( x . numpy ( ) * y . numpy ( ) ) . sum ( axis = ( 0 , 2 ) ) . reshape ( 1 , 4 , 1 ) . transpose ( 0 , 2 , 1 ) + z . numpy ( )
np . testing . assert_allclose ( t . numpy ( ) , t_np , atol = 1e-6 , rtol = 1e-3 )
@unittest . skip ( " TODO: this swizzle isn ' t resolvable when there ' s a mask " )
def test_swizzle_failure_permute ( self ) :
a = Tensor . empty ( 45 , 65 ) . T . reshape ( 65 , 1 , 45 ) . pad ( ( None , None , ( 0 , 45 ) ) ) . expand ( 65 , 45 , 90 )
b = Tensor . empty ( 45 , 65 )
a_reduce = a . sum ( axis = ( 2 , ) , keepdim = True ) . sum ( axis = ( 1 , ) )
b_reduce = b . sum ( axis = ( 0 , ) )
t = a_reduce + b_reduce
with Context ( DONT_GROUP_REDUCES = 1 , DONT_REALIZE_EXPAND = 1 ) : run_schedule ( check_schedule ( t , 1 ) )
def test_parallel_reduce_possible ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 2 , 2 ) . realize ( )
y = Tensor . randn ( 4 , 2 , 2 ) . realize ( )
t = x . sum ( axis = 1 ) + y . sum ( axis = 1 )
with Context ( DONT_GROUP_REDUCES = 1 ) : run_schedule ( check_schedule ( t , 1 ) )
np . testing . assert_allclose ( t . numpy ( ) , x . numpy ( ) . sum ( axis = 1 ) + y . numpy ( ) . sum ( axis = 1 ) , atol = 1e-6 , rtol = 1e-3 )
# kernels can only have 1 or n in each dim
@unittest . expectedFailure
def test_dont_parallelize_different_n ( self ) :
Tensor . manual_seed ( 0 )
x = Tensor . randn ( 4 , 2 , 2 ) . realize ( )
y = Tensor . randn ( 4 , 3 , 2 ) . realize ( )
t = x . sum ( axis = 1 ) + y . sum ( axis = 1 )
with Context ( DONT_GROUP_REDUCES = 1 ) : run_schedule ( check_schedule ( t , 1 ) )
np . testing . assert_allclose ( t . numpy ( ) , x . numpy ( ) . sum ( axis = 1 ) + y . numpy ( ) . sum ( axis = 1 ) , atol = 1e-6 , rtol = 1e-3 )
def test_unsafe_pad ( self ) :
x = Tensor . full ( ( 2 , 2 ) , 1.0 ) . contiguous ( )
y = x * x . sum ( ( 1 , ) ) . reciprocal ( )
t = y . pad ( ( ( 0 , 1 ) , None ) ) . contiguous ( )
swizzled = swizzle_rewrite ( t . lazydata )
sched = check_schedule ( swizzled . sink ( ) , 3 )
output_buffer = sched [ - 1 ] . bufs [ 0 ]
run_schedule ( sched )
self . assertListEqual ( output_buffer . as_buffer ( ) . cast ( " f " ) . tolist ( ) , [ 0.5 , 0.5 , 0.5 , 0.5 , 0. , 0. ] )
def store_val ( si : ScheduleItem ) : return si . ast . src [ 0 ] . src [ 2 ]
zero_pm = UPat ( Ops . CONST , arg = 0 )
class TestView ( unittest . TestCase ) :
def test_all_masked_out ( self ) :
# start with non CONST Ops
a = Tensor . rand ( 10 , 10 ) . realize ( )
# all masked out, degrades to const 0
b = a . pad ( ( ( 0 , 10 ) , None ) ) [ 10 : ]
sched = check_schedule ( b . contiguous ( ) , 1 )
assert zero_pm . match ( store_val ( sched [ - 1 ] ) , { } )
run_schedule ( sched )
np . testing . assert_equal ( b . numpy ( ) , 0 )
def test_mask_dim_1 ( self ) :
# mask out dim = 1 works too
a = Tensor . rand ( 10 , 10 ) . realize ( )
b = a . pad ( ( None , ( 0 , 10 ) ) ) [ : , 10 : ]
assert b . shape == ( 10 , 10 )
sched = check_schedule ( b . contiguous ( ) , 1 )
self . assertEqual ( sched [ - 1 ] . ast . full_shape , ( 10 , 10 ) )
assert zero_pm . match ( store_val ( sched [ - 1 ] ) , { } )
run_schedule ( sched )
np . testing . assert_equal ( b . numpy ( ) , 0 )
def test_zero_size_alt ( self ) :
a = Tensor . empty ( 135 , 0 , 9 )
b = a . pad ( ( ( 0 , 0 ) , ( 0 , 0 ) , ( 18 , 0 ) ) )
check_schedule ( b , 0 )
def test_partial_mask ( self ) :
# partial masked out does not degrade into CONST
a = Tensor . rand ( 10 , 10 ) . realize ( )
b = a . pad ( ( ( 0 , 5 ) , None ) ) [ 5 : ]
assert b . shape == ( 10 , 10 )
sched = check_schedule ( b . contiguous ( ) , 1 )
self . assertEqual ( store_val ( sched [ - 1 ] ) . op , Ops . LOAD )
self . assertEqual ( store_val ( sched [ - 1 ] ) . st_arg , b . lazydata . st )
run_schedule ( sched )
np . testing . assert_allclose ( b . numpy ( ) , np . pad ( a . numpy ( ) , ( ( 0 , 5 ) , ( 0 , 0 ) ) ) [ 5 : ] )
# a*VIEW(x), where VIEW(x) = 0
# x collapses along with its children
def test_parent_view_collapses ( self ) :
a = Tensor ( [ 1 , 2 ] )
b = Tensor . arange ( 3 ) . contiguous ( )
bv = b . pad ( ( ( 0 , 2 ) , ) ) [ - 2 : ]
# this becomes a late a*0
late_mul = a * bv
check_schedule ( late_mul , 0 )
# the arange doesn't realize
self . assertIsNone ( b . lazydata . base . realized )
# mul doesn't realize
self . assertIsNone ( late_mul . lazydata . base . realized )
self . assertEqual ( late_mul . tolist ( ) , [ 0 , 0 ] )
# SINK has two branches:
# a*VIEW(x), where VIEW(x) = 0
# x+2
# as long as one child realizes, x does not collapse
def test_parent_multiple_children_no_collapse ( self ) :
a = Tensor ( [ 1 , 2 ] )
b = Tensor . arange ( 3 ) . contiguous ( )
bv = b . pad ( ( ( 0 , 2 ) , ) ) [ - 2 : ]
late_mul = a * bv
other_child = b + 2
s = check_schedule ( [ late_mul , other_child ] , 2 )
# the arange becomes a BUFFER
self . assertIs ( b . lazydata . base . op , Ops . BUFFER )
# mul still collapses
self . assertIs ( late_mul . lazydata . base . op , Ops . CONST )
run_schedule ( s )
self . assertEqual ( other_child . tolist ( ) , [ 2 , 3 , 4 ] )
def tensor_rewrite ( t ) - > UOp : return graph_rewrite ( t . lazydata . base , remove_movement_ops + symbolic_simple )
class TestSimplifier ( unittest . TestCase ) :
def test_sink_childless_const ( self ) :
x = Tensor ( 0 )
check_schedule ( x , 0 )
def test_sink_childless_const_alt_expanded ( self ) :
x = Tensor . zeros ( 4 , 4 ) . contiguous ( )
check_schedule ( x , 1 )
def test_all_const_uops ( self ) :
a = Tensor ( 4 ) * Tensor ( 2 )
sink = tensor_rewrite ( a )
assert UPat . cvar ( ) . match ( sink , { } )
def test_masked_const_elementwise ( self ) :
a = Tensor . eye ( 10 ) @Tensor . eye ( 10 )
sink = tensor_rewrite ( a )
assert UPat ( Ops . REDUCE_AXIS , src = ( UPat . cvar ( ) . view ( ) * UPat . cvar ( ) . view ( ) , ) ) . match ( sink , { } )
def test_elementwise_ops ( self ) :
a = Tensor . empty ( 4 , 4 , dtype = dtypes . int )
sink = tensor_rewrite ( a * 0 )
assert UPat ( Ops . CONST , arg = 0 ) . match ( sink , { } )
self . assertIs ( tensor_rewrite ( a * 1 ) . base , a . lazydata . base )
self . assertIs ( tensor_rewrite ( a + 0 ) . base , a . lazydata . base )
self . assertIs ( tensor_rewrite ( a / / 1 ) . base , a . lazydata . base )
def test_cast_folding ( self ) :
a = Tensor ( 1.0 ) . cast ( dtypes . int )
sink = tensor_rewrite ( a )
assert UPat . cvar ( dtype = dtypes . int ) . match ( sink , { } )
def test_const_folding_mul ( self ) :
a = Tensor ( [ 1 ] )
sink = tensor_rewrite ( a * 0 )
assert UPat ( Ops . CONST , arg = 0 ) . match ( sink , { } ) , f " expected { sink } to collapse to a const 0 "
assert sink . shape == a . shape
def test_const_folding_ne ( self ) :
a = Tensor ( [ 1 ] )
sink = tensor_rewrite ( a != a )
assert UPat ( Ops . CONST , arg = False ) . match ( sink , { } ) , f " expected { sink } to collapse to a const False "
assert sink . shape == a . shape
def test_const_folding_lt ( self ) :
a = Tensor ( [ 1 ] )
sink = tensor_rewrite ( a < a )
assert UPat ( Ops . CONST , arg = False ) . match ( sink , { } ) , f " expected { sink } to collapse to a const False "
assert sink . shape == a . shape
tensor_const_pm = PatternMatcher ( [
( UPat ( Ops . CONST , src = ( UPat ( Ops . VIEW , src = ( UPat ( Ops . DEVICE ) , ) ) , ) ) , lambda : True ) ,
( UPat ( Ops . BIND , src = ( UPat ( Ops . DEFINE_VAR , src = ( UPat ( Ops . VIEW , src = ( UPat ( Ops . DEVICE ) , ) ) ) ) , UPat ( Ops . CONST ) ) ) , lambda : True ) ,
] )
class TestConst ( unittest . TestCase ) :
# ** part 1: basic functionality of a tensor directly created from CONST
def test_tensor_const ( self ) :
a = Tensor ( 1 )
print ( a . lazydata )
self . assertTrue ( tensor_const_pm . rewrite ( a . lazydata ) )
def test_tensor_variable ( self ) :
vv = UOp . variable ( " a " , 0 , 10 ) . bind ( 1 )
a = Tensor ( vv )
print ( a . lazydata )
self . assertTrue ( tensor_const_pm . rewrite ( a . lazydata ) )
def test_const_schedule ( self ) :
a = Tensor . ones ( ( 4 , 4 ) )
sched = a . schedule ( )
self . assertEqual ( len ( sched ) , 0 )
def test_const_contiguous_schedule ( self ) :
# this ends up in the big graph
a = Tensor . ones ( ( 4 , ) ) . contiguous ( )
sched = a . schedule ( )
self . assertEqual ( len ( sched ) , 1 )
def test_const_ast ( self ) :
a = Tensor . ones ( ( 4 , ) ) . pad ( ( 1 , 1 ) ) . contiguous ( )
sched = a . schedule ( )
print ( sched [ 0 ] . ast )
const_ast_pattern = UPat ( Ops . SINK , src = ( UPat . store ( UPat ( ) , UPat ( ) , UPat . where ( UPat ( Ops . VALID ) , UPat . cvar ( " x " ) , UPat ( Ops . CONST , arg = 0 ) ) ) , ) )
self . assertEqual ( len ( const_ast_pattern . match ( sched [ 0 ] . ast , { } ) ) , 1 )
run_schedule ( sched )
self . assertListEqual ( a . tolist ( ) , [ 0 , 1 , 1 , 1 , 1 , 0 ] )
def test_unmasked_const_ast ( self ) :
a = Tensor . ones ( ( 4 , ) ) . contiguous ( )
sched = a . schedule ( )
print ( sched [ 0 ] . ast )
const_ast_pattern = UPat ( Ops . SINK , src = ( UPat . store ( UPat ( ) , UPat ( ) , UPat ( Ops . CONST ) ) , ) )
self . assertEqual ( len ( const_ast_pattern . match ( sched [ 0 ] . ast , { } ) ) , 1 )
run_schedule ( sched )
self . assertListEqual ( a . tolist ( ) , [ 1 , 1 , 1 , 1 ] )
# ** part 2: scheduler behavior when const folding happens later
def test_const_folding_no_realize ( self ) :
a = Tensor ( [ 1 , 2 , 3 , 4 ] ) * 0
sched = a . schedule ( )
self . assertEqual ( len ( sched ) , 0 )
def test_src_const_folding ( self ) :
with Context ( TRACK_MATCH_STATS = 0 ) :
a = Tensor . full ( ( 4 , ) , 1 ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 4 , ) , 2 ) . contiguous ( ) . realize ( )
mul0 = a * 0
add = b + mul0
sched = add . schedule ( )
self . assertEqual ( len ( sched ) , 0 )
# b+0 and b share the same underlying device memory
self . assertIs ( add . lazydata . buffer , b . lazydata . buffer )
self . assertListEqual ( add . tolist ( ) , [ 2 , 2 , 2 , 2 ] )
def test_src_masked_const_folding ( self ) :
with Context ( TRACK_MATCH_STATS = 0 ) :
a = Tensor . full ( ( 4 , ) , 1 ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 6 , ) , 2 ) . contiguous ( ) . realize ( )
mul0 = a * 0
add = b + mul0 . pad ( ( 1 , 1 ) , value = 2 )
sched = add . schedule ( )
self . assertEqual ( len ( sched ) , 1 )
run_schedule ( sched )
# add gets assigned to a new buffer
self . assertIsNot ( add . lazydata . base . realized , b . lazydata . base . realized )
self . assertListEqual ( add . tolist ( ) , [ 4 , 2 , 2 , 2 , 2 , 4 ] )
# ** part 3: Tensor variable bindings
#@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
def test_var_schedule ( self ) :
vv = UOp . variable ( " a " , 0 , 10 ) . bind ( 1 )
a = Tensor ( vv )
sched = a . schedule ( )
self . assertEqual ( len ( sched ) , 0 )
def test_add_tvar ( self ) :
vv = UOp . variable ( " a " , 0 , 10 ) . bind ( 1 )
a = Tensor ( vv ) + 2
sched , var_vals = a . schedule_with_vars ( )
self . assertEqual ( len ( sched ) , 1 )
run_schedule ( sched , var_vals )
self . assertEqual ( a . tolist ( ) , 3 )
@unittest . skipIf ( Device . DEFAULT == " CPU " , " tests copy from another device to cpu " )
class TestCopyFolding ( unittest . TestCase ) :
def test_const_copy_is_free ( self ) :
b = Tensor ( 1 ) . to ( " CPU " )
check_schedule ( b , 0 , filter_sink = False )
assert b . item ( ) == 1
def test_late_const_copy_folding ( self ) :
a = Tensor . arange ( 3 ) . realize ( )
zeros = Tensor . zeros ( 3 ) . realize ( )
b = ( a * zeros ) . to ( " CPU " )
run_schedule ( check_schedule ( b , 0 , filter_sink = False ) )
self . assertListEqual ( b . tolist ( ) , [ 0 , 0 , 0 ] )
def test_alu_after_copy ( self ) :
a = Tensor . ones ( ( 4 , ) ) . to ( " CPU " ) . lazydata
b = Tensor . empty ( 4 , device = " CPU " ) . lazydata
add = a + b
add = schedule_graph_rewrite ( add )
assert all_same ( [ x . device for x in add . src ] ) , f " ALU has different devices! { [ x . device for x in add . src ] } "
def test_copy_to_same_device ( self ) :
a = Tensor . empty ( 4 ) . lazydata
b = a . copy_to_device ( a . device )
check_schedule ( b , 0 , filter_sink = False )
b = schedule_graph_rewrite ( b )
# NOTE: Tensor.empty(4) always creates a VIEW(BUFFER) with ShapeTracker((4,)), we simplify this to jsut a BUFFER
# in the scheduler because buffer already has shape (4,)
self . assertIs ( b , a . base )
def test_copy_to_same_device_alt ( self ) :
a = Tensor . empty ( 4 , 4 ) . lazydata
b = a . copy_to_device ( a . device )
check_schedule ( b , 0 , filter_sink = False )
b = schedule_graph_rewrite ( b )
self . assertIs ( b . base , a . base )
def test_clone ( self ) :
a = Tensor . empty ( 4 ) . lazydata
check_schedule ( a . clone ( ) , 1 , filter_sink = False )
# NOTE: moving copy before view might change this
def test_shrink_copy ( self ) :
a = Tensor . arange ( 4 )
view = a . shrink ( ( ( 0 , 2 ) , ) )
b = view . clone ( )
run_schedule ( check_schedule ( b , 2 , filter_sink = False ) )
self . assertEqual ( b . lazydata . base . buffer . size , 2 )
self . assertEqual ( b . lazydata . size , 2 )
self . assertListEqual ( b . tolist ( ) , [ 0 , 1 ] )
def test_expanded_copy ( self ) :
a = Tensor . arange ( 2 )
view = a . reshape ( 2 , 1 ) . expand ( 2 , 2 )
b = view . clone ( )
run_schedule ( check_schedule ( b , 2 , filter_sink = False ) )
self . assertEqual ( b . lazydata . base . buffer . size , 2 )
self . assertEqual ( b . lazydata . size , 4 )
self . assertListEqual ( b . tolist ( ) , [ [ 0 , 0 ] , [ 1 , 1 ] ] )
def test_permuted_copy ( self ) :
a = Tensor . arange ( 4 )
b = a . reshape ( 2 , 2 ) . permute ( 1 , 0 )
b . realize ( )
self . assertListEqual ( b . tolist ( ) , [ [ 0 , 2 ] , [ 1 , 3 ] ] )
def test_permute_on_disk ( self ) :
with open ( temp ( ' dt_arange_4_permute ' ) , " wb " ) as f : f . write ( Tensor . arange ( 4 ) . realize ( ) . lazydata . base . buffer . as_buffer ( ) )
a = Tensor . empty ( 4 , dtype = dtypes . int32 , device = f " disk: { temp ( ' dt_arange_4_permute ' ) } " )
b = a . reshape ( 2 , 2 ) . permute ( 1 , 0 ) . to ( " CPU " )
b . realize ( )
self . assertListEqual ( b . tolist ( ) , [ [ 0 , 2 ] , [ 1 , 3 ] ] )
def test_permute_after_shrink ( self ) :
a = Tensor . arange ( 5 )
b = a . shrink ( ( ( 0 , 4 ) , ) ) . reshape ( 2 , 2 ) . permute ( 1 , 0 ) . to ( " CPU " )
b . realize ( )
self . assertListEqual ( b . tolist ( ) , [ [ 0 , 2 ] , [ 1 , 3 ] ] )
# NOTE: disk permute must come after COPY
# TODO: this is wrong because of the permute
@unittest . expectedFailure
def test_permute_after_shrink_on_disk ( self ) :
with open ( temp ( ' dt_arange_5_permute ' ) , " wb " ) as f : f . write ( Tensor . arange ( 5 ) . realize ( ) . lazydata . base . buffer . as_buffer ( ) )
a = Tensor . empty ( 5 , dtype = dtypes . int32 , device = f " disk: { temp ( ' dt_arange_5_permute ' ) } " )
b = a . shrink ( ( ( 0 , 4 ) , ) ) . reshape ( 2 , 2 ) . permute ( 1 , 0 ) . to ( " CPU " )
b . realize ( )
self . assertListEqual ( b . tolist ( ) , [ [ 0 , 2 ] , [ 1 , 3 ] ] )
class TestTensorUOpSpec ( unittest . TestCase ) :
def test_const_must_be_unmasked ( self ) :
a = Tensor . ones ( ( 4 , 4 ) ) . pad ( ( 2 , 2 ) )
unsafe_push_views = PatternMatcher ( [
( UPat . cvar ( " root " ) . view ( name = " view " ) , lambda root , view : root . replace ( src = tuple ( x . view ( view . st ) for x in root . src ) ) ) ,
] )
a . lazydata = graph_rewrite ( a . lazydata . sink ( ) , remove_movement_ops + merge_views + unsafe_push_views )
with self . assertRaisesRegex ( RuntimeError , " UOp verification failed " ) :
a . schedule ( )
def test_expanded_const_ok ( self ) :
a = Tensor . ones ( ( 4 , 4 ) )
t = graph_rewrite ( a . lazydata . sink ( ) , remove_movement_ops + merge_views )
create_schedule_with_vars ( t )
# NOTE: changing symbolic CONST VIEWs is not allowed
@unittest . expectedFailure
def test_symbolic_shape_ok ( self ) :
a = Tensor . ones ( 4 )
vi = UOp . variable ( " i " , 1 , 10 ) . bind ( 4 )
a . lazydata = graph_rewrite ( a . reshape ( vi ) . sum ( ) . lazydata , remove_movement_ops + merge_views )
a . schedule ( )
class TestBufferUOp ( unittest . TestCase ) :
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
def test_buffer_has_buffer ( self ) :
buf = Tensor . empty ( 10 )
self . assertIsNotNone ( buf . lazydata . buffer )
self . assertEqual ( buf . lazydata . st , ShapeTracker . from_shape ( ( 10 , ) ) )
# the device Buffer remains unallocated until it's we run the schedule
self . assertFalse ( buf . lazydata . buffer . is_allocated ( ) )
add = buf + 1
sched = add . schedule ( )
self . assertFalse ( buf . lazydata . buffer . is_allocated ( ) )
run_schedule ( sched )
self . assertTrue ( buf . lazydata . buffer . is_allocated ( ) )
def test_buffer_has_unique_buffer ( self ) :
buf = Tensor . empty ( 10 )
buf1 = buf . lazydata . buffer
buf2 = buf . lazydata . buffer
self . assertIs ( buf1 , buf2 )
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
def test_buffer_view_allowed ( self ) :
add = Tensor . empty ( 1 , 1 ) + Tensor . empty ( 1 , 1 )
add . realize ( )
self . assertIsNotNone ( add . lazydata . buffer )
self . assertEqual ( add . lazydata . shape , ( 1 , 1 ) )
def test_buffer_view_not_allowed ( self ) :
permuted_view = Tensor . empty ( 1 , 2 , 3 ) . permute ( 0 , 2 , 1 )
merged = graph_rewrite ( permuted_view . lazydata , remove_movement_ops )
with self . assertRaisesRegex ( AssertionError , " VIEW only works here if it ' s contiguous " ) :
merged . buffer # cannot access Buffer of a non contiguous VIEW
def test_buffer_only_after_realize ( self ) :
a = Tensor ( [ 1 ] ) + Tensor ( [ 2 ] )
# accessing realized will return None
self . assertIsNone ( a . lazydata . realized )
# accessing Buffer will assert
with self . assertRaisesRegex ( AssertionError , " must be BUFFER " ) :
a . lazydata . buffer # there is no BUFFER on an unrealized ADD
# Buffer only exists once we realize it
a . realize ( )
self . assertIsNotNone ( a . lazydata . buffer )
def test_const_does_not_realize ( self ) :
a = Tensor ( 1 ) + Tensor ( 2 )
run_schedule ( check_schedule ( a , 0 ) )
self . assertIsNone ( a . lazydata . base . realized )
def test_var_does_not_realize ( self ) :
a = Tensor ( UOp . variable ( " a " , 0 , 10 ) . bind ( 1 ) )
run_schedule ( check_schedule ( a , 0 ) )
self . assertIsNone ( a . lazydata . base . realized )
def test_view_does_not_realize ( self ) :
a = Tensor . randn ( 1 , 4 ) . expand ( 4 , 4 )
a . realize ( )
self . assertEqual ( a . lazydata . base . realized . size , 4 )
a2 = a . contiguous ( ) . realize ( )
self . assertEqual ( a2 . lazydata . base . realized . size , 16 )
class TestContiguous ( unittest . TestCase ) :
def test_contiguous_buffer ( self ) :
a = Tensor . empty ( 4 )
b = a . contiguous ( )
check_schedule ( b , 0 )
def test_contiguous_buffer_view ( self ) :
a = Tensor . empty ( 4 )
b = a . reshape ( ( 2 , 2 ) ) . contiguous ( )
check_schedule ( b , 0 )
def test_non_contiguous_buffer_view ( self ) :
a = Tensor . empty ( 4 , 1 )
b = a . expand ( ( 4 , 4 ) ) . contiguous ( )
check_schedule ( b , 1 )
def test_size_change_buffer_view ( self ) :
a = Tensor . empty ( 4 )
b = a . reshape ( ( 1 , 1 , 4 ) ) . shrink ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 3 ) ) ) . contiguous ( )
check_schedule ( b , 1 )
def test_double_contiguous_realizes_once ( self ) :
a = Tensor . empty ( 4 , 1 )
b = a . expand ( ( 4 , 4 ) ) . contiguous ( ) . contiguous ( )
check_schedule ( b , 1 )
def test_view_does_not_realize ( self ) :
a = Tensor . empty ( 4 )
b = a . expand ( ( 4 , 4 ) )
check_schedule ( b , 0 )
self . assertEqual ( b . lazydata . base . buffer . size , 4 )
def test_contiguous_view_realizes ( self ) :
a = Tensor . empty ( 4 )
b = a . expand ( ( 4 , 4 ) ) . contiguous ( )
check_schedule ( b , 1 )
self . assertEqual ( b . lazydata . base . buffer . size , 16 )
class TestUOpBecome ( unittest . TestCase ) :
# the simplest case, if we create a new BUFFER for this tensor UOp
def test_new_buffer ( self ) :
a = Tensor . empty ( 4 , 4 )
b = Tensor . empty ( 4 , 4 )
add = a + b
check_schedule ( add , 1 )
# NOTE: realized base is always a flat buffer
assert UPat ( Ops . BUFFER ) . match ( add . lazydata . base , { } )
# the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor
assert add . lazydata is not add . lazydata . base
self . assertEqual ( add . lazydata . size , 16 )
self . assertEqual ( add . lazydata . shape , ( 4 , 4 ) )
def test_new_buffer_view ( self ) :
a = Tensor . empty ( 4 , 4 )
b = Tensor . empty ( 4 , 4 )
add = ( a + b ) . reshape ( 8 , 2 )
check_schedule ( add , 1 )
assert UPat ( Ops . BUFFER ) . match ( add . lazydata . base , { } )
# the shape is preserverd in the becomes_map.
self . assertEqual ( add . lazydata . shape , ( 8 , 2 ) )
assert add . lazydata is not add . lazydata . base
def test_new_flat_buffer ( self ) :
a = Tensor . empty ( 4 , )
b = Tensor . empty ( 4 , )
add = a + b
check_schedule ( add , 1 )
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
assert UPat ( Ops . BUFFER ) . match ( add . lazydata , { } )
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
def test_reorder_expand ( self ) :
a = Tensor . empty ( 4 , 1 )
b = a . expand ( 4 , 4 ) . reciprocal ( )
check_schedule ( b , 1 )
self . assertEqual ( b . lazydata . base . buffer . size , 4 )
self . assertEqual ( b . lazydata . st , ShapeTracker . from_shape ( ( 4 , 1 ) ) . expand ( ( 4 , 4 ) ) )
def test_become_existing_buffer ( self ) :
a = Tensor . empty ( 4 , 4 )
b = a * 1
assert UPat ( Ops . MUL ) . match ( b . lazydata , { } ) # before scheduling it's a mul
check_schedule ( b , 0 )
assert UPat ( Ops . VIEW , src = ( UPat ( Ops . BUFFER ) ) ) . match ( b . lazydata , { } ) # scheduling merges all MovementOps into a single VIEW
self . assertIs ( a . lazydata . base . buffer , b . lazydata . base . buffer )
def test_become_buf_with_mops ( self ) :
a = Tensor . empty ( 2 , 4 , 2 )
noop = a . shrink ( ( ( 1 , 2 ) , ( 0 , 4 ) , ( 0 , 2 ) ) ) . reshape ( 4 , 2 ) * 1 + 0
# before realizing, this tensor is base
assert noop . lazydata is noop . lazydata . base
noop . realize ( )
# it becomes a realized view after realize
assert noop . lazydata is not noop . lazydata . base
assert noop . lazydata . base . op is Ops . BUFFER
late_add = noop + 2
late_add . realize ( )
def test_become_const_in_base ( self ) :
a = Tensor . empty ( 4 )
b = a * 0
assert UPat ( Ops . MUL ) . match ( b . lazydata , { } ) # before scheduling it's a mul
check_schedule ( b , 0 )
assert UPat ( Ops . CONST , arg = 0 ) . match ( b . lazydata . base , { } ) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
def test_become_const_in_view ( self ) :
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
add = Tensor . empty ( 2 , 2 ) + Tensor . empty ( 2 , 2 )
b = add . shrink ( ( ( 0 , 1 ) , ( 0 , 0 ) ) )
check_schedule ( b , 0 )
assert UPat ( Ops . CONST , arg = 0 ) . match ( b . lazydata , { } )
self . assertEqual ( b . shape , ( 1 , 0 ) )
# the base is untouched.
assert UPat ( Ops . ADD ) . match ( add . lazydata , { } )
def test_become_const_from_const ( self ) :
const_add = Tensor ( 1 ) + Tensor ( 2 )
assert UPat ( Ops . ADD ) . match ( const_add . lazydata , { } )
check_schedule ( const_add , 0 )
assert UPat ( Ops . CONST , arg = 3 ) . match ( const_add . lazydata . base , { } )
# tensors can become another realized tensor source
def test_become_existing_buf_simple ( self ) :
a = Tensor . empty ( 4 , 4 )
b = a + 0
check_schedule ( b , 0 )
assert b . lazydata . base . op is Ops . BUFFER
self . assertIs ( a . lazydata , b . lazydata )
# they can also chain other movement ops on top of the tensor source
def test_become_existing_buf_view ( self ) :
a = Tensor . empty ( 4 , 4 )
b = a . permute ( ( 1 , 0 ) ) + 0
check_schedule ( b , 0 )
self . assertEqual ( b . lazydata . st , a . lazydata . permute ( ( 1 , 0 ) ) . st )
def test_become_existing_buf_view_alt ( self ) :
a = Tensor . empty ( 4 , 4 )
b = a . permute ( ( 1 , 0 ) ) . reshape ( ( 8 , 2 ) ) + 0
check_schedule ( b , 0 )
self . assertEqual ( b . lazydata . st , a . lazydata . permute ( ( 1 , 0 ) ) . reshape ( ( 8 , 2 ) ) . st )
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops
def test_become_existing_buf_complex ( self ) :
a = Tensor . empty ( 4 , 4 )
b = ( a . permute ( ( 1 , 0 ) ) + 0 ) . reshape ( ( 8 , 2 ) ) + 0
check_schedule ( b , 0 )
self . assertEqual ( b . lazydata . st , a . lazydata . permute ( ( 1 , 0 ) ) . reshape ( ( 8 , 2 ) ) . st )
assert b . lazydata . base . op is Ops . BUFFER
def test_become_multiple_choices ( self ) :
a = Tensor . empty ( 16 )
b = ( a . reshape ( 1 , 1 , 4 , 1 , 4 ) + 0 ) . reshape ( 1 , 1 , 4 , 4 ) . shrink ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 3 ) , ( 0 , 3 ) ) ) + 0
c = ( a . reshape ( 1 , 1 , 4 , 4 ) + 0 ) . shrink ( ( ( 0 , 1 ) , ( 0 , 1 ) , ( 0 , 3 ) , ( 0 , 3 ) ) ) + 0
check_schedule ( [ b , c ] , 0 )
assert all_same ( [ x . lazydata . base . realized for x in [ a , b , c ] ] )
# these movement ops result in the same ShapeTracker
assert b . lazydata . st == c . lazydata . st
assert b . lazydata is c . lazydata
assert UPat ( Ops . VIEW , src = ( UPat ( Ops . BUFFER ) , ) ) . match ( c . lazydata , { } )
def test_setitem_becomes_view_of_base ( self ) :
a = Tensor . full ( ( 4 , ) , 2. ) . contiguous ( ) . realize ( )
b = a . shrink ( ( ( 0 , 2 ) , ) ) . assign ( Tensor . full ( ( 2 , ) , 1.0 ) )
b . realize ( )
assert b . lazydata . is_realized
assert b . lazydata . base . buffer . _base is None
def test_setitem_offset ( self ) :
a = Tensor . full ( ( 16 , ) , 0. ) . contiguous ( ) . realize ( )
b = Tensor . full ( ( 16 , ) , 1. ) . contiguous ( ) . realize ( )
a_view = a [ 4 : ] . reshape ( 3 , 4 ) . shrink ( ( ( 0 , 2 ) , ( 0 , 2 ) ) ) . reshape ( ( 4 , ) )
b . shrink ( ( ( 0 , 4 ) , ) ) . assign ( a_view ) . realize ( )
self . assertListEqual ( b . tolist ( ) , [ 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ] )
if __name__ == ' __main__ ' :
unittest . main ( verbosity = 2 )