-
Notifications
You must be signed in to change notification settings - Fork 67
/
config_load.py
84 lines (76 loc) · 2.51 KB
/
config_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Temporary file for testing loading from config without modifying armory.utils
"""
from armory.datasets import load, preprocessing, generator, filtering
def load_dataset(
name=None,
version=None,
config=None,
batch_size=1,
num_batches=None,
epochs=1,
split="test",
framework="numpy",
preprocessor_name="DEFAULT",
preprocessor_kwargs=None,
shuffle_files=False,
label_key="label", # TODO: make this smarter or more flexible
index=None,
class_ids=None,
drop_remainder=False,
):
# All are keyword elements by design
if name is None:
raise ValueError("name must be specified, not None")
info, ds_dict = load.load(
name, version=version, config=config, shuffle_files=shuffle_files
)
if class_ids is None:
element_filter = None
else:
if isinstance(class_ids, int):
class_ids = [class_ids]
if not isinstance(class_ids, list):
raise ValueError(
f"class_ids must be a list, int, or None, not {type(class_ids)}"
)
element_filter = filtering.get_filter_by_class(class_ids, label_key=label_key)
if index is None:
index_filter = None
elif isinstance(index, list):
index_filter = filtering.get_enum_filter_by_index(index)
elif isinstance(index, str):
index_filter = filtering.get_enum_filter_by_slice(index)
else:
raise ValueError(f"index must be a list, str, or None, not {type(index)}")
if preprocessor_name is None:
preprocessor = None
elif preprocessor_name == "DEFAULT":
if preprocessing.has(name):
preprocessor = preprocessing.get(name)
else:
preprocessor = preprocessing.infer_from_dataset_info(info, split)
else:
preprocessor = preprocessing.get(preprocessor_name)
if preprocessor is not None and preprocessor_kwargs is not None:
preprocessing_fn = lambda x: preprocessor( # noqa: E731
x, **preprocessor_kwargs
)
else:
preprocessing_fn = preprocessor
shuffle_elements = shuffle_files
return generator.ArmoryDataGenerator(
info,
ds_dict,
split=split,
batch_size=batch_size,
framework=framework,
epochs=epochs,
drop_remainder=drop_remainder,
num_batches=num_batches,
index_filter=index_filter,
element_filter=element_filter,
element_map=preprocessing_fn,
shuffle_elements=shuffle_elements,
key_map=None,
)