You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.4 KiB
54 lines
1.4 KiB
2 days ago
|
import unittest
|
||
|
from tinygrad.helpers import DEBUG
|
||
|
from tinygrad.dtype import dtypes
|
||
|
from tinygrad.ops import UPat, track_rewrites, GroupOp, Ops
|
||
|
from tinygrad.upat import _get_code, upat_compile
|
||
|
import dis
|
||
|
|
||
|
@track_rewrites()
|
||
|
def do_compile(up):
|
||
|
print("\n***** COMPILE", up)
|
||
|
match_code = _get_code(up, False)
|
||
|
match = upat_compile(up, lambda **kwargs: None)
|
||
|
print(match_code[0])
|
||
|
if DEBUG >= 2: dis.dis(match)
|
||
|
return match_code[0]
|
||
|
|
||
|
class TestUPatCompile(unittest.TestCase):
|
||
|
def test_double(self):
|
||
|
up = UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1")
|
||
|
do_compile(up)
|
||
|
|
||
|
def test_single(self):
|
||
|
up = UPat.var("x") + UPat.var("y")
|
||
|
do_compile(up)
|
||
|
|
||
|
def test_xpx(self):
|
||
|
up = UPat.var("x") + UPat.var("x")
|
||
|
do_compile(up)
|
||
|
|
||
|
def test_xp0(self):
|
||
|
up = UPat.var("x") + 0
|
||
|
do_compile(up)
|
||
|
|
||
|
def test_bool(self):
|
||
|
up = UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool)
|
||
|
do_compile(up)
|
||
|
|
||
|
def test_single_c(self):
|
||
|
up = (UPat.var("x") + UPat.var("y")) * UPat.var("c")
|
||
|
do_compile(up)
|
||
|
|
||
|
def test_const_folding(self):
|
||
|
up = UPat(GroupOp.ALU-{Ops.THREEFRY}, name="a", src=UPat((Ops.VCONST, Ops.CONST)))
|
||
|
do_compile(up)
|
||
|
|
||
|
@unittest.skip("fix this")
|
||
|
def test_range_named(self):
|
||
|
# this should be one src, but this should also still work
|
||
|
up = UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16))
|
||
|
do_compile(up)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|