-
Notifications
You must be signed in to change notification settings - Fork 2
/
__init__.py
31 lines (27 loc) · 993 Bytes
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from .classification import ImageClassificationEnv
from .detection import ObjectDetectionEnv
from gym.envs.registration import register
BASENAMES = {
'classification': 'ImageClassificationEnv',
'detection': 'ObjectDetectionEnv'
}
for scenario in ['basic', 'rotation', 'shear', 'hierarchical']:
for dataset in ['cifar10', 'imagenet']:
register(
id='ImageClassificationEnv-{scenario}-{dataset}-v0'.format(scenario=scenario, dataset=dataset),
entry_point='envs:ImageClassificationEnv',
kwargs={
'scenario': scenario,
'dataset': dataset
}
)
for scenario in ['basic', 'rotation', 'shear', 'hierarchical']:
dataset = 'coco'
register(
id='ObjectDetectionEnv-{scenario}-{dataset}-v0'.format(scenario=scenario, dataset=dataset),
entry_point='envs:ObjectDetectionEnv',
kwargs={
'scenario': scenario,
'dataset': dataset
}
)