diff --git a/tools/lib/tests/__init__.py b/tools/lib/tests/__init__.py index e69de29bb2..32219c6ce9 100644 --- a/tools/lib/tests/__init__.py +++ b/tools/lib/tests/__init__.py @@ -0,0 +1,12 @@ +import tempfile + +from unittest import mock + +def temporary_cache_dir(func): + def wrapper(*args, **kwargs): + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir_patch = mock.patch("openpilot.tools.lib.url_file.CACHE_DIR", temp_dir) + cache_dir_patch.start() + func(*args, **kwargs) + cache_dir_patch.stop() + return wrapper \ No newline at end of file diff --git a/tools/lib/tests/test_caching.py b/tools/lib/tests/test_caching.py index 1e44b96489..0cb7d66d54 100644 --- a/tools/lib/tests/test_caching.py +++ b/tools/lib/tests/test_caching.py @@ -1,18 +1,14 @@ #!/usr/bin/env python3 import os -import shutil import unittest - -os.environ["COMMA_CACHE"] = "/tmp/__test_cache__" -from openpilot.tools.lib.url_file import URLFile, CACHE_DIR +from openpilot.tools.lib.url_file import URLFile +from tools.lib.tests import temporary_cache_dir class TestFileDownload(unittest.TestCase): def compare_loads(self, url, start=0, length=None): """Compares range between cached and non cached version""" - shutil.rmtree(CACHE_DIR) - file_cached = URLFile(url, cache=True) file_downloaded = URLFile(url, cache=False) @@ -35,6 +31,7 @@ class TestFileDownload(unittest.TestCase): self.assertEqual(file_cached.get_length(), file_downloaded.get_length()) self.assertEqual(response_cached, response_downloaded) + @temporary_cache_dir def test_small_file(self): # Make sure we don't force cache os.environ["FILEREADER_CACHE"] = "0" @@ -55,6 +52,7 @@ class TestFileDownload(unittest.TestCase): for i in range(length // 100): self.compare_loads(small_file_url, 100 * i, 100) + @temporary_cache_dir def test_large_file(self): large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2" # Load the end 100 bytes of both files