import os import shutil import tempfile import contextlib from typing import Optional def rm_not_exists_ok(path): try: os.remove(path) except OSError: if os.path.exists(path): raise def rm_tree_or_link(path): if os.path.islink(path): os.unlink(path) elif os.path.isdir(path): shutil.rmtree(path) def get_tmpdir_on_same_filesystem(path): normpath = os.path.normpath(path) parts = normpath.split("/") if len(parts) > 1 and parts[1] == "scratch": return "/scratch/tmp" elif len(parts) > 2 and parts[2] == "runner": return f"/{parts[1]}/runner/tmp" return "/tmp" class NamedTemporaryDir(): def __init__(self, temp_dir=None): self._path = tempfile.mkdtemp(dir=temp_dir) @property def name(self): return self._path def close(self): shutil.rmtree(self._path) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() class CallbackReader: """Wraps a file, but overrides the read method to also call a callback function with the number of bytes read so far.""" def __init__(self, f, callback, *args): self.f = f self.callback = callback self.cb_args = args self.total_read = 0 def __getattr__(self, attr): return getattr(self.f, attr) def read(self, *args, **kwargs): chunk = self.f.read(*args, **kwargs) self.total_read += len(chunk) self.callback(*self.cb_args, self.total_read) return chunk def _get_fileobject_func(writer, temp_dir): def _get_fileobject(): return writer.get_fileobject(dir=temp_dir) return _get_fileobject @contextlib.contextmanager def atomic_write_on_fs_tmp(path: str, mode: str = 'w', buffering: int = -1, encoding: Optional[str] = None, newline: Optional[str] = None): """Write to a file atomically using a temporary file in a temporary directory on the same filesystem as path.""" temp_dir = get_tmpdir_on_same_filesystem(path) with tempfile.NamedTemporaryFile(mode=mode, buffering=buffering, encoding=encoding, newline=newline, dir=temp_dir, delete=False) as tmp_file: yield tmp_file tmp_file_name = tmp_file.name os.replace(tmp_file_name, path) @contextlib.contextmanager def atomic_write_in_dir(path: str, mode: str = 'w', buffering: int = -1, encoding: Optional[str] = None, newline: Optional[str] = None, overwrite: bool = False): """Write to a file atomically using a temporary file in the same directory as the destination file.""" dir_name = os.path.dirname(path) if not overwrite and os.path.exists(path): raise FileExistsError(f"File '{path}' already exists. To overwrite it, set 'overwrite' to True.") with tempfile.NamedTemporaryFile(mode=mode, buffering=buffering, encoding=encoding, newline=newline, dir=dir_name, delete=False) as tmp_file: yield tmp_file tmp_file_name = tmp_file.name os.replace(tmp_file_name, path)