diff --git a/audb/__init__.py b/audb/__init__.py index 329998d8..2f323736 100644 --- a/audb/__init__.py +++ b/audb/__init__.py @@ -19,6 +19,8 @@ from audb.core.load_to import load_to from audb.core.publish import publish from audb.core.repository import Repository +from audb.core.stream import DatabaseIterator +from audb.core.stream import stream __all__ = [] diff --git a/audb/core/load.py b/audb/core/load.py index 75cb985b..4059a488 100644 --- a/audb/core/load.py +++ b/audb/core/load.py @@ -25,7 +25,7 @@ from audb.core.utils import lookup_backend -CachedVersions = typing.Sequence[typing.Tuple[audeer.StrictVersion, str, Dependencies],] +CachedVersions = typing.Sequence[typing.Tuple[audeer.StrictVersion, str, Dependencies]] def _cached_versions( @@ -805,7 +805,28 @@ def _misc_tables_used_in_scheme( if scheme.uses_table: misc_tables_used_in_scheme.append(scheme.labels) - return list(set(misc_tables_used_in_scheme)) + return audeer.unique(misc_tables_used_in_scheme) + + +def _misc_tables_used_in_table( + table: audformat.Table, +) -> typing.List[str]: + r"""List of misc tables that are used inside schemes of a table. + + Args: + table: table object + + Returns: + unique list of misc tables used in schemes of the table + + """ + misc_tables_used_in_table = [] + for column_id, column in table.columns.items(): + if column.scheme_id is not None: + scheme = table.db.schemes[column.scheme_id] + if scheme.uses_table: + misc_tables_used_in_table.append(scheme.labels) + return audeer.unique(misc_tables_used_in_table) def _missing_files( diff --git a/audb/core/stream.py b/audb/core/stream.py new file mode 100644 index 00000000..4c78b5ae --- /dev/null +++ b/audb/core/stream.py @@ -0,0 +1,598 @@ +from __future__ import annotations + +import abc +import os +import typing + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as parquet + +import audformat + +from audb.core.api import dependencies +from audb.core.cache import database_cache_root +from audb.core.dependencies import error_message_missing_object +from audb.core.flavor import Flavor +from audb.core.load import _load_files +from audb.core.load import _misc_tables_used_in_table +from audb.core.load import _update_path +from audb.core.load import latest_version +from audb.core.load import load_header_to +from audb.core.load import load_media +from audb.core.lock import FolderLock + + +class DatabaseIterator(audformat.Database, metaclass=abc.ABCMeta): + r"""Database iterator. + + This class cannot be created directly, + but only by calling :func:`audb.stream`. + + Examples: + Create :class:`audb.DatabaseIterator` object. + + >>> db = audb.stream( + ... "emodb", + ... "files", + ... version="1.4.1", + ... batch_size=4, + ... only_metadata=True, + ... full_path=False, + ... verbose=False, + ... ) + + The :class:`audb.DatabaseIterator` object + is restricted to the requested table, + and all related schemes + and misc tables + used as labels in a related scheme. + + >>> db + name: emodb + ... + schemes: + age: {description: Age of speaker, dtype: int, minimum: 0} + duration: {dtype: time} + gender: + description: Gender of speaker + dtype: str + labels: [female, male] + language: {description: Language of speaker, dtype: str} + speaker: {description: The actors could produce each sentence as often as they liked + and were asked to remember a real situation from their past when they had felt + this emotion., dtype: int, labels: speaker} + transcription: + description: Sentence produced by actor. + dtype: str + labels: ... + tables: + files: + type: filewise + columns: + duration: {scheme_id: duration} + speaker: {scheme_id: speaker} + transcription: {scheme_id: transcription} + misc_tables: + speaker: + levels: {speaker: int} + columns: + age: {scheme_id: age} + gender: {scheme_id: gender} + language: {scheme_id: language} + ... + + Request the first batch of data. + + >>> next(db) + duration speaker transcription + file + wav/03a01Fa.wav 0 days 00:00:01.898250 3 a01 + wav/03a01Nc.wav 0 days 00:00:01.611250 3 a01 + wav/03a01Wa.wav 0 days 00:00:01.877812500 3 a01 + wav/03a02Fc.wav 0 days 00:00:02.006250 3 a02 + + During the iteration, + the :class:`audb.DatabaseIterator` object + provides access to the current batch of data. + + >>> db["files"].get(map={"speaker": "age"}) + duration transcription age + file + wav/03a01Fa.wav 0 days 00:00:01.898250 a01 31 + wav/03a01Nc.wav 0 days 00:00:01.611250 a01 31 + wav/03a01Wa.wav 0 days 00:00:01.877812500 a01 31 + wav/03a02Fc.wav 0 days 00:00:02.006250 a02 31 + + """ # noqa: E501 + + def __init__( + self, + db: audformat.Database, + table: str, + *, + version: str, + map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]], + batch_size: int, + shuffle: bool, + buffer_size: int, + only_metadata: bool, + bit_depth: int, + channels: typing.Union[int, typing.Sequence[int]], + format: str, + mixdown: bool, + sampling_rate: int, + full_path: bool, + cache_root: str, + num_workers: typing.Optional[int], + timeout: float, + verbose: bool, + ): + self._cleanup_database(db, table) + + # Transfer attributes of database object + for attr in db.__dict__.keys(): + setattr(self, attr, getattr(db, attr)) + + self._table = table + self._version = version + self._map = map + self._batch_size = batch_size + self._shuffle = shuffle + self._buffer_size = buffer_size + self._only_metadata = only_metadata + self._bit_depth = bit_depth + self._channels = channels + self._format = format + self._mixdown = mixdown + self._sampling_rate = sampling_rate + self._full_path = full_path + self._cache_root = cache_root + self._num_workers = num_workers + self._timeout = timeout + self._verbose = verbose + + if shuffle: + self._samples = buffer_size + else: + self._samples = batch_size + self._buffer = pd.DataFrame() + self._stream = self._initialize_stream() + + def __iter__(self) -> DatabaseIterator: + r"""Iterator generator.""" + return self + + def __next__(self) -> pd.DataFrame: + r"""Iterate database.""" + # Load part of table + df = self._get_batch() + self[self._table]._df = df + + # Load corresponding media files + self._load_media(df) + + # Map column values + if self._map is not None: + df = self[self._table].get(map=self._map) + + # Adjust full paths and file extensions in table + _update_path( + self, + self.root, + self._full_path, + self._format, + self._num_workers, + self._verbose, + ) + + return df + + @abc.abstractmethod + def _initialize_stream(self) -> typing.Iterable: + r"""Create table iterator object. + + This method needs to be implemented + for the table file types + in the classes, + that inherit from :class:`audb.DatabaseIterator`. + + Returns: + table iterator + + """ + return # pragma: nocover + + @staticmethod + def _cleanup_database(db: audformat.Database, table: str): + r"""Remove parts of database, not used by table. + + Args: + db: database object + table: table ID + + """ + tables = [table] + misc_tables = _misc_tables_used_in_table(db[table]) + + # Remove non-requested table + db.drop_tables([table for table in list(db.tables) if table not in tables]) + + # Remove unused schemes + used_schemes = [] + for table in misc_tables + tables: + for column_id, column in db[table].columns.items(): + if column.scheme_id is not None: + used_schemes.append(column.scheme_id) + for scheme in list(db.schemes): + if scheme not in used_schemes: + del db.schemes[scheme] + + # Remove misc tables not required by the schemes of table + db.drop_tables( + [ + misc_table + for misc_table in list(db.misc_tables) + if misc_table not in misc_tables + tables + ] + ) + + # Remove unused splits + used_splits = [ + db[table].split_id for table in list(db) if db[table].split_id is not None + ] + for split in list(db.splits): + if split not in used_splits: + del db.splits[split] + + def _get_batch(self) -> pd.DataFrame: + r"""Read table batch. + + Returns: + dataframe + + """ + if self._shuffle: + buffer_read_length = self._batch_size + df1 = pd.DataFrame() + if len(self._buffer) < self._batch_size: + if len(self._buffer) > 0: + # Empty current buffer, + # before refilling + df1 = self._buffer + buffer_read_length = self._batch_size - len(self._buffer) + self._buffer = self._read_dataframe() + # Shuffle data + self._buffer = self._buffer.sample(frac=1) + + df2 = self._buffer.iloc[:buffer_read_length, :] + self._buffer.drop(index=df2.index, inplace=True) + df = pd.concat([df1, df2]) + + else: + df = self._read_dataframe() + + if len(df) == 0: + raise StopIteration + + return df + + def _load_media(self, df: pd.DataFrame): + r"""Load media file for batch. + + Args: + df: dataframe of batch + + """ + if audformat.is_segmented_index(df.index): + media = list(df.index.get_level_values("file")) + elif audformat.is_filewise_index(df.index): + media = list(df.index) + else: + media = [] + if not self._only_metadata and len(media) > 0: + load_media( + self.name, + media, + version=self._version, + bit_depth=self._bit_depth, + channels=self._channels, + format=self._format, + mixdown=self._mixdown, + sampling_rate=self._sampling_rate, + cache_root=self._cache_root, + num_workers=self._num_workers, + timeout=self._timeout, + verbose=self._verbose, + ) + + def _postprocess_batch(self, batch: typing.Any) -> pd.DataFrame: + r"""Post-process batch data to desired dataframe. + + Args: + batch: input data + + Returns: + dataframe + + """ + return batch # pragma: nocover + + def _postprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: + r"""Post-process dataframe to have correct index and data types. + + Args: + df: dataframe + + Returns: + dataframe + + """ + # Adjust dtypes and set index + df = self[self._table]._pyarrow_convert_dtypes(df, convert_all=False) + index_columns = list(self[self._table]._levels_and_dtypes.keys()) + df = self[self._table]._set_index(df, index_columns) + return df + + def _read_dataframe(self) -> pd.DataFrame: + r"""Read dataframe from table. + + Returns: + dataframe + + """ + try: + df = self._postprocess_dataframe( + self._postprocess_batch(next(iter(self._stream))) + ) + except StopIteration: + # Ensure return an empty dataframe, + # at the last iteration, + # when no remaining data is left + df = pd.DataFrame() + return df + + +class DatabaseIteratorCsv(DatabaseIterator): + def _initialize_stream(self): + # Prepare settings for csv file reading + + # index + columns_and_dtypes = self[self._table]._levels_and_dtypes.copy() + # add columns + for column_id, column in self[self._table].columns.items(): + if column.scheme_id is not None: + columns_and_dtypes[column_id] = self.schemes[column.scheme_id].dtype + else: + columns_and_dtypes[column_id] = audformat.define.DataType.OBJECT + + # Replace data type with converter for dates or timestamps + converters = {} + dtypes_wo_converters = {} + for column, dtype in columns_and_dtypes.items(): + if dtype == audformat.define.DataType.DATE: + converters[column] = lambda x: pd.to_datetime(x) + elif dtype == audformat.define.DataType.TIME: + converters[column] = lambda x: pd.to_timedelta(x) + else: + dtypes_wo_converters[column] = audformat.core.common.to_pandas_dtype( + dtype + ) + + if self._samples == 0: + # `pandas.read_csv()` does not support a `chunksize=0` + return [] + else: + file = os.path.join(self.root, f"db.{self._table}.csv") + return pd.read_csv( + file, + chunksize=self._samples, + usecols=list(columns_and_dtypes.keys()), + dtype=dtypes_wo_converters, + converters=converters, + float_precision="round_trip", + ) + + +class DatabaseIteratorParquet(DatabaseIterator): + def _initialize_stream(self) -> pa.RecordBatch: + file = os.path.join(self.root, f"db.{self._table}.parquet") + return parquet.ParquetFile(file).iter_batches(batch_size=self._samples) + + def _postprocess_batch(self, batch: pa.RecordBatch) -> pd.DataFrame: + df = batch.to_pandas( + deduplicate_objects=False, + types_mapper={ + pa.string(): pd.StringDtype(), + }.get, # we have to provide a callable, not a dict + ) + return df + + +def stream( + name: str, + table: str, + *, + version: str = None, + map: typing.Dict[str, typing.Union[str, typing.Sequence[str]]] = None, + batch_size: int = 16, + shuffle: bool = False, + buffer_size: int = 100_000, + only_metadata: bool = False, + bit_depth: int = None, + channels: typing.Union[int, typing.Sequence[int]] = None, + format: str = None, + mixdown: bool = False, + sampling_rate: int = None, + full_path: bool = True, + cache_root: str = None, + num_workers: typing.Optional[int] = 1, + timeout: float = -1, + verbose: bool = True, +) -> DatabaseIterator: + r"""Stream table and media files of a database. + + Loads only the first ``batch_size`` rows + of a table into memory, + and downloads only the related media files, + if any media files are requested. + + By setting + ``bit_depth``, + ``channels``, + ``format``, + ``mixdown``, + and ``sampling_rate`` + we can request a specific flavor of the database. + In that case media files are automatically converted to the desired + properties (see also :class:`audb.Flavor`). + + Args: + name: name of database + table: name of table + version: version string, latest if ``None`` + map: map scheme or scheme fields to column values. + For example if your table holds a column ``speaker`` with + speaker IDs, which is assigned to a scheme that contains a + dict mapping speaker IDs to age and gender entries, + ``map={'speaker': ['age', 'gender']}`` + will replace the column with two new columns that map ID + values to age and gender, respectively. + To also keep the original column with speaker IDS, you can do + ``map={'speaker': ['speaker', 'age', 'gender']}`` + batch_size: number of table rows + to return in one iteration + shuffle: if ``True``, + it first reads ``buffer_size`` rows from the table + and selects ``batch_size`` randomly from them + buffer_size: number of table rows + to be loaded + when ``shuffle`` is ``True`` + only_metadata: load only header and tables of database + bit_depth: bit depth, one of ``16``, ``24``, ``32`` + channels: channel selection, see :func:`audresample.remix`. + Note that media files with too few channels + will be first upsampled by repeating the existing channels. + E.g. ``channels=[0, 1]`` upsamples all mono files to stereo, + and ``channels=[1]`` returns the second channel + of all multi-channel files + and all mono files + format: file format, one of ``'flac'``, ``'wav'`` + mixdown: apply mono mix-down + sampling_rate: sampling rate in Hz, one of + ``8000``, ``16000``, ``22500``, ``44100``, ``48000`` + full_path: replace relative with absolute file paths + cache_root: cache folder where databases are stored. + If not set :meth:`audb.default_cache_root` is used + num_workers: number of parallel jobs or 1 for sequential + processing. If ``None`` will be set to the number of + processors on the machine multiplied by 5 + timeout: maximum wait time if another thread or process is already + accessing the database. If timeout is reached, ``None`` is + returned. If timeout < 0 the method will block until the + database can be accessed + verbose: show debug messages + + Returns: + database object + + Raises: + ValueError: if table is requested + that is not part of the database + ValueError: if a non-supported ``bit_depth``, + ``format``, + or ``sampling_rate`` + is requested + RuntimeError: if a flavor is requested, + but the database contains media files, + that don't contain audio, + e.g. text files + + Examples: + >>> import numpy as np + >>> np.random.seed(1) + >>> db = audb.stream( + ... "emodb", + ... "files", + ... version="1.4.1", + ... batch_size=4, + ... shuffle=True, + ... only_metadata=True, + ... full_path=False, + ... verbose=False, + ... ) + >>> next(db) + duration speaker transcription + file + wav/14a05Fb.wav 0 days 00:00:03.128687500 14 a05 + wav/15a05Eb.wav 0 days 00:00:03.993562500 15 a05 + wav/12a05Nd.wav 0 days 00:00:03.185875 12 a05 + wav/13a07Na.wav 0 days 00:00:01.911687500 13 a07 + + """ + if version is None: + version = latest_version(name) + + # Extract kwargs + # to pass on to the DatabaseIterator constructor + kwargs = dict((k, v) for (k, v) in locals().items() if k not in ["name", "table"]) + + flavor = Flavor( + bit_depth=bit_depth, + channels=channels, + format=format, + mixdown=mixdown, + sampling_rate=sampling_rate, + ) + db_root = database_cache_root(name, version, cache_root, flavor) + + deps = dependencies( + name, + version=version, + cache_root=cache_root, + verbose=verbose, + ) + + if table not in deps.table_ids: + msg = error_message_missing_object("table", [table], name, version) + raise ValueError(msg) + + with FolderLock(db_root): + # Start with database header without tables + db, backend_interface = load_header_to( + db_root, + name, + version, + flavor=flavor, + add_audb_meta=True, + ) + + # Misc tables required by schemes of requested table + misc_tables = _misc_tables_used_in_table(db[table]) + + # Load table files + _load_files( + misc_tables + [table], + "table", + backend_interface, + db_root, + db, + version, + None, + deps, + Flavor(), + cache_root, + False, # pickle_tables + num_workers, + verbose, + ) + + # Load misc tables completely + for misc_table in misc_tables: + table_file = os.path.join(db_root, f"db.{misc_table}") + db[misc_table].load(table_file) + + if os.path.exists(os.path.join(db_root, f"db.{table}.parquet")): + return DatabaseIteratorParquet(db, table, **kwargs) + else: + return DatabaseIteratorCsv(db, table, **kwargs) diff --git a/docs/api-src/audb.rst b/docs/api-src/audb.rst index f430a5fc..4e9fa411 100644 --- a/docs/api-src/audb.rst +++ b/docs/api-src/audb.rst @@ -10,6 +10,7 @@ audb :nosignatures: config + DatabaseIterator Dependencies Flavor Repository @@ -35,4 +36,5 @@ audb publish remove_media repository + stream versions diff --git a/docs/conf.py b/docs/conf.py index c3a898b0..19458180 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -74,7 +74,9 @@ "__contains__", "__eq__", "__getitem__", + "__iter__", "__len__", + "__next__", ] # HTML -------------------------------------------------------------------- diff --git a/docs/load.rst b/docs/load.rst index a133308c..020fa028 100644 --- a/docs/load.rst +++ b/docs/load.rst @@ -327,6 +327,78 @@ from the tables. db["emotion"].get() +.. _streaming: + +Streaming +--------- + +:func:`audb.stream` provides a pseudo-streaming mode, +which helps to load large datasets. +It will only load ``batch_size`` number of rows +from a selected table into memory, +and download only matching media files +in each iteration. +The table and media files +are still stored in the cache. + +.. Prefetch data with only_metadata=True +.. jupyter-execute:: + :hide-code: + + db = audb.stream( + "emodb", + "emotion", + version="1.4.1", + batch_size=4, + only_metadata=True, + full_path=False, + verbose=False, + ) + +.. code-block:: python + + db = audb.stream( + "emodb", + "emotion", + version="1.4.1", + batch_size=4, + full_path=False, + verbose=False, + ) + +It returns an :class:`audb.DatabaseIterator` object, +which behaves as :class:`audformat.Database`, +but provides the ability +to iterate over the database: + +.. jupyter-execute:: + + next(db) + +With ``shuffle=True``, +a user can request +that the data is returned in a random order. +:func:`audb.stream` will then load ``buffer_size`` of rows +into an buffer and selected randomly from those. + +.. jupyter-execute:: + + import numpy as np + np.random.seed(1) + db = audb.stream( + "emodb", + "emotion", + version="1.4.1", + batch_size=4, + shuffle=True, + buffer_size=100_000, + only_metadata=True, + full_path=False, + verbose=False, + ) + next(db) + + .. _corresponding audformat documentation: https://audeering.github.io/audformat/accessing-data.html .. _combine tables: https://audeering.github.io/audformat/combine-tables.html .. _map labels: https://audeering.github.io/audformat/map-scheme.html diff --git a/pyproject.toml b/pyproject.toml index f4f1053e..653f2e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ build-backend = 'setuptools.build_meta' [tool.codespell] builtin = 'clear,rare,informal,usage,names' skip = './audb.egg-info,./build,./docs/api,./docs/_templates,./docs/pics' +ignore-words-list = 'sie,Sie,unter' uri-ignore-words-list = 'ist' diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 00000000..2810dcb0 --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,452 @@ +import os +import typing + +import numpy as np +import pandas as pd +import pytest + +import audeer +import audformat +import audiofile + +import audb + + +def create_audio_files(root: str, index: pd.Index, *, sampling_rate=16_000): + r"""Create audio files. + + Given an index with relative paths, + and a root folder, + it will create audio files + for every entry in the index. + + The audio files have a duration of 1 second, + and have a constant magnitude of 1. + + Args: + root: root folder of audio files + index: name of audio files to create in ``root`` + sampling_rate: sampling rate in Hz + + """ + for file in index: + path = audeer.path(root, file) + signal = np.ones((1, sampling_rate)) + audiofile.write(path, signal, sampling_rate) + + +class TestStreaming: + r"""Test streaming functionality of audb. + + This test tackles ``audb.stream()``, + and its returned class ``audb.DatabaseIterator``. + + """ + + name = "db" + """Name of test database.""" + + version = "1.0.0" + """Version of test database.""" + + seed = 2 + """Seed for random operations.""" + + @classmethod + @pytest.fixture(scope="class", autouse=True, params=["parquet", "csv"]) + def setup(cls, request, tmpdir_factory, persistent_repository): + r"""Publish a database. + + This creates a database, + consisting of 5 audio files, + 2 misc tables (1 used as scheme labels), + and 2 tables (1 fileswise, 1 segmented). + + The database is published to a new repository + for each ``storage_format``, + and all tests are run for a given ``storage_format``. + + Args: + request: request fixture to access params + tmpdir_factory: tmpdir_factory fixture + persistent_repository: persistent_repository fixture + + """ + storage_format = request.param + + db_root = tmpdir_factory.mktemp("build") + + db = audformat.Database(cls.name) + + # Misc table for scheme labels + db.schemes["year-of-birth"] = audformat.Scheme("date") + db["speaker"] = audformat.MiscTable( + pd.Index([0, 1, 2, 3, 4], name="speaker", dtype="Int64") + ) + db["speaker"]["year-of-birth"] = audformat.Column(scheme_id="year-of-birth") + db["speaker"]["year-of-birth"].set([1910, 1920, 1930, 1940, 1950]) + db.schemes["speaker"] = audformat.Scheme("int", labels="speaker") + + # Misc table + db.schemes["full-name"] = audformat.Scheme("str") + db["acronym"] = audformat.MiscTable( + pd.Index(["CCC"], name="acronym", dtype="string") + ) + db["acronym"]["full-name"] = audformat.Column(scheme_id="full-name") + db["acronym"]["full-name"].set(["Concordance Correlation Coefficient"]) + db["acronym"]["speaker"] = audformat.Column(scheme_id="speaker") + db["acronym"]["speaker"].set([0]) + + # Filewise table + db.schemes["quality"] = audformat.Scheme("int", labels=[1, 2, 3]) + index = audformat.filewise_index([f"file{n}.wav" for n in range(5)]) + create_audio_files(db_root, index) + db["files"] = audformat.Table(index) + db["files"]["quality"] = audformat.Column(scheme_id="quality") + db["files"]["quality"].set([1, 1, 2, 2, 3]) + db["files"]["speaker"] = audformat.Column(scheme_id="speaker") + db["files"]["speaker"].set([0, 1, 2, 3, 4]) + + # Segmented table + index = audformat.segmented_index( + files=["file0.wav", "file0.wav", "file1.wav", "file2.wav"], + starts=[0, 0.5, 0.1, 0.1], + ends=[0.4, 0.9, 0.8, 0.9], + ) + db["segments"] = audformat.Table(index) + db["segments"]["noise"] = audformat.Column() + db["segments"]["noise"].set(["low", "low", "low", "high"]) + + db.save(db_root, storage_format=storage_format) + + # Ensure we start with an empty repository + # for each storage format + audeer.rmdir(persistent_repository.host) + audeer.mkdir(persistent_repository.host, persistent_repository.name) + + audb.publish(db_root, cls.version, persistent_repository) + + yield + + # Clean up persistent repository at end + audeer.rmdir(persistent_repository.host) + audeer.mkdir(persistent_repository.host, persistent_repository.name) + + @pytest.mark.parametrize("table", ["files", "segments", "speaker", "acronym"]) + @pytest.mark.parametrize("batch_size", [0, 1, 2, 10]) + def test_batch_size(self, table: str, batch_size: int): + r"""Test table batching. + + If batch size is 0, + no batch should be returned. + If batch size is greater than table length, + a single batch, + containing the whole table, + should be returned. + + Args: + table: table to stream + batch_size: number of table rows per batch + + """ + db = audb.stream( + self.name, + table, + version=self.version, + batch_size=batch_size, + only_metadata=True, + full_path=False, + verbose=False, + ) + batches = [df for df in db] + # Create expected dataframes from original table + expected_df = audb.load_table( + self.name, + table, + version=self.version, + verbose=False, + ) + if batch_size == 0: + assert batches == [] + else: + pd.testing.assert_frame_equal(pd.concat(batches), expected_df) + if batch_size > len(expected_df): + assert len(batches) == 1 + + @pytest.mark.parametrize("table", ["files", "acronym"]) + @pytest.mark.parametrize("batch_size", [0, 1, 2, 10]) + @pytest.mark.parametrize("buffer_size", [0, 1, 2, 10]) + def test_buffer_size(self, table: str, batch_size: int, buffer_size: int): + r"""Test buffer size, when shuffling table batches. + + If batch size is 0 or buffer size is 0, + no batch should be returned. + + Args: + table: table to stream + batch_size: number of table rows per batch + buffer_size: size of buffer to read table, + before shuffling + + """ + np.random.seed(self.seed) + db = audb.stream( + self.name, + table, + version=self.version, + batch_size=batch_size, + buffer_size=buffer_size, + shuffle=True, + only_metadata=True, + full_path=False, + verbose=False, + ) + batches = [df for df in db] + # Create expected dataframes from original table + np.random.seed(self.seed) + expected_df = audb.load_table( + self.name, + table, + version=self.version, + verbose=False, + ) + if batch_size == 0 or buffer_size == 0: + assert batches == [] + else: + df = pd.concat(batches) + # Ensure data is shuffled (besides a buffer size of 1) + if buffer_size == 1 or len(df) == 1: + assert list(df.index) == list(expected_df.index) + else: + assert list(df.index) != list(expected_df.index) + # Ensure all index entries appear in shuffled batches + assert sorted(list(df.index)) == sorted(list(expected_df.index)) + + @pytest.mark.parametrize( + "table, expected_tables, expected_schemes", + [ + ("speaker", ["speaker"], ["year-of-birth"]), + ( + "acronym", + ["acronym", "speaker"], + ["year-of-birth", "full-name", "speaker"], + ), + ("files", ["files", "speaker"], ["year-of-birth", "quality", "speaker"]), + ], + ) + def test_db_cleanup( + self, + table: str, + expected_tables: typing.List, + expected_schemes: typing.List, + ): + r"""Test removal of non-selected tables and schemes. + + The database object (``audb.DatabaseIterator``), + should not contain unneeded tables or schemes. + If a misc table is used as scheme labels + in a scheme of the requested table, + it and its schemes, + will also be part of the database object. + + Args: + table: table to stream + expected_tables: expected tables in database + expected_schemes: expected schemes in table + + """ + db = audb.stream( + self.name, + table, + version=self.version, + only_metadata=True, + verbose=False, + ) + assert sorted(list(db.tables) + list(db.misc_tables)) == sorted(expected_tables) + assert list(db.schemes) == sorted(expected_schemes) + + @pytest.mark.parametrize("full_path", [False, True]) + def test_full_path(self, full_path: bool): + r"""Test full path in tables. + + Args: + full_path: if ``True``, + the path to media files + should start with ``db.root`` + + """ + db = audb.stream( + self.name, + "files", + only_metadata=True, + full_path=full_path, + verbose=False, + ) + df = next(db) + path = df.index[0] + if full_path: + assert path == os.path.join(db.root, "file0.wav") + else: + assert path == "file0.wav" + + @pytest.mark.parametrize( + "table, map", + [ + ("acronym", {"speaker": "year-of-birth"}), + ("acronym", {"speaker": ["year-of-birth", "speaker"]}), + ("acronym", {"speaker": ["speaker", "year-of-birth"]}), + ("files", {"speaker": "year-of-birth"}), + ], + ) + def test_map(self, table: str, map: typing.Dict): + r"""Test mapping of scheme labels. + + Args: + table: table to stream + map: mapping of column with scheme labels + + """ + db = audb.stream( + self.name, + table, + version=self.version, + map=map, + batch_size=16, + only_metadata=True, + verbose=False, + ) + df = next(db) + expected_df = audb.load_table( + self.name, + table, + version=self.version, + map=map, + verbose=False, + ) + pd.testing.assert_frame_equal(df, expected_df) + + @pytest.mark.parametrize( + "only_metadata, table, expected_number_of_media_files", + [ + (True, "files", 5), + (False, "files", 5), + (True, "segments", 3), + (False, "segments", 3), + (True, "acronym", 0), + (False, "acronym", 0), + (True, "speaker", 0), + (False, "speaker", 0), + ], + ) + def test_only_metadata( + self, + only_metadata: bool, + table: str, + expected_number_of_media_files: int, + ): + r"""Test streaming with and without media files. + + Args: + only_metadata: if ``True``, + only the table should be streamed + table: table to stream + expected_number_of_media_files: expected number of downloaded media files + + """ + db = audb.stream( + self.name, + table, + version=self.version, + only_metadata=only_metadata, + verbose=False, + ) + next(db) + assert len(db.files) == expected_number_of_media_files + if not only_metadata: + for file in db.files: + assert os.path.exists(file) + + @pytest.mark.parametrize("shuffle", [False, True]) + @pytest.mark.parametrize("table", ["files", "segments", "speaker", "acronym"]) + def test_shuffle(self, shuffle: bool, table: str): + r"""Test table batch shuffling. + + Args: + shuffle: if returned table rows should be shuffled + table: table to stream + + """ + np.random.seed(self.seed) + db = audb.stream( + self.name, + table, + version=self.version, + batch_size=16, + buffer_size=16, + shuffle=shuffle, + only_metadata=True, + full_path=False, + verbose=False, + ) + df = next(db) + # Create expected dataframe from original table + np.random.seed(self.seed) + expected_df = audb.load_table( + self.name, + table, + version=self.version, + verbose=False, + ) + if shuffle: + expected_df = expected_df.sample(frac=1) + pd.testing.assert_frame_equal(df, expected_df) + + @pytest.mark.parametrize( + "table, error, error_msg", + [ + ( + "non-existent", + ValueError, + "Could not find the table 'non-existent' in db v1.0.0", + ), + ], + ) + def test_errors(self, table: str, error, error_msg: str): + with pytest.raises(error, match=error_msg): + audb.stream( + self.name, + table, + version=self.version, + verbose=False, + ) + + +def test_database_iterator_error(): + db = audformat.Database("db") + db["some"] = audformat.Table() + table = "some" + error_msg = ( + "Can't instantiate abstract class DatabaseIterator " + "with abstract method _initialize_stream" + ) + with pytest.raises(TypeError, match=error_msg): + audb.DatabaseIterator( + db, + table, + version=None, + map=None, + batch_size=16, + shuffle=False, + buffer_size=100_000, + only_metadata=False, + bit_depth=None, + channels=None, + format=None, + mixdown=False, + sampling_rate=None, + full_path=False, + cache_root=None, + num_workers=1, + timeout=-1, + verbose=False, + )