Skip to content

Commit

Permalink
Add repositories argument to audb.available()
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Dec 12, 2024
1 parent 21c0634 commit 234c0df
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
9 changes: 8 additions & 1 deletion audb/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
def available(
*,
only_latest: bool = False,
repositories: Repository | Sequence[Repository] = None,
) -> pd.DataFrame:
r"""List all databases that are available to the user.
Args:
only_latest: include only latest version of database
repositories: search only in the given repositories
Returns:
table with database name as index,
Expand Down Expand Up @@ -59,7 +61,12 @@ def add_database(name: str, version: str, repository: Repository):
]
)

for repository in config.REPOSITORIES:
if repositories is not None:
repositories = audeer.to_list(repositories)
else:
repositories = config.REPOSITORIES

for repository in repositories:
try:
backend_interface = repository.create_backend_interface()
with backend_interface.backend as backend:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ def test_available_broken_dataset(private_and_public_repository):
assert "broken-dataset" not in df


def test_available_repositories(tmpdir):
"""Test repositories argument of available()."""
repositories = []
for n in range(2):
host = audeer.mkdir(tmpdir, f"host{n}")
repo = f"repo{n}"
repositories.append(audb.Repository(repo, host, "file-system"))
# Fake dataset by adding db.yaml file
audeer.touch(audeer.mkdir(host, repo, f"name{n}", "1.0.0"), "db.yaml")
df = audb.available(repositories=repositories)
assert len(df) == 2
for n, repository in enumerate(repositories):
# Test for string and list arguments
for repo_test in [repository, [repository]]:
df = audb.available(repositories=repo_test)
assert len(df) == 1
assert df.host.iloc[0] == repository.host
assert df.repository.iloc[0] == repository.name
assert df.index[0] == f"name{n}"


def test_versions(tmpdir, repository):
"""Test versions() for non existing repositories.
Expand Down

0 comments on commit 234c0df

Please sign in to comment.