Skip to content

Commit

Permalink
Make DatabaseIterator abstract class
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Aug 21, 2024
1 parent 54ee461 commit 920c857
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
32 changes: 17 additions & 15 deletions audb/core/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
import os
import typing

Expand All @@ -22,7 +23,7 @@
from audb.core.lock import FolderLock


class DatabaseIterator(audformat.Database):
class DatabaseIterator(audformat.Database, metaclass=abc.ABCMeta):
r"""Database iterator.
This class cannot be created directly,
Expand Down Expand Up @@ -187,6 +188,21 @@ def __next__(self) -> pd.DataFrame:

return df

@abc.abstractmethod
def _initialize_stream(self) -> typing.Iterable:
r"""Create table iterator object.
This method needs to be implemented
for the table file types
in the classes,
that inherit from :class:`audb.DatabaseIterator`.
Returns:
table iterator
"""
return # pragma: nocover

@staticmethod
def _cleanup_database(db: audformat.Database, table: str):
r"""Remove parts of database, not used by table.
Expand Down Expand Up @@ -261,20 +277,6 @@ def _get_batch(self) -> pd.DataFrame:

return df

def _initialize_stream(self) -> typing.Iterable:
r"""Create table iterator object.
This method needs to be implemented
for the table file types
in the classes,
that inherit from :class:`audb.DatabaseIterator`.
Returns:
table iterator
"""
return None # pragma: nocover

def _load_media(self, df: pd.DataFrame):
r"""Load media file for batch.
Expand Down
31 changes: 31 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,34 @@ def test_errors(self, table: str, error, error_msg: str):
version=self.version,
verbose=False,
)


def test_database_iterator_error():
db = audformat.Database("db")
db["some"] = audformat.Table()
table = "some"
error_msg = (
"Can't instantiate abstract class DatabaseIterator "
"with abstract method _initialize_stream"
)
with pytest.raises(TypeError, match=error_msg):
audb.DatabaseIterator(
db,
table,
version=None,
map=None,
batch_size=16,
shuffle=False,
buffer_size=100_000,
only_metadata=False,
bit_depth=None,
channels=None,
format=None,
mixdown=False,
sampling_rate=None,
full_path=False,
cache_root=None,
num_workers=1,
timeout=-1,
verbose=False,
)

0 comments on commit 920c857

Please sign in to comment.