You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
912 B
32 lines
912 B
import unittest
|
|
from tinygrad.tensor import Tensor
|
|
|
|
class TestMaskedShapeTracker(unittest.TestCase):
|
|
def test_mul_masked(self):
|
|
a = Tensor([1,1,1,1,1])
|
|
b = Tensor([1,1]).pad(((0,3),))
|
|
c = a*b
|
|
assert c.shape == a.shape
|
|
#assert c.lazydata.st.views[0].mask is not None
|
|
ret = c.data()
|
|
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
|
|
|
|
def test_mul_both_masked(self):
|
|
a = Tensor([1,1]).pad(((0,3),))
|
|
b = Tensor([1,1]).pad(((0,3),))
|
|
c = a*b
|
|
assert c.shape == a.shape
|
|
#assert c.lazydata.st.views[0].mask is not None
|
|
ret = c.data()
|
|
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
|
|
|
|
def test_add_masked(self):
|
|
a = Tensor([1,1]).pad(((0,2),))
|
|
b = Tensor([1,1]).pad(((0,2),))
|
|
c = a+b
|
|
#assert c.lazydata.st.views[0].mask is not None
|
|
ret = c.data()
|
|
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|