diff --git a/network_file_system/__init__.py b/network_file_system/__init__.py index c3e52eb..39fb1e1 100644 --- a/network_file_system/__init__.py +++ b/network_file_system/__init__.py @@ -12,6 +12,8 @@ import os import shutil import errno +import tempfile +import builtins def copy(src, dst): @@ -95,3 +97,56 @@ def read(path, mode="rt"): with open(path, mode) as f: content = f.read() return content + + +def open(file, mode, tmp_dir=None): + if "r" in str.lower(mode): + return builtins.open(file=file, mode=mode) + elif "w" in str.lower(mode): + return NfsFileWriter(file=file, mode=mode, tmp_dir=tmp_dir) + else: + AttributeError("'mode' must either be 'r' or 'w'.") + + +class NfsFileWriter: + """ + Write to tmp-dir first and move to final destination after closeing. + This guarantees that when the output 'file' exists, the file is also + complete. + """ + + def __init__(self, file, mode, tmp_dir=None): + """ + tmp_dir : str (default: None) + Path to the emporary-directory where the file is initially written + to. Make sure this is a fast drive. + """ + self.file = file + self.mode = mode + self.tmp_dir = tmp_dir + self.ready = False + self.closed = False + + def close(self): + self.rc = self.f.close() + move(src=self.tmp_file, dst=self.file) + self.tmp.cleanup() + self.closed = True + return self.rc + + def write(self, payload): + if not self.ready: + self.__enter__() + return self.f.write(payload) + + def __enter__(self): + self.tmp = tempfile.TemporaryDirectory(dir=self.tmp_dir) + self.tmp_file = os.path.join( + self.tmp.name, os.path.basename(self.file) + ) + self.f = builtins.open(file=self.tmp_file, mode=self.mode) + self.ready = True + return self.f + + def __exit__(self, exc_type, exc_value, exc_tb): + return self.close() diff --git a/network_file_system/tests/test_open.py b/network_file_system/tests/test_open.py new file mode 100644 index 0000000..7d0d0c5 --- /dev/null +++ b/network_file_system/tests/test_open.py @@ -0,0 +1,26 @@ +import network_file_system as nfs +import tempfile +import os + + +def test_open_manually(): + with tempfile.TemporaryDirectory() as tmp: + file = os.path.join(tmp, "123.txt") + assert not os.path.exists(file) + f = nfs.open(file, "wt") + assert not os.path.exists(file) + f.write("omg\n") + assert not os.path.exists(file) + f.close() + assert os.path.exists(file) + + +def test_open_context(): + with tempfile.TemporaryDirectory() as tmp: + file = os.path.join(tmp, "123.txt") + + with nfs.open(file, "wt") as f: + f.write("ralerale") + + with nfs.open(file, "rt") as f: + assert f.read() == "ralerale" diff --git a/setup.py b/setup.py index f357ff9..238a363 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="network_file_system_sebastian-achim-mueller", - version="0.0.1", + version="0.0.2", description="Safe copy, move, and write on remote drives", long_description=long_description, long_description_content_type="text/x-rst",