-
Notifications
You must be signed in to change notification settings - Fork 19
/
visualize_results.py
74 lines (71 loc) · 2.52 KB
/
visualize_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
This file does the following
1. plot the accuracy difference in accuracy (and cost) between different models more easily
2. This is good for comparing the accuracy of different models, but no good to explore deeper on the specific metrics
"""
import os.path
from typing import Union
import pandas as pd
from jsonargparse import ActionConfigFile, ArgumentParser
from evaluation.eval import EvaluationResultCollection
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--input-file",
type=str,
help="input file path to saved EvaluationCollection",
)
parser.add_argument(
"--data-path",
type=str,
help="input file path to the original input data file, that has a `model_name` column with the model names",
)
parser.add_argument(
"--wanted-eval-names",
type=Union[str, list[str]],
help="evaluation name to plot, if None, plots all evals",
default=None,
)
parser.add_argument(
"--data-name",
type=str,
default="default",
help="The name to use when saving in a subfolder",
)
parser.add_argument(
"--config", action=ActionConfigFile, help="path to the configuration file"
)
args = parser.parse_args()
save_path = f"data/{args.data_name}/analysis_results/"
if not os.path.exists(save_path):
os.makedirs(save_path)
# save the config to the data folder
parser.save(
args,
f"data/{args.data_name}/visualize_results.yaml",
format="yaml",
overwrite=True,
)
# Load the EvaluationResultCollection
evaluation_collection = EvaluationResultCollection.load(args.input_file)
# Get all the router names in the collection
router_names = evaluation_collection.get_router_names()
if args.data_path.endswith(".csv"):
dataset_df = pd.read_csv(args.data_path)
else:
dataset_df = pd.read_pickle(args.data_path)
if args.wanted_eval_names is None:
eval_names_to_plot = evaluation_collection.get_eval_names()
else:
if isinstance(args.wanted_eval_names, str):
eval_names_to_plot = [args.wanted_eval_names]
else:
eval_names_to_plot = args.wanted_eval_names
for name in eval_names_to_plot:
evaluation_collection.plot_performance_vs_cost(
save_file_path=save_path,
show_plot=True,
eval_name=name,
plot_ndch=True,
models_to_route=dataset_df.model_name.unique().tolist(),
)