You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.0 KiB
39 lines
1.0 KiB
1 month ago
|
import unittest
|
||
|
import multiprocessing.shared_memory as shared_memory
|
||
|
from tinygrad.helpers import CI
|
||
|
from tinygrad.tensor import Tensor, Device
|
||
|
import numpy as np
|
||
|
|
||
|
class TestRawShmBuffer(unittest.TestCase):
|
||
|
def test_e2e(self):
|
||
|
t = Tensor.randn(2, 2, 2).realize()
|
||
|
|
||
|
# copy to shm
|
||
|
shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name
|
||
|
s.close()
|
||
|
t_shm = t.to(f"disk:shm:{shm_name}").realize()
|
||
|
|
||
|
# copy from shm
|
||
|
t2 = t_shm.to(Device.DEFAULT).realize()
|
||
|
|
||
|
assert np.allclose(t.numpy(), t2.numpy())
|
||
|
s.unlink()
|
||
|
|
||
|
@unittest.skipIf(CI, "CI doesn't like big shared memory")
|
||
|
def test_e2e_big(self):
|
||
|
t = Tensor.randn(2048, 2048, 8).realize()
|
||
|
|
||
|
# copy to shm
|
||
|
shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name
|
||
|
s.close()
|
||
|
t_shm = t.to(f"disk:shm:{shm_name}").realize()
|
||
|
|
||
|
# copy from shm
|
||
|
t2 = t_shm.to(Device.DEFAULT).realize()
|
||
|
|
||
|
assert np.allclose(t.numpy(), t2.numpy())
|
||
|
s.unlink()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|