Skip to content

Commit

Permalink
Dataset object, heavily inspired by the RFC in #219
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachel Bittner committed Oct 20, 2020
1 parent ad87782 commit a7c3d03
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 69 deletions.
37 changes: 19 additions & 18 deletions mirdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@


from .version import version as __version__
from .dataset import Dataset

__all__ = [
'beatles',
'dali',
'giantsteps_tempo',
'giantsteps_key',
'groove_midi',
'gtzan_genre',
'guitarset',
'ikala',
'maestro',
'medley_solos_db',
'medleydb_melody',
'medleydb_pitch',
'orchset',
'rwc_classical',
'rwc_jazz',
'rwc_popular',
'salami',
'tinysol',
"beatles",
"dali",
"giantsteps_tempo",
"giantsteps_key",
"groove_midi",
"gtzan_genre",
"guitarset",
"ikala",
"maestro",
"medley_solos_db",
"medleydb_melody",
"medleydb_pitch",
"orchset",
"rwc_classical",
"rwc_jazz",
"rwc_popular",
"salami",
"tinysol",
]
165 changes: 165 additions & 0 deletions mirdata/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Module containing the Dataset base class
"""
import importlib
import os
import random

from mirdata import download_utils
from mirdata import utils


class Dataset(object):
"""mirdata Dataset object
Usage example:
orchset = mirdata.Dataset('orchset') # get the orchset dataset
orchset.download() # download orchset
orchset.validate() # validate orchset
track = orchset.choice() # load a random track
print(track) # see what data a track contains
orchset.track_ids() # load all track ids
Attributes:
dataset (str): the identifier of the dataset
bibtex (str): dataset citation/s in bibtex format
remotes (dict): data to be downloaded
index (dict): dataset file index
download_info (str): download instructions or caveats
track_object (mirdata.track.Track): an uninstantiated Track object
dataset_dir (str): dataset save folder
readme (str): information about the dataset
data_home (str): path where mirdata will look for the dataset
"""

def __init__(self, dataset, data_home=None):
"""Inits a dataset by name and data location"""
module = importlib.import_module("mirdata.{}".format(dataset))
self.dataset = dataset
self.bibtex = module.bibtex
self.remotes = module.REMOTES
self.index = module.DATA.index
self.download_info = getattr(module, "DOWNLOAD_INFO", None)
self.track_object = getattr(module, "Track", None)
self.dataset_dir = module.DATASET_DIR
self.download_fn = getattr(module, "download", download_utils.downloader)
self.readme = module.__doc__

if data_home is None:
self.data_home = self.default_path
else:
self.data_home = data_home

@property
def default_path(self):
"""Get the default path for the dataset
Returns:
default_path (str): Local path to the dataset
"""
mir_datasets_dir = os.path.join(os.getenv("HOME", "/tmp"), "mir_datasets")
return os.path.join(mir_datasets_dir, self.dataset_dir)

def track(self, track_id):
"""Load a track by track_id
Args:
track_id (str): track id of the track
Returns:
track (dataset.Track): an instance of this dataset's Track object
"""
if self.track_object is None:
raise NotImplementedError
else:
return self.track_object(track_id, self.data_home)

def load_tracks(self):
"""Load all tracks in the dataset
Returns:
(dict): {`track_id`: track data}
Raises:
NotImplementedError: If the dataset does not support Track objects
"""
return {self.track(track_id) for track_id in self.track_ids}

def choice(self):
"""Choose a random track
Returns:
track (dataset.Track): a random Track object
"""
return self.track(random.choice(self.track_ids))

def cite(self):
"""Print the reference"""
# TODO: use pybtex to convert to MLA
print("========== BibTeX ==========")
if isinstance(self.bibtex, str):
print(self.bibtex)
else:
print("\n".join(self.bibtex.values()))

def download(self, force_overwrite=False, cleanup=True, partial_download=None):
"""Download data to `save_dir` and optionally print a message.
Args:
save_dir (str):
The directory to download the data
remotes (dict or None):
A dictionary of RemoteFileMetadata tuples of data in zip format.
If None, there is no data to download
partial_download (list or None):
A list of keys of remotes to partially download.
If None, all data is downloaded
info_message (str or None):
A string of info to print when this function is called.
If None, no string is printed.
force_overwrite (bool):
If True, existing files are overwritten by the downloaded files.
cleanup (bool):
Whether to delete the zip/tar file after extracting.
Raises:
ValueError: if invalid keys are passed to partial_download
IOError: if a downloaded file's checksum is different from expected
"""
self.download_fn(
self.data_home,
remotes=self.remotes,
partial_download=partial_download,
info_message=self.download_info,
force_overwrite=force_overwrite,
cleanup=cleanup,
)

@utils.cached_property
def track_ids(self):
"""Return track ids
Returns:
(list): A list of track ids
"""
return list(self.index.keys())

def validate(self, verbose=True):
"""Validate if the stored dataset is a valid version
Args:
verbose (bool): If False, don't print output
Returns:
missing_files (list): List of file paths that are in the dataset index
but missing locally
invalid_checksums (list): List of file paths that file exists in the dataset
index but has a different checksum compare to the reference checksum
"""
missing_files, invalid_checksums = utils.validator(
self.index, self.data_home, verbose=verbose
)
return missing_files, invalid_checksums

Loading

0 comments on commit a7c3d03

Please sign in to comment.