-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment_helper.py
130 lines (97 loc) · 5.66 KB
/
experiment_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import typing
import wandb
from algorithms.algorithms_utils import AlgorithmsEnum
from algorithms.base_classes import SGDBasedRecommenderAlgorithm, SparseMatrixBasedRecommenderAlgorithm
from algorithms.naive_algs import PopularItems
from algorithms.sgd_alg import DeepMatrixFactorization, ECF
from conf.conf_parser import parse_conf_file, parse_conf, save_yaml
from data.data_utils import DatasetsEnum, get_dataloader
from data.dataset import TrainRecDataset, ECFTrainRecDataset
from eval.eval import evaluate_recommender_algorithm, FullEvaluator
from train.rec_losses import RecommenderSystemLoss, RecommenderSystemLossesEnum
from train.trainer import Trainer
from utilities.utils import reproducible
from wandb_conf import PROJECT_NAME, ENTITY_NAME
def run_train_val(alg: AlgorithmsEnum, dataset: DatasetsEnum, conf: typing.Union[str, dict]):
print('Starting Train-Val')
print(f'Algorithm is {alg.name} - Dataset is {dataset.name}')
if isinstance(conf, str):
conf = parse_conf_file(conf)
conf = parse_conf(conf, alg, dataset)
if conf['running_settings']['use_wandb']:
wandb.init(project=PROJECT_NAME, entity=ENTITY_NAME, config=conf, tags=[alg.name, dataset.name],
group=f'{alg.name} - {dataset.name} - train/val', name=conf['time_run'], job_type='train/val')
reproducible(conf['running_settings']['seed'])
if issubclass(alg.value, SGDBasedRecommenderAlgorithm):
train_loader = get_dataloader(conf, 'train')
val_loader = get_dataloader(conf, 'val')
rec_loss = RecommenderSystemLossesEnum[conf['rec_loss']]
alg = alg.value.build_from_conf(conf, train_loader.dataset)
rec_loss = rec_loss.value.build_from_conf(conf, train_loader.dataset)
trainer = Trainer(alg, train_loader, val_loader, rec_loss, conf)
# Validation happens within the Trainer
metrics_values = trainer.fit()
save_yaml(conf['model_path'], conf)
elif issubclass(alg.value, SparseMatrixBasedRecommenderAlgorithm):
train_dataset = TrainRecDataset(conf['dataset_path'])
val_loader = get_dataloader(conf, 'val')
alg = alg.value.build_from_conf(conf, train_dataset)
# -- Training --
alg.fit(train_dataset.sampling_matrix)
# -- Validation --
evaluator = FullEvaluator(aggr_by_group=True, n_groups=val_loader.dataset.n_user_groups,
user_to_user_group=val_loader.dataset.user_to_user_group)
metrics_values = evaluate_recommender_algorithm(alg, val_loader, evaluator,
verbose=conf['running_settings']['batch_verbose'])
alg.save_model_to_path(conf['model_path'])
save_yaml(conf['model_path'], conf)
if conf['running_settings']['use_wandb']:
wandb.log(metrics_values)
elif alg in [AlgorithmsEnum.rand, AlgorithmsEnum.pop]:
train_dataset = TrainRecDataset(conf['dataset_path'])
val_loader = get_dataloader(conf, 'val')
alg = alg.value.build_from_conf(conf, train_dataset)
evaluator = FullEvaluator(aggr_by_group=True, n_groups=val_loader.dataset.n_user_groups,
user_to_user_group=val_loader.dataset.user_to_user_group)
metrics_values = evaluate_recommender_algorithm(alg, val_loader, evaluator,
verbose=conf['running_settings']['batch_verbose'])
save_yaml(conf['model_path'], conf)
if conf['running_settings']['use_wandb']:
wandb.log(metrics_values)
else:
raise ValueError(f'Training for {alg.value} has been not implemented')
if conf['running_settings']['use_wandb']:
wandb.finish()
return metrics_values, conf
def run_test(alg: AlgorithmsEnum, dataset: DatasetsEnum, conf: typing.Union[str, dict]):
print('Starting Test')
print(f'Algorithm is {alg.name} - Dataset is {dataset.name}')
if isinstance(conf, str):
conf = parse_conf_file(conf)
if conf['running_settings']['use_wandb']:
wandb.init(project=PROJECT_NAME, entity=ENTITY_NAME, config=conf, tags=[alg.name, dataset.name],
group=f'{alg.name} - {dataset.name} - test', name=conf['time_run'], job_type='test', reinit=True)
test_loader = get_dataloader(conf, 'test')
if alg.value == PopularItems or alg.value == DeepMatrixFactorization:
# Popular Items requires the popularity distribution over the items learned over the training data
# DeepMatrixFactorization also requires access to the training data
alg = alg.value.build_from_conf(conf, TrainRecDataset(conf['dataset_path']))
elif alg.value == ECF:
alg = alg.value.build_from_conf(conf, ECFTrainRecDataset(conf['dataset_path']))
else:
alg = alg.value.build_from_conf(conf, test_loader.dataset)
alg.load_model_from_path(conf['model_path'])
evaluator = FullEvaluator(aggr_by_group=True, n_groups=test_loader.dataset.n_user_groups,
user_to_user_group=test_loader.dataset.user_to_user_group)
metrics_values = evaluate_recommender_algorithm(alg, test_loader, evaluator,
verbose=conf['running_settings']['batch_verbose'])
if conf['running_settings']['use_wandb']:
wandb.log(metrics_values, step=0)
wandb.finish()
def run_train_val_test(alg: AlgorithmsEnum, dataset: DatasetsEnum, conf_path: str):
print('Starting Train-Val-Test')
print(f'Algorithm is {alg.name} - Dataset is {dataset.name}')
# ------ Run train and Val ------ #
metrics_values, conf = run_train_val(alg, dataset, conf_path)
# ------ Run test ------ #
run_test(alg, dataset, conf)