diff --git a/common/file_helpers.py b/common/file_helpers.py index a29eafdd9f..53d8d71439 100644 --- a/common/file_helpers.py +++ b/common/file_helpers.py @@ -1,7 +1,7 @@ import os import shutil import tempfile -from atomicwrites import AtomicWriter +import contextlib def rm_not_exists_ok(path): @@ -71,19 +71,20 @@ def _get_fileobject_func(writer, temp_dir): return writer.get_fileobject(dir=temp_dir) return _get_fileobject -def atomic_write_on_fs_tmp(path, **kwargs): - """Creates an atomic writer using a temporary file in a temporary directory - on the same filesystem as path. - """ - # TODO(mgraczyk): This use of AtomicWriter relies on implementation details to set the temp - # directory. - writer = AtomicWriter(path, **kwargs) - return writer._open(_get_fileobject_func(writer, get_tmpdir_on_same_filesystem(path))) - - -def atomic_write_in_dir(path, **kwargs): - """Creates an atomic writer using a temporary file in the same directory - as the destination file. - """ - writer = AtomicWriter(path, **kwargs) - return writer._open(_get_fileobject_func(writer, os.path.dirname(path))) +@contextlib.contextmanager +def atomic_write_on_fs_tmp(path, mode='w', buffering=-1, encoding=None, newline=None): + """Write to a file atomically using a temporary file in a temporary directory on the same filesystem as path.""" + temp_dir = get_tmpdir_on_same_filesystem(path) + with tempfile.NamedTemporaryFile(mode=mode, buffering=buffering, encoding=encoding, newline=newline, dir=temp_dir, delete=False) as tmp_file: + yield tmp_file + tmp_file_name = tmp_file.name + os.replace(tmp_file_name, path) + +@contextlib.contextmanager +def atomic_write_in_dir(path, mode='w', buffering=-1, encoding=None, newline=None): + """Write to a file atomically using a temporary file in the same directory as the destination file.""" + dir_name = os.path.dirname(path) + with tempfile.NamedTemporaryFile(mode=mode, buffering=buffering, encoding=encoding, newline=newline, dir=dir_name, delete=False) as tmp_file: + yield tmp_file + tmp_file_name = tmp_file.name + os.replace(tmp_file_name, path)