Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickle_tables argument to load functions #444

Merged
merged 9 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions audb/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def _get_tables_from_backend(
db_root: str,
deps: Dependencies,
backend_interface: typing.Type[audbackend.interface.Base],
pickle_tables: bool,
num_workers: typing.Optional[int],
verbose: bool,
):
Expand All @@ -538,6 +539,12 @@ def _get_tables_from_backend(
db_root: database root
deps: database dependencies
backend_interface: backend interface
pickle_tables: if ``True``,
tables are stored in their original format,
and as pickle files
in the cache.
This allows for faster loading,
when loading from cache
num_workers: number of workers
verbose: if ``True``, show progress bar

Expand Down Expand Up @@ -576,20 +583,24 @@ def job(table: str):
deps.version(table_file),
)

table_files = [table_file]

# Cache table as PKL file
pickle_file = f"db.{table}.pkl"
table_path = os.path.join(db_root_tmp, f"db.{table}")
db[table].load(table_path)
db[table].save(
table_path,
storage_format=audformat.define.TableStorageFormat.PICKLE,
)
if pickle_tables:
pickle_file = f"db.{table}.pkl"
table_path = os.path.join(db_root_tmp, f"db.{table}")
db[table].load(table_path)
db[table].save(
table_path,
storage_format=audformat.define.TableStorageFormat.PICKLE,
)
table_files.append(pickle_file)

# Move tables from tmp folder to database root
for file in [pickle_file, table_file]:
for table_file in table_files:
audeer.move_file(
os.path.join(db_root_tmp, file),
os.path.join(db_root, file),
os.path.join(db_root_tmp, table_file),
os.path.join(db_root, table_file),
)

