Skip to content

Commit

Permalink
Improve type hinting for aiida.orm.nodes.data.singlefile
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Aug 14, 2023
1 parent 9bde86e commit b9d087d
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions aiida/orm/nodes/data/singlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io
import os
import pathlib
import typing as t

from aiida.common import exceptions

Expand All @@ -28,15 +29,15 @@ class SinglefileData(Data):
DEFAULT_FILENAME = 'file.txt'

@classmethod
def from_string(cls, content: str, filename: str | pathlib.Path | None = None, **kwargs):
def from_string(cls, content: str, filename: str | pathlib.Path | None = None, **kwargs: t.Any) -> 'SinglefileData':
"""Construct a new instance and set ``content`` as its contents.
:param content: The content as a string.
:param filename: Specify filename to use (defaults to ``file.txt``).
"""
return cls(io.StringIO(content), filename, **kwargs)

def __init__(self, file, filename: str | pathlib.Path | None = None, **kwargs):
def __init__(self, file: str | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any) -> None:
"""Construct a new instance and set the contents to that of the file.
:param file: an absolute filepath or filelike object whose contents to copy.
Expand All @@ -50,15 +51,35 @@ def __init__(self, file, filename: str | pathlib.Path | None = None, **kwargs):
self.set_file(file, filename=filename)

@property
def filename(self):
def filename(self) -> str:
"""Return the name of the file stored.
:return: the filename under which the file is stored in the repository
"""
return self.base.attributes.get('filename')

@t.overload
@contextlib.contextmanager
def open(self, path=None, mode='r'):
def open(self, path: str, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: None, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: str, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: None, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
...

@contextlib.contextmanager # type: ignore[misc]
def open(self, path: str | None = None, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO | t.TextIO]:
"""Return an open file handle to the content of this data node.
:param path: the relative path of the object within the repository.
Expand All @@ -71,15 +92,15 @@ def open(self, path=None, mode='r'):
with self.base.repository.open(path, mode=mode) as handle:
yield handle

def get_content(self):
def get_content(self) -> str:
"""Return the content of the single file stored for this data node.
:return: the content of the file as a string
"""
with self.open() as handle:
with self.open(mode='r') as handle: # type: ignore[call-overload]
return handle.read()

def set_file(self, file, filename: str | pathlib.Path | None = None):
def set_file(self, file: str | t.IO, filename: str | pathlib.Path | None = None) -> None:
"""Store the content of the file in the node's repository, deleting any other existing objects.
:param file: an absolute filepath or filelike object whose contents to copy
Expand Down Expand Up @@ -114,17 +135,17 @@ def set_file(self, file, filename: str | pathlib.Path | None = None):
pass

if is_filelike:
self.base.repository.put_object_from_filelike(file, key)
self.base.repository.put_object_from_filelike(file, key) # type: ignore[arg-type]
else:
self.base.repository.put_object_from_file(file, key)
self.base.repository.put_object_from_file(file, key) # type: ignore[arg-type]

# Delete any other existing objects (minus the current `key` which was already removed from the list)
for existing_key in existing_object_names:
self.base.repository.delete_object(existing_key)

self.base.attributes.set('filename', key)

def _validate(self):
def _validate(self) -> bool:
"""Ensure that there is one object stored in the repository, whose key matches value set for `filename` attr."""
super()._validate()

Expand All @@ -139,3 +160,5 @@ def _validate(self):
raise exceptions.ValidationError(
f'respository files {objects} do not match the `filename` attribute `{filename}`.'
)

return True

0 comments on commit b9d087d

Please sign in to comment.