You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
5.1 KiB
141 lines
5.1 KiB
import unittest, tarfile, io, os, pathlib, tempfile
|
|
import numpy as np
|
|
from tinygrad import Tensor
|
|
from tinygrad.nn.state import tar_extract
|
|
|
|
class TestTarExtractFile(unittest.TestCase):
|
|
def setUp(self):
|
|
self.test_dir = tempfile.mkdtemp()
|
|
self.test_files = {
|
|
'file1.txt': b'Hello, World!',
|
|
'file2.bin': b'\x00\x01\x02\x03\x04',
|
|
'empty_file.txt': b''
|
|
}
|
|
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
|
with tarfile.open(self.tar_path, 'w') as tar:
|
|
for filename, content in self.test_files.items():
|
|
file_path = os.path.join(self.test_dir, filename)
|
|
with open(file_path, 'wb') as f:
|
|
f.write(content)
|
|
tar.add(file_path, arcname=filename)
|
|
|
|
# Create invalid tar file
|
|
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
|
with open(self.invalid_tar_path, 'wb') as f:
|
|
f.write(b'This is not a valid tar file')
|
|
|
|
def tearDown(self):
|
|
for filename in self.test_files:
|
|
os.remove(os.path.join(self.test_dir, filename))
|
|
os.remove(self.tar_path)
|
|
os.remove(self.invalid_tar_path)
|
|
os.rmdir(self.test_dir)
|
|
|
|
def test_tar_extract_returns_dict(self):
|
|
result = tar_extract(self.tar_path)
|
|
self.assertIsInstance(result, dict)
|
|
|
|
def test_tar_extract_correct_keys(self):
|
|
result = tar_extract(self.tar_path)
|
|
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
|
|
|
def test_tar_extract_content_size(self):
|
|
result = tar_extract(self.tar_path)
|
|
for filename, content in self.test_files.items():
|
|
self.assertEqual(len(result[filename]), len(content))
|
|
|
|
def test_tar_extract_content_values(self):
|
|
result = tar_extract(self.tar_path)
|
|
for filename, content in self.test_files.items():
|
|
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
|
|
|
def test_tar_extract_empty_file(self):
|
|
result = tar_extract(self.tar_path)
|
|
self.assertEqual(len(result['empty_file.txt']), 0)
|
|
|
|
def test_tar_extract_non_existent_file(self):
|
|
with self.assertRaises(FileNotFoundError):
|
|
tar_extract('non_existent_file.tar')
|
|
|
|
def test_tar_extract_invalid_file(self):
|
|
with self.assertRaises(tarfile.ReadError):
|
|
tar_extract(self.invalid_tar_path)
|
|
|
|
class TestTarExtractPAX(unittest.TestCase):
|
|
tar_format = tarfile.PAX_FORMAT
|
|
max_link_len = 1000_000
|
|
test_files = {
|
|
'a/file1.txt': b'Hello, World!',
|
|
'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
|
|
'empty_file.txt': b'',
|
|
'512file': b'a' * 512,
|
|
'long_file': b'some data' * 100,
|
|
'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
|
|
'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
|
|
}
|
|
|
|
def create_tar_tensor(self):
|
|
fobj = io.BytesIO()
|
|
test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
|
|
with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
|
|
for dirname in test_dirs:
|
|
dir_info = tarfile.TarInfo(name=dirname)
|
|
dir_info.type = tarfile.DIRTYPE
|
|
tar.addfile(dir_info)
|
|
|
|
for filename, content in self.test_files.items():
|
|
file_info = tarfile.TarInfo(name=filename)
|
|
file_info.size = len(content)
|
|
tar.addfile(file_info, io.BytesIO(content))
|
|
|
|
if len(filename) < self.max_link_len:
|
|
link_info = tarfile.TarInfo(name=filename + '.lnk')
|
|
link_info.type = tarfile.SYMTYPE
|
|
link_info.linkname = filename
|
|
tar.addfile(link_info)
|
|
return Tensor(fobj.getvalue())
|
|
|
|
def test_tar_extract_returns_dict(self):
|
|
result = tar_extract(self.create_tar_tensor())
|
|
self.assertIsInstance(result, dict)
|
|
|
|
def test_tar_extract_correct_keys(self):
|
|
result = tar_extract(self.create_tar_tensor())
|
|
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
|
|
|
def test_tar_extract_content_size(self):
|
|
result = tar_extract(self.create_tar_tensor())
|
|
for filename, content in self.test_files.items():
|
|
self.assertEqual(len(result[filename]), len(content))
|
|
|
|
def test_tar_extract_content_values(self):
|
|
result = tar_extract(self.create_tar_tensor())
|
|
for filename, content in self.test_files.items():
|
|
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
|
|
|
def test_tar_extract_empty_file(self):
|
|
result = tar_extract(self.create_tar_tensor())
|
|
self.assertEqual(len(result['empty_file.txt']), 0)
|
|
|
|
def test_tar_extract_non_existent_file(self):
|
|
with self.assertRaises(FileNotFoundError):
|
|
tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))
|
|
|
|
def test_tar_extract_invalid_file(self):
|
|
with self.assertRaises(tarfile.ReadError):
|
|
tar_extract(Tensor(b'This is not a valid tar file'))
|
|
|
|
def test_tar_extract_invalid_file_long(self):
|
|
with self.assertRaises(tarfile.ReadError):
|
|
tar_extract(Tensor(b'This is not a valid tar file'*100))
|
|
|
|
class TestTarExtractUSTAR(TestTarExtractPAX):
|
|
tar_format = tarfile.USTAR_FORMAT
|
|
max_link_len = 100
|
|
test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}
|
|
|
|
class TestTarExtractGNU(TestTarExtractPAX):
|
|
tar_format = tarfile.GNU_FORMAT
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|