-
Notifications
You must be signed in to change notification settings - Fork 51
/
evaluate.py
32 lines (24 loc) · 1.13 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from modelvshuman import Plot, Evaluate
from modelvshuman import constants as c
from plotting_definition import plotting_definition_template
def run_evaluation():
models = ["resnet50", "bagnet33", "simclr_resnet50x1"]
datasets = c.DEFAULT_DATASETS # or e.g. ["cue-conflict", "uniform-noise"]
params = {"batch_size": 64, "print_predictions": True, "num_workers": 20}
Evaluate()(models, datasets, **params)
def run_plotting():
plot_types = c.DEFAULT_PLOT_TYPES # or e.g. ["accuracy", "shape-bias"]
plotting_def = plotting_definition_template
figure_dirname = "example-figures/"
Plot(plot_types = plot_types, plotting_definition = plotting_def,
figure_directory_name = figure_dirname)
# In examples/plotting_definition.py, you can edit
# plotting_definition_template as desired: this will let
# the toolbox know which models to plot, and which colours to use etc.
if __name__ == "__main__":
# 1. evaluate models on out-of-distribution datasets
run_evaluation()
# 2. plot the evaluation results
run_plotting()