openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
5.1 KiB

import unittest, tarfile, io, os, pathlib, tempfile
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import tar_extract
class TestTarExtractFile(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestTarExtractPAX(unittest.TestCase):
tar_format = tarfile.PAX_FORMAT
max_link_len = 1000_000
test_files = {
'a/file1.txt': b'Hello, World!',
'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b'',
'512file': b'a' * 512,
'long_file': b'some data' * 100,
'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
}
def create_tar_tensor(self):
fobj = io.BytesIO()
test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
for dirname in test_dirs:
dir_info = tarfile.TarInfo(name=dirname)
dir_info.type = tarfile.DIRTYPE
tar.addfile(dir_info)
for filename, content in self.test_files.items():
file_info = tarfile.TarInfo(name=filename)
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))
if len(filename) < self.max_link_len:
link_info = tarfile.TarInfo(name=filename + '.lnk')
link_info.type = tarfile.SYMTYPE
link_info.linkname = filename
tar.addfile(link_info)
return Tensor(fobj.getvalue())
def test_tar_extract_returns_dict(self):
result = tar_extract(self.create_tar_tensor())
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.create_tar_tensor())
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.create_tar_tensor())
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.create_tar_tensor())
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.create_tar_tensor())
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(Tensor(b'This is not a valid tar file'))
def test_tar_extract_invalid_file_long(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(Tensor(b'This is not a valid tar file'*100))
class TestTarExtractUSTAR(TestTarExtractPAX):
tar_format = tarfile.USTAR_FORMAT
max_link_len = 100
test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}
class TestTarExtractGNU(TestTarExtractPAX):
tar_format = tarfile.GNU_FORMAT
if __name__ == '__main__':
unittest.main()