audeer.run_tasks(
Expand Down Expand Up @@ -692,6 +703,7 @@ def _load_files(
deps: Dependencies,
flavor: Flavor,
cache_root: str,
pickle_tables: bool,
num_workers: int,
verbose: bool,
) -> typing.Optional[CachedVersions]:
Expand Down Expand Up @@ -722,6 +734,12 @@ def _load_files(
deps: database dependency object
flavor: database flavor object
cache_root: root path of cache
pickle_tables: if ``True``,
tables are stored in their original format,
hagenw marked this conversation as resolved.
Show resolved Hide resolved
and as pickle files
in the cache.
This allows for faster loading,
when loading from cache
num_workers: number of workers to use
verbose: if ``True`` show progress bars
for each step
Expand Down Expand Up @@ -778,6 +796,7 @@ def _load_files(
db_root,
deps,
backend_interface,
pickle_tables,
num_workers,
verbose,
)
Expand Down Expand Up @@ -980,6 +999,7 @@ def load(
media: typing.Union[str, typing.Sequence[str]] = None,
removed_media: bool = False,
full_path: bool = True,
pickle_tables: bool = True,
cache_root: str = None,
num_workers: typing.Optional[int] = 1,
timeout: float = -1,
Expand Down Expand Up @@ -1049,6 +1069,12 @@ def load(
misc tables will be empty
removed_media: keep rows that reference removed media
full_path: replace relative with absolute file paths
pickle_tables: if ``True``,
tables are stored in their original format,
and as pickle files
in the cache.
This allows for faster loading,
when loading from cache
cache_root: cache folder where databases are stored.
If not set :meth:`audb.default_cache_root` is used
num_workers: number of parallel jobs or 1 for sequential
Expand Down Expand Up @@ -1180,6 +1206,7 @@ def load(
deps,
flavor,
cache_root,
pickle_tables,
num_workers,
verbose,
)
Expand Down Expand Up @@ -1215,6 +1242,7 @@ def load(
deps,
flavor,
cache_root,
False,
num_workers,
verbose,
)
Expand Down Expand Up @@ -1581,6 +1609,7 @@ def load_media(
deps,
flavor,
cache_root,
False,
num_workers,
verbose,
)
Expand All @@ -1603,6 +1632,7 @@ def load_table(
table: str,
*,
version: str = None,
pickle_tables: bool = True,
cache_root: str = None,
num_workers: typing.Optional[int] = 1,
verbose: bool = True,
Expand All @@ -1621,6 +1651,12 @@ def load_table(
name: name of database
table: load table from database
version: version of database
pickle_tables: if ``True``,
tables are stored in their original format,
and as pickle files
in the cache.
This allows for faster loading,
when loading from cache
cache_root: cache folder where databases are stored.
If not set :meth:`audb.default_cache_root` is used
num_workers: number of parallel jobs or 1 for sequential
Expand Down Expand Up @@ -1685,14 +1721,14 @@ def load_table(

# Load table
tables = _misc_tables_used_in_scheme(db) + [table]
for table in tables:
table_file = os.path.join(db_root, f"db.{table}")
for _table in tables:
table_file = os.path.join(db_root, f"db.{_table}")
if not (
os.path.exists(f"{table_file}.csv")
or os.path.exists(f"{table_file}.pkl")
):
_load_files(
[table],
[_table],
"table",
backend_interface,
db_root,
Expand All @@ -1702,10 +1738,10 @@ def load_table(
deps,
Flavor(),
cache_root,
pickle_tables,
num_workers,
verbose,
)
table = audformat.Table()
table.load(table_file)
db[_table].load(table_file)

return table._df
return db[table]._df
28 changes: 21 additions & 7 deletions audb/core/load_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def load_to(
*,
version: str = None,
only_metadata: bool = False,
pickle_tables: bool = True,
cache_root: str = None,
num_workers: typing.Optional[int] = 1,
verbose: bool = True,
Expand All @@ -331,6 +332,12 @@ def load_to(
name: name of database
version: version string, latest if ``None``
only_metadata: load only header and tables of database
pickle_tables: if ``True``,
tables are stored in their original format,
and as pickle files
in ``root``.
This allows for faster loading,
when loading from ``root``
cache_root: cache folder where databases are stored.
If not set :meth:`audb.default_cache_root` is used.
Only used to read the dependencies of the requested version
Expand Down Expand Up @@ -464,13 +471,20 @@ def load_to(

# save database and PKL tables

db.save(
db_root,
storage_format=audformat.define.TableStorageFormat.PICKLE,
update_other_formats=False,
num_workers=num_workers,
verbose=verbose,
)
if pickle_tables:
hagenw marked this conversation as resolved.
Show resolved Hide resolved
db.save(
db_root,
storage_format=audformat.define.TableStorageFormat.PICKLE,
update_other_formats=False,
num_workers=num_workers,
verbose=verbose,
)
else:
db.save(
db_root,
header_only=True,
verbose=verbose,
)

# remove the temporal directory
# to signal all files were correctly loaded
Expand Down
83 changes: 83 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,89 @@ def test_load_media(cache, version, media, format):
assert paths2 == paths


@pytest.mark.parametrize("pickle_tables", [True, False])
@pytest.mark.parametrize("name, version, table", [(DB_NAME, "1.0.0", "emotion")])
class TestLoadPickle:
r"""Test storing tables as pickle files in cache.

When tables are first downloaded from a backend,
they are stored in their original format in the cache,
and dependent on the ``pickle_tables`` argument
stored as pickle as well.

"""

def assert_table_exist_in_cache(
self, db_root, table, storage_format, pickle_tables
):
"""Assert table exists in original format and maybe as pickle.

Args:
db_root: database root folder
table: table ID
storage_format: storage format of table, ``"csv"`` or ``"parquet"``
pickle_tables: if ``True``,
table is asserted to exist as pickle file as well,
otherwise to not exist

"""
original_table = audeer.path(db_root, f"db.{table}.{storage_format}")
assert os.path.exists(original_table)

pickled_table = audeer.path(db_root, f"db.{table}.pkl")
if pickle_tables:
assert os.path.exists(pickled_table)
else:
assert not os.path.exists(pickled_table)

def test_load_pickle(self, storage_format, name, version, table, pickle_tables):
"""Test storing tables with audb.load().

Args:
storage_format: storage_format fixture
name: database name
version: database version
table: table ID
pickle_tables: if ``True``,
tables are stored as pickle files as well in cache

"""
db = audb.load(
name,
version=version,
tables=table,
pickle_tables=pickle_tables,
only_metadata=True,
verbose=False,
)
self.assert_table_exist_in_cache(db.root, table, storage_format, pickle_tables)

def test_load_table_pickle(
self, cache, storage_format, name, version, table, pickle_tables
):
r"""Test storing tables with audb.load_table().

Args:
cache: cache fixture
storage_format: storage_format fixture
name: database name
version: database version
table: table ID
pickle_tables: if ``True``,
tables are stored as pickle files as well in cache

"""
audb.load_table(
name,
table,
version=version,
pickle_tables=pickle_tables,
verbose=False,
)
db_root = audeer.path(cache, name, version)
self.assert_table_exist_in_cache(db_root, table, storage_format, pickle_tables)


@pytest.mark.parametrize(
"version, table",
[
Expand Down
22 changes: 19 additions & 3 deletions tests/test_publish_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,20 @@ def test_database_load(
db = audb.load(db.name, version=version, verbose=False, full_path=False)
assert_db_saved_to_dir(db, db.root, [storage_format, "pkl"])

def test_updated_database_save(self, build_dir, db, repository, storage_format):
@pytest.mark.parametrize("pickle_tables", [True, False])
def test_updated_database_save(
self, build_dir, db, repository, storage_format, pickle_tables
):
r"""Test correct files are stored to build dir after database update.

Args:
build_dir: build dir fixture
db: database fixture
repository: repository fixture
storage_format: storage format of database tables
pickle_tables: if ``True``,
``audb.load_to()`` should store tables
as pickle files as well

"""
# Publish first version
Expand All @@ -329,12 +335,22 @@ def test_updated_database_save(self, build_dir, db, repository, storage_format):

# Clear build dir to force audb.load_to() to load from backend
audeer.rmdir(build_dir)
db = audb.load_to(build_dir, db.name, version=version, verbose=False)
db = audb.load_to(
build_dir,
db.name,
version=version,
pickle_tables=pickle_tables,
verbose=False,
)

# Update database
db = update_db(db)
db.save(build_dir, storage_format=storage_format)
assert_db_saved_to_dir(db, db.root, [storage_format, "pkl"])
if pickle_tables:
expected_formats = [storage_format, "pkl"]
else:
expected_formats = [storage_format]
assert_db_saved_to_dir(db, db.root, expected_formats)

def test_updated_database_publish(
self,
Expand Down
Loading