from typing import List
import unittest , time , pytest
from tinygrad import dtypes , Device
from tinygrad . helpers import DEBUG
from tinygrad . ops import Ops , UOp , KernelInfo , UPat , PatternMatcher , track_rewrites
from tinygrad . renderer import Renderer
from tinygrad . codegen . lowerer import rewrite_shapetracker_with_index
from tinygrad . codegen . devectorizer import full_graph_rewrite , graph_rewrite , sym
from tinygrad . codegen . expander import expander , expand_rewrite
from tinygrad . codegen . linearize import linearize_uop
from tinygrad . shape . shapetracker import ShapeTracker , View
simple_pm = PatternMatcher ( [
( UPat . cvar ( ' x ' , dtypes . int ) , lambda x : UOp . const ( dtypes . float , 1.0 ) + UOp . const ( dtypes . float , 2.0 ) ) ,
( UPat . cvar ( ' x ' ) + UPat . cvar ( ' y ' ) , lambda x , y : UOp . const ( dtypes . float , x . arg + y . arg ) ) ,
( UPat . cvar ( ' x ' ) * UPat . cvar ( ' y ' ) * UPat . cvar ( ' z ' ) , lambda x , y , z : UOp . const ( dtypes . float , x . arg * y . arg * z . arg ) ) ,
( ( UPat . var ( ' x ' ) + UPat . cvar ( ' c1 ' ) ) + UPat . cvar ( ' c2 ' ) , lambda x , c1 , c2 : x + ( c1 . arg + c2 . arg ) ) ,
] )
def to_uops_list ( u : List [ UOp ] ) - > List [ UOp ] : return linearize_uop ( full_graph_rewrite ( UOp . sink ( * u ) ) )
class TestGraphRewriteEfficiency ( unittest . TestCase ) :
def test_create_many_uops ( self ) :
c1 = UOp . const ( dtypes . int , 1 )
c2 = UOp . const ( dtypes . int , 2 )
st = time . perf_counter ( )
uops = [ UOp ( Ops . ADD , dtypes . int , ( c1 , c2 ) ) for _ in range ( 10000 ) ]
et = time . perf_counter ( ) - st
print ( f " created { len ( uops ) } uops in { et * 1000 : .2f } ms " )
def test_expand_rewrite ( self ) :
sink = UOp ( Ops . SINK , dtypes . void , arg = KernelInfo ( local_dims = 2 , upcasted = 4 , dont_use_locals = False ) , src = (
UOp ( Ops . STORE , dtypes . void , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 0 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 2 , 4 , 64 , 8 , 16 , 1 , 1 , 3 , 3 , 4 , 1 ) ,
strides = ( 1179648 , 9216 , 1 , 147456 , 576 , 0 , 0 , 64 , 192 , 36864 , 0 ) ,
offset = 0 , mask = None , contiguous = False ) , ) ) , src = ( ) ) ,
UOp ( Ops . REDUCE_AXIS , dtypes . float , arg = ( Ops . ADD , ( 5 , 6 , 10 ) ) , src = (
UOp ( Ops . CAST , dtypes . float , arg = None , src = (
UOp ( Ops . MUL , dtypes . half , arg = None , src = (
UOp ( Ops . LOAD , dtypes . half , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . half . ptr ( ) , arg = 1 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = (
View ( shape = ( 1 , 1024 , 1 , 64 , 4 , 17 , 4 , 17 ) , strides = ( 0 , 14400 , 0 , 225 , 0 , 15 , 0 , 1 ) , offset = - 16 ,
mask = ( ( 0 , 1 ) , ( 0 , 1024 ) , ( 0 , 1 ) , ( 0 , 64 ) , ( 0 , 4 ) , ( 1 , 16 ) , ( 0 , 4 ) , ( 1 , 16 ) ) , contiguous = False ) ,
View ( shape = ( 2 , 4 , 64 , 8 , 16 , 16 , 15 , 3 , 3 , 4 , 15 ) , strides = ( 0 , 73984 , 4734976 , 0 , 4624 , 295936 , 68 , 18 , 1224 , 0 , 1 ) , offset = 0 ,
mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) ,
UOp ( Ops . LOAD , dtypes . half , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . half . ptr ( ) , arg = 2 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = (
View ( shape = ( 2 , 4 , 64 , 8 , 16 , 16 , 15 , 3 , 3 , 4 , 15 ) , strides = ( 7200 , 0 , 230400 , 900 , 0 , 14400 , 15 , 0 , 0 , 225 , 1 ) , offset = 0 ,
mask = None , contiguous = False ) , ) ) , src = ( ) ) , ) ) , ) ) , ) ) , ) ) , ) ) , ) )
lower_sink = rewrite_shapetracker_with_index ( sink , Device [ Device . DEFAULT ] . renderer )
cnt = [ 0 ]
old_init = UOp . __init__
def uop_hook ( self , * args , * * kwargs ) :
cnt [ 0 ] + = 1
old_init ( self , * args , * * kwargs )
UOp . __init__ = uop_hook
st = time . perf_counter ( )
new_sink = full_graph_rewrite ( lower_sink )
et = time . perf_counter ( ) - st
UOp . __init__ = old_init
print ( f " rewrote in { et * 1000 : .2f } ms, from { len ( lower_sink . toposort ) } -> { len ( new_sink . toposort ) } , creating { cnt [ 0 ] } uops " )
class TestGraphRewriteConst ( unittest . TestCase ) :
def test_gep_const ( self ) :
v1 = UOp . const ( dtypes . int . vec ( 3 ) , ( 0 , 1 , 2 ) )
v2 = v1 . gep ( 1 )
ret = graph_rewrite ( v2 , sym )
self . assertEqual ( ret . dtype , dtypes . int )
self . assertEqual ( ret . arg , 1 )
def test_gep_const_single ( self ) :
v1 = UOp . const ( dtypes . int . vec ( 3 ) , 4 )
v2 = v1 . gep ( 1 )
ret = graph_rewrite ( v2 , sym )
self . assertEqual ( ret . dtype , dtypes . int )
self . assertEqual ( ret . arg , 4 )
def test_add_const ( self ) :
v1 = UOp . const ( dtypes . int . vec ( 3 ) , ( 0 , 1 , 2 ) )
v2 = UOp . const ( dtypes . int . vec ( 3 ) , ( 5 , 6 , 7 ) )
ret = graph_rewrite ( v1 + v2 , sym )
self . assertEqual ( ret . op , Ops . VCONST )
self . assertEqual ( ret . dtype , dtypes . int . vec ( 3 ) )
self . assertEqual ( ret . arg , ( 5 , 7 , 9 ) )
def test_add_const_lose_v ( self ) :
v1 = UOp . const ( dtypes . int . vec ( 3 ) , ( 0 , 1 , 2 ) )
v2 = UOp . const ( dtypes . int . vec ( 3 ) , ( 2 , 1 , 0 ) )
ret = graph_rewrite ( v1 + v2 , sym )
self . assertEqual ( ret . op , Ops . CONST )
self . assertEqual ( ret . dtype , dtypes . int . vec ( 3 ) )
self . assertEqual ( ret . arg , 2 )
xfail_broken_const_wraparound = pytest . mark . xfail ( reason = " const folding does not properly implement modular arithmetic " )
class TestModularWraparound ( unittest . TestCase ) :
def _test ( self , uop : UOp , expected : int ) :
results = to_uops_list ( [ uop ] )
self . assertEqual ( len ( results ) , 1 )
self . assertEqual ( results [ 0 ] . op , Ops . CONST )
self . assertEqual ( results [ 0 ] . dtype , uop . dtype )
self . assertEqual ( results [ 0 ] . arg , expected )
@xfail_broken_const_wraparound
def test_cast ( self ) :
t = self . _test
t ( UOp . const ( dtypes . uint , 0xABCD17D6 ) . cast ( dtypes . uint8 ) , 0xD6 )
t ( UOp . const ( dtypes . uint , 0xABCD17D6 ) . cast ( dtypes . uint8 ) . cast ( dtypes . uint ) , 0xD6 )
@xfail_broken_const_wraparound
def test_mul ( self ) :
t = self . _test
t ( UOp . const ( dtypes . uint , 0xABCD17D6 ) * 0xAABBCCDD , 1147018174 )
t ( UOp . const ( dtypes . int , 0xABCD17D6 ) * 10 , - 1241321892 )
@xfail_broken_const_wraparound
def test_div ( self ) :
t = self . _test
t ( UOp . const ( dtypes . uint , 0xABCD17D6 ) * 0xAABBCCDD / / 11 , 104274379 )
t ( UOp . const ( dtypes . int , 0xABCD17D6 ) * 10 / / 11 , - 112847444 )
@xfail_broken_const_wraparound
def test_neg ( self ) :
t = self . _test
t ( - UOp . const ( dtypes . uint8 , 1 ) , 0xFF )
t ( - UOp . const ( dtypes . uint16 , 1 ) , 0xFFFF )
t ( - UOp . const ( dtypes . uint32 , 1 ) , 0xFFFFFFFF )
t ( - UOp . const ( dtypes . uint64 , 1 ) , 0xFFFFFFFFFFFFFFFF )
@xfail_broken_const_wraparound
def test_neg_min_int ( self ) :
t = self . _test
t ( - UOp . const ( dtypes . int8 , - 2 * * 7 ) , - 2 * * 7 )
t ( - UOp . const ( dtypes . int16 , - 2 * * 15 ) , - 2 * * 15 )
t ( - UOp . const ( dtypes . int32 , - 2 * * 31 ) , - 2 * * 31 )
t ( - UOp . const ( dtypes . int64 , - 2 * * 63 ) , - 2 * * 63 )
@xfail_broken_const_wraparound
def test_payne_hanek_reduction_bug ( self ) :
t = self . _test
a = ( UOp . const ( dtypes . uint , 43748177600 ) . cast ( dtypes . uint ) | 36 ) . cast ( dtypes . ulong )
b = 2536655455 * a + 4294967296 * UOp . const ( dtypes . ulong , 25366554550 )
c = ( b + 2261737165 ) / / 4611686018427387904
t ( c , 0 )
class TestGraphRewrite ( unittest . TestCase ) :
def test_dedup ( self ) :
v1 = UOp ( Ops . DEFINE_VAR , dtypes . float )
v2 = UOp ( Ops . DEFINE_VAR , dtypes . float )
nout = graph_rewrite ( v1 + v2 , PatternMatcher ( [ ] ) )
self . assertIs ( nout . src [ 0 ] , nout . src [ 1 ] )
# NOTE: this shows why we can't have a UOp in arg
@unittest . expectedFailure
def test_no_dedup_args ( self ) :
a1 = UOp ( Ops . DEFINE_VAR , dtypes . int , ( ) , ( " a1 " , UOp . const ( dtypes . int , 0 ) , UOp . const ( dtypes . int , 11 ) ) )
a2 = UOp ( Ops . DEFINE_VAR , dtypes . int , ( ) , ( " a2 " , UOp . const ( dtypes . int , 0 ) , UOp . const ( dtypes . int , 11 ) ) )
sink = a1 . sink ( a2 )
define_vars = [ x for x in graph_rewrite ( sink , PatternMatcher ( [ ] ) ) . toposort if x . op is Ops . DEFINE_VAR ]
self . assertEqual ( len ( define_vars ) , 1 )
def test_simple ( self ) :
c1 = UOp . const ( dtypes . float , 1.0 )
c2 = UOp . const ( dtypes . float , 2.0 )
nout = graph_rewrite ( c1 + c2 , simple_pm )
self . assertEqual ( nout . op , Ops . CONST )
self . assertEqual ( nout . arg , 3.0 )
def test_depth_2_late ( self ) :
c1 = UOp . const ( dtypes . float , 1.0 )
c2 = UOp . const ( dtypes . float , 2.0 )
c3 = UOp . const ( dtypes . float , 3.0 )
nout = graph_rewrite ( c1 * c2 * ( c3 + c3 ) , simple_pm )
self . assertEqual ( nout . op , Ops . CONST )
self . assertEqual ( nout . arg , 12.0 )
def test_double ( self ) :
c1 = UOp . const ( dtypes . float , 1.0 )
c2 = UOp . const ( dtypes . float , 2.0 )
c3 = UOp . const ( dtypes . float , 3.0 )
nout = graph_rewrite ( c1 + c2 + c3 , simple_pm )
self . assertEqual ( nout . op , Ops . CONST )
self . assertEqual ( nout . arg , 6.0 )
def test_triple ( self ) :
c1 = UOp . const ( dtypes . float , 1.0 )
c2 = UOp . const ( dtypes . float , 2.0 )
c3 = UOp . const ( dtypes . float , 3.0 )
c4 = UOp . const ( dtypes . float , 4.0 )
nout = graph_rewrite ( c1 + c2 + c3 + c4 , simple_pm )
self . assertEqual ( nout . op , Ops . CONST )
self . assertEqual ( nout . arg , 10.0 )
def test_diamond ( self ) :
c1 = UOp . const ( dtypes . float , 1.0 )
c2 = UOp . const ( dtypes . float , 2.0 )
c3 = UOp . const ( dtypes . float , 3.0 )
nout = graph_rewrite ( ( c1 + c2 ) + ( c1 + c3 ) , simple_pm )
self . assertEqual ( nout . op , Ops . CONST )
self . assertEqual ( nout . arg , 7.0 )
def test_magic_4 ( self ) :
c1 = UOp . const ( dtypes . int , 4.0 )
nout = graph_rewrite ( c1 , simple_pm )
self . assertEqual ( nout . op , Ops . CONST )
self . assertEqual ( nout . arg , 3.0 )
def test_depth_2_fold ( self ) :
v = UOp ( Ops . DEFINE_VAR , dtypes . float )
c1 = UOp . const ( dtypes . float , 1.0 )
c2 = UOp . const ( dtypes . float , 2.0 )
nout = graph_rewrite ( v + c1 + c2 , simple_pm )
self . assertEqual ( nout . op , Ops . ADD )
self . assertEqual ( nout . src [ 0 ] . op , Ops . DEFINE_VAR )
self . assertEqual ( nout . src [ 1 ] . op , Ops . CONST )
self . assertEqual ( nout . src [ 1 ] . arg , 3.0 )
def test_commutative_work ( self ) :
a = UOp . variable ( ' a ' , 0 , 1 )
b = UOp . variable ( ' b ' , 0 , 1 )
self . assertIs ( ( a + b ) . simplify ( ) , ( b + a ) . simplify ( ) )
def test_consts_go_last_right_away ( self ) :
a = UOp . variable ( ' a ' , 0 , 1 )
tst = ( 2 + a ) . simplify ( )
self . assertIs ( tst . src [ 0 ] , a )
self . assertIs ( tst . src [ 1 ] , a . const_like ( 2 ) )
def test_consts_go_last ( self ) :
a = UOp . variable ( ' a ' , 0 , 1 )
b = UOp . variable ( ' b ' , 0 , 1 )
c = UOp . variable ( ' c ' , 0 , 1 )
d = UOp . variable ( ' d ' , 0 , 1 )
outs = [ 2 + a , 2 + a + d + 3 + b + c + 4 , UOp ( Ops . ADD , a . dtype , src = ( a . const_like ( 2 ) , a ) ) , ( 4 + d ) + c + ( 2 + a ) + b ]
for out in outs :
sink = graph_rewrite ( out , sym )
print ( sink . render ( ) )
self . assertEqual ( sink . op , Ops . ADD )
self . assertEqual ( sink . src [ 1 ] . op , Ops . CONST )
self . assertEqual ( len ( [ x for x in sink . toposort if x . op is Ops . CONST ] ) , 1 )
class TestUOpGraph ( unittest . TestCase ) :
def test_add_constant_fold ( self ) :
c1 = UOp ( Ops . CONST , dtypes . float , arg = 1.0 )
c2 = UOp ( Ops . CONST , dtypes . float , arg = 2.0 )
out = UOp ( Ops . ADD , dtypes . float , ( c1 , c2 ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( uops ) , 1 )
out = uops [ - 1 ]
self . assertEqual ( out . op , Ops . CONST )
self . assertEqual ( out . arg , 3.0 )
def test_where_same_fold ( self ) :
v = UOp . variable ( ' tmp ' , 0 , 1 )
c0 = UOp ( Ops . CONST , dtypes . int , arg = 0 )
vc = UOp ( Ops . CMPNE , dtypes . bool , ( v , c0 ) )
c1 = UOp ( Ops . CONST , dtypes . float , arg = 1.0 )
out = UOp ( Ops . WHERE , dtypes . float , ( vc , c1 , c1 ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( uops ) , 1 )
out = uops [ - 1 ]
self . assertEqual ( out . op , Ops . CONST )
self . assertEqual ( out . arg , 1.0 )
def test_where_const_fold ( self ) :
bf = UOp ( Ops . CONST , dtypes . bool , arg = False )
c1 = UOp ( Ops . CONST , dtypes . float , arg = 1.0 )
c2 = UOp ( Ops . CONST , dtypes . float , arg = 2.0 )
out = UOp ( Ops . WHERE , dtypes . float , ( bf , c1 , c2 ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( uops ) , 1 )
out = uops [ - 1 ]
self . assertEqual ( out . op , Ops . CONST )
self . assertEqual ( out . arg , 2.0 )
def test_const_cast ( self ) :
bf = UOp ( Ops . CONST , dtypes . bool , arg = False )
out = UOp ( Ops . CAST , dtypes . int , ( bf , ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( uops ) , 1 )
out = uops [ - 1 ]
self . assertEqual ( out . op , Ops . CONST )
self . assertEqual ( out . arg , 0 )
@unittest . skip ( " this test isn ' t valid uops " )
def test_noop_vectorize_fold ( self ) :
d0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 0 )
idx = UOp . const ( dtypes . int , 0 )
ld = UOp ( Ops . LOAD , dtypes . float . vec ( 2 ) , ( d0 , idx ) )
vec = UOp ( Ops . VECTORIZE , dtypes . float . vec ( 2 ) , ( ld , ) )
x = UOp ( Ops . GEP , dtypes . float , ( vec , ) , arg = 0 )
alu = UOp ( Ops . SQRT , dtypes . float , ( x , ) )
out = UOp ( Ops . STORE , dtypes . void , ( d0 , idx , alu ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( [ x for x in uops if x . op is Ops . VECTORIZE ] ) , 0 )
def test_gep_vec_fold ( self ) :
d0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , ( ) , 0 )
d1 = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , ( ) , 1 )
d2 = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , ( ) , 2 )
idx = UOp . const ( dtypes . int , 0 )
def _test_vec ( geps , count = 4 ) :
vec = UOp ( Ops . VECTORIZE , dtypes . float . vec ( count ) , geps )
out = UOp ( Ops . STORE , dtypes . void , ( d0 . index ( idx ) , vec ) )
uops = to_uops_list ( [ out ] )
if DEBUG > = 4 :
from tinygrad import Device
print ( Device [ Device . DEFAULT ] . renderer . render ( " test " , uops ) )
return uops [ - 1 ] . src [ - 1 ]
# possible
val = UOp ( Ops . LOAD , dtypes . float . vec ( 4 ) , ( d1 . index ( idx ) , ) )
xyzw = tuple ( UOp ( Ops . GEP , dtypes . float , ( val , ) , ( i , ) ) for i in range ( 4 ) )
self . assertIs ( _test_vec ( xyzw ) . op , Ops . LOAD )
# unaligned
val = UOp ( Ops . LOAD , dtypes . float . vec ( 4 ) , ( d1 . index ( idx ) , ) )
wzyx = tuple ( UOp ( Ops . GEP , dtypes . float , ( val , ) , ( i , ) ) for i in reversed ( range ( 4 ) ) )
self . assertIs ( _test_vec ( wzyx ) . op , Ops . VECTORIZE )
# different_size
val = UOp ( Ops . LOAD , dtypes . float . vec ( 2 ) , ( d1 . index ( idx ) , ) )
xy = tuple ( UOp ( Ops . GEP , dtypes . float , ( val , ) , ( i , ) ) for i in range ( 2 ) )
self . assertIs ( _test_vec ( xy + xy ) . op , Ops . VECTORIZE )
val = UOp ( Ops . LOAD , dtypes . float . vec ( 4 ) , ( d1 . index ( idx ) , ) )
xy = tuple ( UOp ( Ops . GEP , dtypes . float , ( val , ) , ( i , ) ) for i in range ( 2 ) )
self . assertIs ( _test_vec ( xy , count = 2 ) . op , Ops . VECTORIZE )
# different vals
val1 = UOp ( Ops . LOAD , dtypes . float . vec ( 2 ) , ( d1 . index ( idx ) , ) )
val2 = UOp ( Ops . LOAD , dtypes . float . vec ( 2 ) , ( d2 . index ( idx ) , ) )
xy1 = tuple ( UOp ( Ops . GEP , dtypes . float , ( val1 , ) , ( i , ) ) for i in range ( 2 ) )
xy2 = tuple ( UOp ( Ops . GEP , dtypes . float , ( val2 , ) , ( i , ) ) for i in range ( 2 ) )
self . assertIs ( _test_vec ( xy1 + xy2 ) . op , Ops . VECTORIZE )
def test_gep_vec_const_fold ( self ) :
for vec_size in [ 2 , 4 , 8 ] :
consts = [ UOp . const ( dtypes . float , float ( i ) ) for i in range ( vec_size ) ]
vec = UOp ( Ops . VECTORIZE , dtypes . float . vec ( vec_size ) , tuple ( consts ) )
uops = to_uops_list ( [ UOp ( Ops . GEP , dtypes . float , ( vec , ) , ( i , ) ) for i in range ( vec_size ) ] )
for uop , const in zip ( uops , consts ) :
self . assertEqual ( uop , const )
def test_wmma_vectorize_fold ( self ) :
for i in [ 2 , 4 , 8 ] :
vec = UOp ( Ops . VECTORIZE , dtypes . half . vec ( i ) , tuple ( UOp . const ( dtypes . half , 0.0 ) for _ in range ( i ) ) )
var = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) )
acc = UOp . variable ( ' acc ' , 0 , 1 , dtypes . half . vec ( i ) )
wmma = UOp ( Ops . WMMA , dtypes . half . vec ( i ) , ( vec , var , acc ) )
uops = to_uops_list ( [ wmma ] )
self . assertEqual ( uops [ 0 ] , acc )
self . assertEqual ( len ( uops ) , 1 )
for i in [ 2 , 4 , 8 ] :
var = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) )
vec = UOp ( Ops . VECTORIZE , dtypes . half . vec ( i ) , tuple ( UOp . const ( dtypes . half , 0.0 ) for _ in range ( i ) ) )
acc = UOp . variable ( ' acc ' , 0 , 1 , dtypes . half . vec ( i ) )
wmma = UOp ( Ops . WMMA , dtypes . half . vec ( i ) , ( var , vec , acc ) )
uops = to_uops_list ( [ wmma ] )
self . assertEqual ( uops [ 0 ] , acc )
self . assertEqual ( len ( uops ) , 1 )
@unittest . skip ( " wmma is wrong here, it needs an arg " )
def test_wmma_vectorize_no_fold ( self ) :
for i in [ 4 , 8 ] :
vec = UOp ( Ops . VECTORIZE , dtypes . half . vec ( i ) ,
tuple ( UOp . const ( dtypes . half , 0.0 ) for _ in range ( i / / 2 ) ) +
tuple ( UOp ( Ops . DEFINE_VAR , dtypes . half , arg = ( f ' tmp { j } ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) ) for j in range ( i / / 2 ) ) )
var = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( f ' tmp { i } ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
acc = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( ' acc ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
wmma = UOp ( Ops . WMMA , dtypes . half . vec ( i ) , ( vec , var , acc ) )
uops = to_uops_list ( [ wmma ] )
self . assertEqual ( uops [ - 1 ] , wmma )
for i in [ 4 , 8 ] :
var = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( f ' tmp { i } ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
vec = UOp ( Ops . VECTORIZE , dtypes . half . vec ( i ) ,
tuple ( UOp . const ( dtypes . half , 0.0 ) for _ in range ( i / / 2 ) ) +
tuple ( UOp ( Ops . DEFINE_VAR , dtypes . half , arg = ( f ' tmp { j } ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) ) for j in range ( i / / 2 ) ) )
acc = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( ' acc ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
wmma = UOp ( Ops . WMMA , dtypes . half . vec ( i ) , ( var , vec , acc ) )
uops = to_uops_list ( [ wmma ] )
self . assertEqual ( uops [ - 1 ] , wmma )
for i in [ 2 , 4 , 8 ] :
vec = UOp ( Ops . VECTORIZE , dtypes . half . vec ( i ) ,
tuple ( UOp . const ( dtypes . half , 1.0 if j == 0 else 0.0 ) for j in range ( i ) ) )
var = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( f ' tmp { i } ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
acc = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( ' acc ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
wmma = UOp ( Ops . WMMA , dtypes . half . vec ( i ) , ( vec , var , acc ) )
uops = to_uops_list ( [ wmma ] )
self . assertEqual ( uops [ - 1 ] , wmma )
for i in [ 2 , 4 , 8 ] :
var = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( f ' tmp { i } ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
vec = UOp ( Ops . VECTORIZE , dtypes . half . vec ( i ) ,
tuple ( UOp . const ( dtypes . half , 1.0 if j == 0 else 0.0 ) for j in range ( i ) ) )
acc = UOp ( Ops . DEFINE_VAR , dtypes . half . vec ( i ) , arg = ( ' acc ' , UOp . const ( dtypes . half , 0 ) , UOp . const ( dtypes . half , 1 ) ) )
wmma = UOp ( Ops . WMMA , dtypes . half . vec ( i ) , ( var , vec , acc ) )
uops = to_uops_list ( [ wmma ] )
self . assertEqual ( uops [ - 1 ] , wmma )
def test_cast_alu_fold ( self ) :
d0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . bool . ptr ( ) , arg = 0 )
d1 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , arg = 1 )
idx = UOp . const ( dtypes . int , 0 )
ld = UOp ( Ops . LOAD , dtypes . int , ( d1 . index ( idx ) , ) )
alu = ( ld < 1 ) . cast ( dtypes . bool )
out = UOp ( Ops . STORE , dtypes . void , ( d0 . index ( idx ) , alu ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( [ x for x in uops if x . op is Ops . CAST ] ) , 0 )
def test_double_cast_fold ( self ) :
d0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 0 )
d1 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , arg = 1 )
idx = UOp . const ( dtypes . int , 0 )
ld = UOp ( Ops . LOAD , dtypes . int , ( d1 . index ( idx ) , ) )
alu = ld . cast ( dtypes . float ) . cast ( dtypes . float )
out = UOp ( Ops . STORE , dtypes . void , ( d0 . index ( idx ) , alu ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( [ x for x in uops if x . op is Ops . CAST ] ) , 1 )
def test_depth_2_const_fold ( self ) :
v = UOp . variable ( " tmp " , 0 , 1 )
c2 = UOp ( Ops . CONST , dtypes . int , arg = 2 )
c4 = UOp ( Ops . CONST , dtypes . int , arg = 4 )
vc = UOp ( Ops . ADD , dtypes . int , ( v , c2 ) )
out = UOp ( Ops . ADD , dtypes . int , ( vc , c4 ) )
uops = to_uops_list ( [ out ] )
self . assertEqual ( len ( uops ) , 3 )
out = uops [ - 1 ]
self . assertEqual ( out . op , Ops . ADD )
self . assertEqual ( out . src [ 1 ] . op , Ops . CONST )
self . assertEqual ( out . src [ 1 ] . arg , 6 )
def test_bitcast_to_same_dtype_fold ( self ) :
for dt in dtypes . ints + dtypes . floats + ( dtypes . bool , ) :
d0 = UOp ( Ops . DEFINE_GLOBAL , dt . ptr ( ) , arg = 0 )
v = UOp ( Ops . LOAD , dt , ( d0 . index ( UOp . const ( dtypes . int , 0 ) ) , ) )
uops = to_uops_list ( [ v . bitcast ( dt ) ] )
self . assertEqual ( len ( [ x for x in uops if x . op is Ops . BITCAST ] ) , 0 , f " dtype = { dt } " )
def test_out_of_bounds_access ( self ) :
glbl0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( 16 ) , ( ) , 0 )
ld0 = UOp ( Ops . LOAD , dtypes . int , ( glbl0 . index ( UOp . const ( dtypes . int , 42 ) ) , ) )
with self . assertRaises ( RuntimeError ) : to_uops_list ( [ ld0 ] )
def test_fold_gated_load ( self ) :
glbl0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 0 )
glbl1 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 1 )
glbl2 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 2 )
idx = UOp . const ( dtypes . int , 0 )
ld0 = UOp ( Ops . LOAD , dtypes . int , ( glbl1 . index ( idx , UOp . const ( dtypes . bool , False ) ) , ) )
ld1 = UOp ( Ops . LOAD , dtypes . int , ( glbl2 . index ( idx , UOp . const ( dtypes . bool , True ) ) , ) )
uops = to_uops_list ( [ UOp ( Ops . STORE , dtypes . void , ( glbl0 . index ( idx ) , ld1 + ld0 ) ) ] )
ld0 = uops [ - 1 ] . src [ - 1 ]
# the gate and invalid value are deleted from ld1
self . assertEqual ( ld0 , UOp . load ( glbl2 . index ( idx ) , dtype = dtypes . int ) )
def test_fold_gated_load_local ( self ) :
glbl0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 0 )
smem = UOp ( Ops . DEFINE_LOCAL , dtypes . int . ptr ( size = 18 , local = True ) , ( ) , " temp " )
lidx = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " lidx0 " , 16 ) )
st = UOp ( Ops . STORE , dtypes . void , ( smem . index ( lidx ) , UOp . load ( glbl0 . index ( lidx ) , dtype = dtypes . int ) ) )
barrier = UOp ( Ops . BARRIER , dtypes . void , ( st , ) )
ld0 = UOp ( Ops . LOAD , dtypes . int , ( smem . index ( lidx + 1 , UOp . const ( dtypes . bool , False ) ) , barrier ) )
ld1 = UOp ( Ops . LOAD , dtypes . int , ( smem . index ( lidx + 2 , UOp . const ( dtypes . bool , True ) ) , barrier ) )
uops = to_uops_list ( [ UOp ( Ops . STORE , dtypes . void , ( glbl0 . index ( lidx ) , ld1 + ld0 ) ) ] )
ld0 = uops [ - 1 ] . src [ - 1 ]
# the gate and invalid value are deleted from ld1
self . assertEqual ( ld0 . src [ 0 ] , smem . index ( lidx + 2 ) )
def test_fold_gated_store ( self ) :
glbl = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 0 )
idx0 = UOp . const ( dtypes . int , 0 )
idx1 = UOp . const ( dtypes . int , 0 )
val = UOp . const ( dtypes . int , 42 )
st0 = UOp ( Ops . STORE , dtypes . void , ( glbl . index ( idx0 , UOp . const ( dtypes . bool , False ) ) , val ) )
st1 = UOp ( Ops . STORE , dtypes . void , ( glbl . index ( idx1 , UOp . const ( dtypes . bool , True ) ) , val ) )
uops = to_uops_list ( [ st0 , st1 ] )
# only the second store happens
self . assertEqual ( len ( uops ) , 5 )
self . assertEqual ( uops [ - 1 ] , UOp . store ( glbl . index ( idx1 ) , val ) )
@unittest . skip ( " this is a uop type error " )
def test_asserts_bad_gate ( self ) :
glbl0 = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 0 )
idx = UOp . const ( dtypes . int , 0 )
bad_gate = UOp . const ( dtypes . int , 1 )
with self . assertRaises ( AssertionError ) : to_uops_list ( [ UOp ( Ops . STORE , dtypes . void , ( glbl0 , idx , UOp . const ( dtypes . int , 42 ) , bad_gate ) ) ] )
def test_switched_range_order ( self ) :
glbl = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , ( ) , 0 )
c0 = UOp . const ( dtypes . int , 0 )
c2 = UOp . const ( dtypes . int , 2 )
cf = UOp . const ( dtypes . float , 0.0 )
r1 = UOp ( Ops . RANGE , dtypes . int , ( c0 , c2 ) , 0 )
r2 = UOp ( Ops . RANGE , dtypes . int , ( c0 , c2 ) , 1 )
alu = UOp ( Ops . MUL , dtypes . int , ( r2 , r1 ) )
store = UOp ( Ops . STORE , dtypes . void , ( glbl . index ( alu ) , cf ) )
uops = to_uops_list ( [ store ] )
ranges = [ x for x in uops if x . op is Ops . RANGE ]
endranges = [ x for x in uops if x . op is Ops . ENDRANGE ]
# ranges are closed in the right order
self . assertEqual ( endranges [ - 1 ] . src [ 0 ] , ranges [ 0 ] )
@track_rewrites ( )
def expander_rewrite ( sink ) : return graph_rewrite ( sink , sym + expander )
@track_rewrites ( )
def float4_rewrite ( sink ) : return full_graph_rewrite ( sink , Renderer ( ) )
class TestExpander ( unittest . TestCase ) :
def test_expand_add_broadcast ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( x for x in range ( 4 ) ) ) , ) , ( ( 1 , 4 ) , ) )
sink = expander_rewrite ( e1 + 3 )
assert sink . op is Ops . UNROLL and len ( sink . src [ 0 ] . arg ) == 4
self . assertTupleEqual ( sink . src [ 0 ] . arg , ( 3 , 4 , 5 , 6 ) )
def test_contract_simple ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( x for x in range ( 4 ) ) ) , ) , ( ( 1 , 4 ) , ) )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 4 ) , ( e1 , ) , ( ( 1 , 4 ) , ) )
sink = expander_rewrite ( con )
self . assertEqual ( sink . op , Ops . VCONST )
self . assertTupleEqual ( sink . arg , ( 0 , 1 , 2 , 3 ) )
def test_contract_axis_1 ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 16 ) , tuple ( x for x in range ( 16 ) ) ) , ) , ( ( 1 , 4 ) , ( 2 , 4 ) ) )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 4 ) , ( e1 , ) , ( ( 1 , 4 ) , ) )
sink = expander_rewrite ( con )
assert sink . op is Ops . UNROLL and len ( sink . src [ 0 ] . arg ) == 16 and sink . arg == ( ( 2 , 4 ) , )
assert sink . src [ 0 ] . op is Ops . VCONST
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 0 : 4 ] , ( 0 , 4 , 8 , 12 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 12 : ] , ( 3 , 7 , 11 , 15 ) )
def test_contract_axis_2 ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 16 ) , tuple ( x for x in range ( 16 ) ) ) , ) , ( ( 1 , 4 ) , ( 2 , 4 ) ) )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 4 ) , ( e1 , ) , ( ( 2 , 4 ) , ) )
sink = expander_rewrite ( con )
assert sink . op is Ops . UNROLL and len ( sink . src [ 0 ] . arg ) == 16 and sink . arg == ( ( 1 , 4 ) , )
assert sink . src [ 0 ] . op is Ops . VCONST
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 0 : 4 ] , ( 0 , 1 , 2 , 3 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 12 : ] , ( 12 , 13 , 14 , 15 ) )
def test_contract_axis_2_big ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 16 ) , tuple ( x for x in range ( 16 ) ) ) , ) , ( ( 1 , 2 ) , ( 2 , 2 ) , ( 3 , 2 ) , ( 4 , 2 ) ) )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 2 ) , ( e1 , ) , ( ( 2 , 2 ) , ) )
sink = expander_rewrite ( con )
assert sink . op is Ops . UNROLL and sink . arg == ( ( 1 , 2 ) , ( 3 , 2 ) , ( 4 , 2 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 0 : 2 ] , ( 0 , 4 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 12 : 14 ] , ( 10 , 14 ) )
def test_contract_multi_axis ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 16 ) , tuple ( x for x in range ( 16 ) ) ) , ) , ( ( 1 , 2 ) , ( 2 , 2 ) , ( 3 , 2 ) , ( 4 , 2 ) ) )
sink = expander_rewrite ( UOp ( Ops . CONTRACT , dtypes . int . vec ( 4 ) , ( e1 , ) , ( ( 3 , 2 ) , ( 2 , 2 ) ) ) )
assert sink . op is Ops . UNROLL and sink . arg == ( ( 1 , 2 ) , ( 4 , 2 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 0 : 4 ] , ( 0 , 4 , 2 , 6 ) )
sink = expander_rewrite ( UOp ( Ops . CONTRACT , dtypes . int . vec ( 4 ) , ( e1 , ) , ( ( 2 , 2 ) , ( 3 , 2 ) ) ) )
assert sink . op is Ops . UNROLL and sink . arg == ( ( 1 , 2 ) , ( 4 , 2 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg [ 0 : 4 ] , ( 0 , 2 , 4 , 6 ) )
def test_contract_mid ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 8 ) , tuple ( x for x in range ( 8 ) ) ) , ) , ( ( 1 , 2 ) , ( 2 , 2 ) , ( 3 , 2 ) ) )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 2 ) , ( e1 , ) , ( ( 2 , 2 ) , ) )
sink = expander_rewrite ( con )
assert sink . op is Ops . UNROLL and sink . arg == ( ( 1 , 2 ) , ( 3 , 2 ) )
assert sink . src [ 0 ] . op is Ops . VCONST and len ( sink . src [ 0 ] . arg ) == 8
self . assertTupleEqual ( sink . src [ 0 ] . arg , ( 0 , 2 , 1 , 3 , 4 , 6 , 5 , 7 ) )
def test_contract_no_expand ( self ) :
e1 = UOp ( Ops . DEFINE_VAR , dtypes . int )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 2 ) , ( e1 , ) , ( ( 2 , 2 ) , ) )
sink = expander_rewrite ( con )
assert sink . op is Ops . VECTORIZE and len ( sink . src ) == 2
assert sink . src [ 0 ] == sink . src [ 1 ]
def test_contract_half_expand ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( x for x in range ( 4 ) ) ) , ) , ( ( 1 , 4 ) , ) )
con = UOp ( Ops . CONTRACT , dtypes . int . vec ( 8 ) , ( e1 , ) , ( ( 1 , 4 ) , ( 2 , 2 ) ) )
sink = expander_rewrite ( con )
assert sink . op is Ops . VCONST and len ( sink . arg ) == 8
assert sink . arg [ 0 ] == sink . arg [ 1 ]
assert sink . arg [ 0 ] != sink . arg [ 2 ]
assert sink . arg [ 6 ] == sink . arg [ 7 ]
def test_expand_same_axis ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( x for x in range ( 4 ) ) ) , ) , ( ( 1 , 4 ) , ) )
e2 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( 4 * x for x in range ( 4 ) ) ) , ) , ( ( 1 , 4 ) , ) )
sink = expander_rewrite ( e1 + e2 )
self . assertEqual ( sink . op , Ops . UNROLL )
self . assertEqual ( sink . src [ 0 ] . op , Ops . VCONST )
self . assertTupleEqual ( sink . src [ 0 ] . arg , ( 0 , 5 , 10 , 15 ) )
def test_expand_different_axis ( self , flip = False ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( 4 * x for x in range ( 4 ) ) ) , ) , ( ( 1 , 4 ) , ) )
e2 = UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( 4 ) , tuple ( x for x in range ( 4 ) ) ) , ) , ( ( 2 , 4 ) , ) )
sink = expander_rewrite ( ( e2 + e1 ) if flip else ( e1 + e2 ) )
assert sink . op is Ops . UNROLL and len ( sink . src [ 0 ] . arg ) == 16
assert sink . arg == ( ( 1 , 4 ) , ( 2 , 4 ) )
self . assertTupleEqual ( sink . src [ 0 ] . arg , ( 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ) )
def test_expand_different_axis_flip ( self ) : self . test_expand_different_axis ( True )
@unittest . skip ( " no longer supported " )
def test_reduce_known_axis ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 1 , 4 ) , ) )
sink = UOp ( Ops . REDUCE , dtypes . int , ( 3 * e1 , e1 ) , Ops . ADD )
sink = expander_rewrite ( sink )
assert sink . op is Ops . CONST
self . assertEqual ( sink . arg , 3 * ( 0 + 1 + 2 + 3 ) )
@unittest . skip ( " no longer supported " )
def test_reduce_const ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 1 , 4 ) , ) )
sink = UOp ( Ops . REDUCE , dtypes . int , ( UOp . const ( dtypes . int , 3 ) , e1 ) , Ops . ADD )
sink = expander_rewrite ( sink )
assert sink . op is Ops . CONST
self . assertEqual ( sink . arg , 3 * 4 )
@unittest . skip ( " no longer supported " )
def test_double_expand ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 2 , 4 ) , ) )
e2 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , 4 + x ) for x in range ( 4 ) ) , ( ( 2 , 4 ) , ) )
e = UOp ( Ops . UNROLL , dtypes . int , ( e1 , e2 ) , ( ( 1 , 2 ) , ) )
sink = expander_rewrite ( e )
assert sink . op is Ops . UNROLL and len ( sink . src ) == 8
assert sink . arg == ( ( 1 , 2 ) , ( 2 , 4 ) )
self . assertListEqual ( [ x . arg for x in sink . src ] , [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] )
@unittest . skip ( " no longer supported " )
def test_double_expand_reverse ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 1 , 4 ) , ) )
e2 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , 4 + x ) for x in range ( 4 ) ) , ( ( 1 , 4 ) , ) )
e = UOp ( Ops . UNROLL , dtypes . int , ( e1 , e2 ) , ( ( 2 , 2 ) , ) )
sink = expander_rewrite ( e )
assert sink . op is Ops . UNROLL and len ( sink . src ) == 8
assert sink . arg == ( ( 1 , 4 ) , ( 2 , 2 ) )
self . assertListEqual ( [ x . arg for x in sink . src ] , [ 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ] )
@unittest . skip ( " no longer supported " )
def test_double_expand_middle ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 1 , 2 ) , ( 3 , 2 ) ) )
e2 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , 4 + x ) for x in range ( 4 ) ) , ( ( 1 , 2 ) , ( 3 , 2 ) ) )
e = UOp ( Ops . UNROLL , dtypes . int , ( e1 , e2 ) , ( ( 2 , 2 ) , ) )
sink = expander_rewrite ( e )
assert sink . op is Ops . UNROLL and len ( sink . src ) == 8
assert sink . arg == ( ( 1 , 2 ) , ( 2 , 2 ) , ( 3 , 2 ) )
self . assertListEqual ( [ x . arg for x in sink . src ] , [ 0 , 1 , 4 , 5 , 2 , 3 , 6 , 7 ] )
# does this need to work?
@unittest . expectedFailure
@unittest . skip
def test_reduce_different_axis ( self ) :
e1 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 1 , 4 ) , ) )
e2 = UOp ( Ops . UNROLL , dtypes . int , tuple ( UOp . const ( dtypes . int , x ) for x in range ( 4 ) ) , ( ( 2 , 4 ) , ) )
sink = UOp ( Ops . REDUCE , dtypes . int , ( e1 , e2 ) , Ops . ADD )
sink = expander_rewrite ( sink )
print ( sink )
class TestIFUOps ( unittest . TestCase ) :
def test_create_ifs ( self ) :
gbuf = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , ( ) , 0 )
sbuf = UOp ( Ops . DEFINE_LOCAL , dtypes . float . ptr ( size = 4 , local = True ) , ( ) , " smem " )
valid = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " gidx0 " , 10 ) ) < 5
lidx = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " lidx0 " , 4 ) )
gate = valid & ( lidx . ne ( 2 ) )
idx = UOp . const ( dtypes . int , 0 )
st = UOp ( Ops . STORE , dtypes . void , ( sbuf . index ( idx ) , UOp . const ( dtypes . float , 42 ) ) )
barrier = UOp ( Ops . BARRIER , dtypes . void , ( st , ) )
lbuf = UOp ( Ops . LOAD , dtypes . float , ( sbuf . index ( UOp . const ( dtypes . int , 0 ) ) , barrier ) )
store = UOp ( Ops . STORE , dtypes . void , ( gbuf . index ( UOp . const ( dtypes . int , 0 ) , gate ) , lbuf ) )
sink = UOp ( Ops . SINK , dtypes . void , ( store , ) )
sink = full_graph_rewrite ( expand_rewrite ( sink ) )
if_uops = [ u for u in sink . toposort if u . op is Ops . IF ]
self . assertEqual ( len ( if_uops ) , 1 )
self . assertEqual ( if_uops [ 0 ] . src [ 0 ] , gate )
for st in sink . src :
self . assertEqual ( len ( st . src ) , 2 )
def test_expand_ifs_one_gate ( self ) :
gbuf = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , ( ) , 0 )
sbuf = UOp ( Ops . DEFINE_LOCAL , dtypes . float . ptr ( size = 16 , local = True ) , ( ) , " smem " )
valid = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " gidx0 " , 4 ) ) < 1
lidx = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " lidx0 " , 16 ) )
gate = valid & ( lidx . ne ( 2 ) )
st = UOp ( Ops . STORE , dtypes . void , ( sbuf , lidx , UOp . const ( dtypes . float , 42 ) ) )
barrier = UOp ( Ops . BARRIER , dtypes . void , ( st , ) )
lbufs = [ UOp ( Ops . LOAD , dtypes . float , ( sbuf . index ( UOp . const ( dtypes . int , i ) ) , barrier ) ) for i in range ( 4 ) ]
stores = [ UOp ( Ops . STORE , dtypes . void , ( gbuf . index ( UOp . const ( dtypes . int , i ) , gate ) , lbufs [ i ] ) ) for i in range ( 4 ) ]
sink = UOp ( Ops . SINK , dtypes . void , tuple ( stores ) )
sink = full_graph_rewrite ( expand_rewrite ( sink ) )
if_uops = [ u for u in sink . toposort if u . op is Ops . IF ]
self . assertEqual ( len ( if_uops ) , 1 )
self . assertEqual ( if_uops [ 0 ] . src [ 0 ] , gate )
for st in sink . src :
self . assertEqual ( len ( st . src ) , 2 )
# this will be fixed with the merge gated stores bounty
@unittest . expectedFailure
def test_expand_ifs_dumb ( self ) :
buf = UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , ( ) , 0 )
valid = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " gidx0 " , 10 ) ) < 5
lidx = UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( " lidx0 " , 4 ) )
gate = valid & ( lidx . ne ( 2 ) )
stores = [ UOp ( Ops . STORE , dtypes . void , ( buf , UOp . const ( dtypes . int , i ) , UOp . const ( dtypes . float , i ) , gate ) ) for i in range ( 4 ) ]
sink = UOp ( Ops . SINK , dtypes . void , tuple ( stores ) )
sink = full_graph_rewrite ( sink )
if_uops = [ u for u in sink . toposort if u . op is Ops . IF ]
self . assertEqual ( len ( if_uops ) , 1 )
self . assertEqual ( if_uops [ 0 ] . src [ 0 ] , gate )
for st in sink . src :
self . assertEqual ( len ( st . src ) , 2 )
if __name__ == ' __main__ ' :
unittest . main ( verbosity = 2 )