Skip to content

Commit

Permalink
Node: add the walk method to iterate over repository content
Browse files Browse the repository at this point in the history
The `walk` method is added that simply forwards to the same method of
the underlying `Repository` instance that is assigned to the node. This
allows to recursively iterate over the directories and files that are
stored in the node's repository folder in the style of `os.walk`.
  • Loading branch information
sphuber committed May 7, 2021
1 parent f49d10b commit d171647
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
28 changes: 22 additions & 6 deletions aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
"""Interface to the file repository of a node instance."""
import contextlib
import io
import pathlib
import tempfile
import typing
from typing import BinaryIO, Dict, Iterable, Iterator, List, Tuple, Union

from aiida.common import exceptions
from aiida.repository import Repository, File
from aiida.repository.backend import SandboxRepositoryBackend

__all__ = ('NodeRepositoryMixin',)

FilePath = Union[str, pathlib.PurePosixPath]


class NodeRepositoryMixin:
"""Interface to the file repository of a node instance.
Expand Down Expand Up @@ -67,7 +70,7 @@ def _repository(self, repository: Repository) -> None:

self._repository_instance = repository

def repository_serialize(self) -> typing.Dict:
def repository_serialize(self) -> Dict:
"""Serialize the metadata of the repository content into a JSON-serializable format.
:return: dictionary with the content metadata.
Expand All @@ -82,7 +85,7 @@ def check_mutability(self):
if self.is_stored:
raise exceptions.ModificationNotAllowed('the node is stored and therefore the repository is immutable.')

def list_objects(self, path: str = None) -> typing.List[File]:
def list_objects(self, path: str = None) -> List[File]:
"""Return a list of the objects contained in this repository sorted by name, optionally in given sub directory.
:param path: the relative path where to store the object in the repository.
Expand All @@ -93,7 +96,7 @@ def list_objects(self, path: str = None) -> typing.List[File]:
"""
return self._repository.list_objects(path)

def list_object_names(self, path: str = None) -> typing.List[str]:
def list_object_names(self, path: str = None) -> List[str]:
"""Return a sorted list of the object names contained in this repository, optionally in the given sub directory.
:param path: the relative path where to store the object in the repository.
Expand All @@ -105,7 +108,7 @@ def list_object_names(self, path: str = None) -> typing.List[str]:
return self._repository.list_object_names(path)

@contextlib.contextmanager
def open(self, path: str, mode='r') -> typing.Iterator[typing.BinaryIO]:
def open(self, path: str, mode='r') -> Iterator[BinaryIO]:
"""Open a file handle to an object stored under the given key.
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
Expand All @@ -127,7 +130,7 @@ def open(self, path: str, mode='r') -> typing.Iterator[typing.BinaryIO]:
else:
yield handle

def get_object_content(self, path: str, mode='r') -> typing.Union[str, bytes]:
def get_object_content(self, path: str, mode='r') -> Union[str, bytes]:
"""Return the content of a object identified by key.
:param key: fully qualified identifier for the object within the repository.
Expand Down Expand Up @@ -190,6 +193,19 @@ def put_object_from_tree(self, filepath: str, path: str = None):
self._repository.put_object_from_tree(filepath, path)
self._update_repository_metadata()

def walk(self, path: FilePath = None) -> Iterable[Tuple[pathlib.PurePosixPath, List[str], List[str]]]:
"""Walk over the directories and files contained within this repository.
.. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in
line with the ``os.walk`` implementation where the order depends on the underlying file system used.
:param path: the relative path of the directory within the repository whose contents to walk.
:return: tuples of root, dirnames and filenames just like ``os.walk``, with the exception that the root path is
always relative with respect to the repository root, instead of an absolute path and it is an instance of
``pathlib.PurePosixPath`` instead of a normal string
"""
yield from self._repository.walk(path)

def delete_object(self, path: str):
"""Delete the object from the repository.
Expand Down
29 changes: 29 additions & 0 deletions tests/orm/node/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=redefined-outer-name,protected-access
"""Tests for the :mod:`aiida.orm.nodes.repository` module."""
import io
import pathlib

import pytest

Expand Down Expand Up @@ -147,3 +148,31 @@ def test_sealed():

with pytest.raises(exceptions.ModificationNotAllowed):
node.put_object_from_filelike(io.BytesIO(b'content'), 'path')


@pytest.mark.usefixtures('clear_database_before_test')
def test_walk():
"""Test the ``NodeRepositoryMixin.walk`` method."""
node = Data()
node.put_object_from_filelike(io.BytesIO(b'content'), 'relative/path')

results = []
for root, dirnames, filenames in node.walk():
results.append((root, sorted(dirnames), sorted(filenames)))

assert sorted(results) == [
(pathlib.Path('.'), ['relative'], []),
(pathlib.Path('relative'), [], ['path']),
]

# Check that the method still works after storing the node
node.store()

results = []
for root, dirnames, filenames in node.walk():
results.append((root, sorted(dirnames), sorted(filenames)))

assert sorted(results) == [
(pathlib.Path('.'), ['relative'], []),
(pathlib.Path('relative'), [], ['path']),
]

0 comments on commit d171647

Please sign in to comment.