-
Notifications
You must be signed in to change notification settings - Fork 59
/
test_loaders.py
532 lines (440 loc) · 18.3 KB
/
test_loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
import importlib
import inspect
import io
import os
import sys
import pytest
import requests
import mirdata
from mirdata import core
from tests.test_utils import get_attributes_and_properties
DATASETS = mirdata.DATASETS
CUSTOM_TEST_TRACKS = {
"beatles": "0111",
"cante100": "008",
"compmusic_jingju_acappella": "lseh-Tan_Yang_jia-Hong_yang_dong-qm",
"compmusic_otmm_makam": "cafcdeaf-e966-4ff0-84fb-f660d2b68365",
"giantsteps_key": "3",
"dali": "4b196e6c99574dd49ad00d56e132712b",
"da_tacos": "coveranalysis#W_163992#P_547131",
"freesound_one_shot_percussive_sounds": "183",
"giantsteps_tempo": "113",
"gtzan_genre": "country.00000",
"guitarset": "03_BN3-119-G_solo",
"irmas": "1",
"medley_solos_db": "d07b1fc0-567d-52c2-fef4-239f31c9d40e",
"medleydb_melody": "MusicDelta_Beethoven",
"mridangam_stroke": "224030",
"rwc_classical": "RM-C003",
"rwc_jazz": "RM-J004",
"rwc_popular": "RM-P001",
"salami": "2",
"saraga_carnatic": "116_Bhuvini_Dasudane",
"saraga_hindustani": "59_Bairagi",
"tinysol": "Fl-ord-C4-mf-N-T14d",
"dagstuhl_choirset": "DCS_LI_QuartetB_Take04_B2",
"tonas": "01-D_AMairena",
}
TEST_DATA_HOME = "tests/resources/mir_datasets"
def test_dataset_attributes():
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
assert (
dataset.name == dataset_name
), "{}.dataset attribute does not match dataset name".format(dataset_name)
assert (
dataset.bibtex is not None
), "No BIBTEX information provided for {}".format(dataset_name)
assert (
dataset._license_info is not None
), "No LICENSE information provided for {}".format(dataset_name)
assert (
isinstance(dataset.remotes, dict) or dataset.remotes is None
), "{}.REMOTES must be a dictionary".format(dataset_name)
assert isinstance(dataset._index, dict), "{}.DATA is not properly set".format(
dataset_name
)
assert (
isinstance(dataset._download_info, str) or dataset._download_info is None
), "{}.DOWNLOAD_INFO must be a string".format(dataset_name)
assert type(dataset._track_class) == type(
core.Track
), "{}.Track must be an instance of core.Track".format(dataset_name)
assert callable(dataset.download), "{}.download is not a function".format(
dataset_name
)
def test_smart_open():
for dataset_name in DATASETS:
# import module
dataset_module = importlib.import_module(f"mirdata.datasets.{dataset_name}")
code_lines = inspect.getsource(dataset_module)
# check if os.path.exists is called:
assert "os.path.exists" not in code_lines, (
f"{dataset_name} uses os.path.exists, "
+ "which does not work with remote filesystems. Instead use a try/except. "
+ "See orchset.py for an example."
)
# if open is called, make sure smart open is imported
not_overwritten_message = (
f"{dataset_name} uses open, but does not overwrite it with smart_open. "
+ "Add the line `from smart_open import open` to the top of the module."
)
overwritten_wrong_message = (
f"{dataset_name} overwrites open using something other than smart_open. "
+ "Add the line `from smart_open import open` to the top of the module."
)
if "open(" in code_lines:
assert hasattr(dataset_module, "open"), not_overwritten_message
assert (
dataset_module.open.__module__ == "smart_open.smart_open_lib"
), overwritten_wrong_message
def test_cite_and_license():
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
text_trap = io.StringIO()
sys.stdout = text_trap
dataset.cite()
sys.stdout = sys.__stdout__
text_trap = io.StringIO()
sys.stdout = text_trap
dataset.license()
sys.stdout = sys.__stdout__
KNOWN_ISSUES = {} # key is module, value is REMOTE key
DOWNLOAD_EXCEPTIONS = ["maestro", "slakh", "gtzan_genre"]
def test_download(mocker):
for dataset_name in DATASETS:
print(dataset_name)
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
# test parameters & defaults
assert callable(dataset.download), "{}.download is not callable".format(
dataset_name
)
params = inspect.signature(dataset.download).parameters
expected_params = [
("partial_download", None),
("force_overwrite", False),
("cleanup", False),
]
for exp in expected_params:
assert exp[0] in params, "{}.download must have {} as a parameter".format(
dataset_name, exp[0]
)
assert (
params[exp[0]].default == exp[1]
), "The default value of {} in {}.download must be {}".format(
dataset_name, exp[0], exp[1]
)
# check that the download method can be called without errors
if dataset.remotes != {}:
mock_downloader = mocker.patch.object(dataset, "remotes")
if dataset_name not in DOWNLOAD_EXCEPTIONS:
try:
dataset.download()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
mocker.resetall()
# check that links are online
for key in dataset.remotes:
# skip this test if it's in known issues
if dataset_name in KNOWN_ISSUES and key in KNOWN_ISSUES[dataset_name]:
continue
url = dataset.remotes[key].url
try:
request = requests.head(url)
assert request.ok, "Link {} for {} does not return OK".format(
url, dataset_name
)
except requests.exceptions.ConnectionError:
assert False, "Link {} for {} is unreachable".format(
url, dataset_name
)
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
else:
try:
dataset.download()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
# This is magically skipped by the the remote fixture `skip_local` in conftest.py
# when tests are run with the --local flag
def test_validate(skip_local):
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
try:
dataset.validate()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
try:
dataset.validate(verbose=False)
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
def test_load_and_trackids():
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
try:
track_ids = dataset.track_ids
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert type(track_ids) is list, "{}.track_ids() should return a list".format(
dataset_name
)
trackid_len = len(track_ids)
# if the dataset has tracks, test the loaders
if dataset._track_class is not None:
try:
choice_track = dataset.choice_track()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert isinstance(
choice_track, core.Track
), "{}.choice_track must return an instance of type core.Track".format(
dataset_name
)
try:
dataset_data = dataset.load_tracks()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert isinstance(
dataset_data, dict
), "{}.load should return a dictionary".format(dataset_name)
assert len(dataset_data.keys()) == trackid_len, (
"the dictionary returned {}.load() does not have the same number of elements as"
" {}.track_ids()".format(dataset_name, dataset_name)
)
def test_track():
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
# if the dataset doesn't have a track object, make sure it raises a value error
# and move on to the next dataset
if dataset._track_class is None:
with pytest.raises(NotImplementedError):
dataset.track("~faketrackid~?!")
continue
if dataset_name in CUSTOM_TEST_TRACKS:
trackid = CUSTOM_TEST_TRACKS[dataset_name]
else:
trackid = dataset.track_ids[0]
# test data home specified
try:
track_test = dataset.track(trackid)
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert isinstance(
track_test, core.Track
), "{}.track must be an instance of type core.Track".format(dataset_name)
assert hasattr(
track_test, "to_jams"
), "{}.track must have a to_jams method".format(dataset_name)
# test calling all attributes, properties and cached properties
track_data = get_attributes_and_properties(track_test)
for attr in track_data["attributes"]:
ret = getattr(track_test, attr)
for prop in track_data["properties"]:
ret = getattr(track_test, prop)
for cprop in track_data["cached_properties"]:
ret = getattr(track_test, cprop)
# Validate JSON schema
try:
jam = track_test.to_jams()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert jam.validate(), "Jams validation failed for {}.track({})".format(
dataset_name, trackid
)
# will fail if something goes wrong with __repr__
try:
text_trap = io.StringIO()
sys.stdout = text_trap
print(track_test)
sys.stdout = sys.__stdout__
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
with pytest.raises(ValueError):
dataset.track("~faketrackid~?!")
# This tests the case where there is no data in data_home.
# It makes sure that the track can be initialized and the
# attributes accessed, but that anything requiring data
# files errors (all properties and cached properties).
def test_track_placeholder_case():
data_home_dir = "not/a/real/path"
for dataset_name in DATASETS:
print(dataset_name)
data_home = os.path.join(data_home_dir, dataset_name)
dataset = mirdata.initialize(dataset_name, data_home, version="test")
if not dataset._track_class:
continue
if dataset_name in CUSTOM_TEST_TRACKS:
trackid = CUSTOM_TEST_TRACKS[dataset_name]
else:
trackid = dataset.track_ids[0]
try:
track_test = dataset.track(trackid)
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
track_data = get_attributes_and_properties(track_test)
for attr in track_data["attributes"]:
ret = getattr(track_test, attr)
for prop in track_data["properties"]:
with pytest.raises(Exception):
ret = getattr(track_test, prop)
for cprop in track_data["cached_properties"]:
with pytest.raises(Exception):
ret = getattr(track_test, cprop)
# for load_* functions which require more than one argument
# module_name : {function_name: {parameter2: value, parameter3: value}}
EXCEPTIONS = {
"dali": {"load_annotations_granularity": {"granularity": "notes"}},
"guitarset": {
"load_pitch_contour": {"string_num": 1},
"load_notes": {"string_num": 1},
"load_chords": {"leadsheet_version": False},
},
"tonas": {"load_f0": {"corrected": True}},
}
SKIP = {
"acousticbrainz_genre": [
"load_all_train",
"load_all_validation",
"load_tagtraum_validation",
"load_tagtraum_train",
"load_allmusic_train",
"load_allmusic_validation",
"load_lastfm_train",
"load_lastfm_validation",
"load_discogs_train",
"load_discogs_validation",
]
}
def test_load_methods():
for dataset_name in DATASETS:
dataset_module = importlib.import_module(f"mirdata.datasets.{dataset_name}")
all_methods = dir(dataset_module)
load_methods = [
getattr(dataset_module, m) for m in all_methods if m.startswith("load_")
]
for load_method in load_methods:
method_name = load_method.__name__
# skip overrides, add to the SKIP dictionary to skip a specific load method
if dataset_name in SKIP and method_name in SKIP[dataset_name]:
continue
if load_method.__doc__ is None:
raise ValueError(
"mirdata.datasets.{}.Dataset.{} has no documentation".format(
dataset_name, method_name
)
)
# add to the EXCEPTIONS dictionary above if your load_* function needs
# more than one argument.
if dataset_name in EXCEPTIONS and method_name in EXCEPTIONS[dataset_name]:
extra_params = EXCEPTIONS[dataset_name][method_name]
with pytest.raises(IOError):
load_method("a/fake/filepath", **extra_params)
else:
with pytest.raises(IOError):
load_method("a/fake/filepath")
CUSTOM_TEST_MTRACKS = {}
def test_multitracks():
data_home_dir = "tests/resources/mir_datasets"
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
# TODO this is currently an opt-in test. Make it an opt out test
# once #265 is addressed
if dataset_name in CUSTOM_TEST_MTRACKS:
mtrack_id = CUSTOM_TEST_MTRACKS[dataset_name]
else:
# there are no multitracks
continue
try:
mtrack_default = dataset.MultiTrack(mtrack_id)
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
# test data home specified
data_home = os.path.join(data_home_dir, dataset_name)
dataset_specific = mirdata.initialize(dataset_name, data_home=data_home)
try:
mtrack_test = dataset_specific.MultiTrack(mtrack_id, data_home=data_home)
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert isinstance(
mtrack_test, core.MultiTrack
), "{}.MultiTrack must be an instance of type core.MultiTrack".format(
dataset_name
)
assert hasattr(
mtrack_test, "to_jams"
), "{}.MultiTrack must have a to_jams method".format(dataset_name)
# Validate JSON schema
try:
jam = mtrack_test.to_jams()
except:
assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
assert jam.validate(), "Jams validation failed for {}.MultiTrack({})".format(
dataset_name, mtrack_id
)
def test_random_splits():
split = [0.9, 0.1]
for dataset_name in DATASETS:
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
# check wrong type of split function
if dataset._track_class is None:
with pytest.raises(AttributeError):
dataset.get_random_track_splits(split)
if dataset._multitrack_class is None:
with pytest.raises(AttributeError):
dataset.get_random_mtrack_splits(split)
# check splits for tracks
if dataset._track_class:
splits = dataset.get_random_track_splits(split)
assert len(dataset.track_ids) == sum([len(i) for i in splits.values()])
# check splits for multitracks
if dataset._multitrack_class:
splits = dataset.get_random_mtrack_splits(split)
assert len(dataset.mtrack_ids) == sum([len(i) for i in splits.values()])
def test_predetermined_splits():
required_track = ["irmas", "mtg_jamendo_autotagging_moodtheme", "slakh", "tinysol"]
required_mtrack = ["slakh"]
for dataset_name in DATASETS:
print(dataset_name)
dataset = mirdata.initialize(
dataset_name, os.path.join(TEST_DATA_HOME, dataset_name), version="test"
)
# test custom get_track_splits functions
try:
splits = dataset.get_track_splits()
assert isinstance(splits, dict)
used_tracks = set()
for k in splits:
assert all([t in dataset.track_ids for t in splits[k]])
this_split = set(splits[k])
assert not used_tracks.intersection(this_split)
used_tracks.update(this_split)
except (AttributeError, NotImplementedError):
assert dataset_name not in required_track
# test custom get_mtrack_splits functions
try:
splits = dataset.get_mtrack_splits()
assert isinstance(splits, dict)
used_tracks = set()
for k in splits:
assert all([t in dataset.mtrack_ids for t in splits[k]])
this_split = set(splits[k])
assert not used_tracks.intersection(this_split)
used_tracks.update(this_split)
except (AttributeError, NotImplementedError):
assert dataset_name not in required_mtrack