Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node: add the walk method to iterate over repository content #4935

Merged
merged 1 commit into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
"""Interface to the file repository of a node instance."""
import contextlib
import io
import pathlib
import tempfile
import typing
from typing import BinaryIO, Dict, Iterable, Iterator, List, Tuple, Union

from aiida.common import exceptions
from aiida.repository import Repository, File
from aiida.repository.backend import SandboxRepositoryBackend

__all__ = ('NodeRepositoryMixin',)

FilePath = Union[str, pathlib.PurePosixPath]


class NodeRepositoryMixin:
"""Interface to the file repository of a node instance.
Expand Down Expand Up @@ -67,7 +70,7 @@ def _repository(self, repository: Repository) -> None:

self._repository_instance = repository

def repository_serialize(self) -> typing.Dict:
def repository_serialize(self) -> Dict:
"""Serialize the metadata of the repository content into a JSON-serializable format.
:return: dictionary with the content metadata.
Expand All @@ -82,7 +85,7 @@ def check_mutability(self):
if self.is_stored:
raise exceptions.ModificationNotAllowed('the node is stored and therefore the repository is immutable.')

def list_objects(self, path: str = None) -> typing.List[File]:
def list_objects(self, path: str = None) -> List[File]:
"""Return a list of the objects contained in this repository sorted by name, optionally in given sub directory.
:param path: the relative path where to store the object in the repository.
Expand All @@ -93,7 +96,7 @@ def list_objects(self, path: str = None) -> typing.List[File]:
"""
return self._repository.list_objects(path)

def list_object_names(self, path: str = None) -> typing.List[str]:
def list_object_names(self, path: str = None) -> List[str]:
"""Return a sorted list of the object names contained in this repository, optionally in the given sub directory.
:param path: the relative path where to store the object in the repository.
Expand All @@ -105,7 +108,7 @@ def list_object_names(self, path: str = None) -> typing.List[str]:
return self._repository.list_object_names(path)

@contextlib.contextmanager
def open(self, path: str, mode='r') -> typing.Iterator[typing.BinaryIO]:
def open(self, path: str, mode='r') -> Iterator[BinaryIO]:
"""Open a file handle to an object stored under the given key.
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
Expand All @@ -127,7 +130,7 @@ def open(self, path: str, mode='r') -> typing.Iterator[typing.BinaryIO]:
else:
yield handle

def get_object_content(self, path: str, mode='r') -> typing.Union[str, bytes]:
def get_object_content(self, path: str, mode='r') -> Union[str, bytes]:
"""Return the content of a object identified by key.
:param key: fully qualified identifier for the object within the repository.
Expand Down Expand Up @@ -190,6 +193,19 @@ def put_object_from_tree(self, filepath: str, path: str = None):
self._repository.put_object_from_tree(filepath, path)
self._update_repository_metadata()

def walk(self, path: FilePath = None) -> Iterable[Tuple[pathlib.PurePosixPath, List[str], List[str]]]:
"""Walk over the directories and files contained within this repository.
.. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in
line with the ``os.walk`` implementation where the order depends on the underlying file system used.
:param path: the relative path of the directory within the repository whose contents to walk.
:return: tuples of root, dirnames and filenames just like ``os.walk``, with the exception that the root path is
always relative with respect to the repository root, instead of an absolute path and it is an instance of
``pathlib.PurePosixPath`` instead of a normal string
"""
yield from self._repository.walk(path)

def delete_object(self, path: str):
"""Delete the object from the repository.
Expand Down
29 changes: 29 additions & 0 deletions tests/orm/node/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=redefined-outer-name,protected-access
"""Tests for the :mod:`aiida.orm.nodes.repository` module."""
import io
import pathlib

import pytest

Expand Down Expand Up @@ -147,3 +148,31 @@ def test_sealed():

with pytest.raises(exceptions.ModificationNotAllowed):
node.put_object_from_filelike(io.BytesIO(b'content'), 'path')


@pytest.mark.usefixtures('clear_database_before_test')
def test_walk():
"""Test the ``NodeRepositoryMixin.walk`` method."""
node = Data()
node.put_object_from_filelike(io.BytesIO(b'content'), 'relative/path')

results = []
for root, dirnames, filenames in node.walk():
results.append((root, sorted(dirnames), sorted(filenames)))

assert sorted(results) == [
(pathlib.Path('.'), ['relative'], []),
(pathlib.Path('relative'), [], ['path']),
]

# Check that the method still works after storing the node
node.store()

results = []
for root, dirnames, filenames in node.walk():
results.append((root, sorted(dirnames), sorted(filenames)))

assert sorted(results) == [
(pathlib.Path('.'), ['relative'], []),
(pathlib.Path('relative'), [], ['path']),
]