-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
80 lines (78 loc) · 3.5 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
base = './dataset/digital_music/'
gmf_config = {'num_epoch':85,
'batch_size': 2048,
#'optimizer': 'sgd',
'sgd_lr': 1e-2,
'optimizer': 'adam',
'adam_lr': 7e-3,
'num_users': 0,
'num_items': 0,
'latent_dim_gmf': 128,
'problem' : 'topk' , #topk or prediction
'k' : 10,
'num_negative': 4,
'l2_regularization': 0 ,
#'train_path': base + 'digital_music.train.rating' ,
#'val_path' : base + 'digital_music.valid.rating',
#'test_path' : base + 'digital_music.test.rating' ,
'train_path': './dataset/ml-1m.train.rating' ,
'val_path' : None,
'test_path' : './dataset/ml-1m.test.negative' ,
'model_dir':'checkpoints/new_prediction_gmf.model' ,
'pretrained' : False ,
'pretrained_gmf_dir' : 'checkpoints/new_prediction_gmf.model' ,
'val_batch_size' : 2048
}
mlp_config = {'num_epoch': 50,
'batch_size': 2048,
'val_batch_size' : 2048,
#'optimizer': 'sgd',
'sgd_lr': 5e-3,
'optimizer': 'adam',
'adam_lr': 5e-3,
'num_users': 6040,
'num_items': 3706,
'num_layers_mlp' : 2,
'latent_dim_mlp': 16,
'problem' : 'topk' , #topk or prediction
'num_negative': 4, #3
'k' : 10,
'l2_regularization': 0 , # 0.0000001, # MLP model is sensitive to hyper params
# 'train_path' : base + 'digital_music.train.rating' ,
# 'val_path' : base + 'digital_music.valid.rating',
# 'test_path' : base + 'digital_music.test.rating' ,
'train_path': './dataset/ml-1m.train.rating' ,
'val_path' : None,
'test_path' : './dataset/ml-1m.test.negative' ,
'pretrained' : False,
'pretrained_mlp_dir': None,
'model_dir':'checkpoints/new/new_topk_mlp.model'}
neumf_config = {'num_epoch': 60,
'batch_size': 128,
'val_batch_size' : 128,
'num_layers_mlp' : 2,
#'optimizer': 'sgd',
'sgd_lr': 1e-3,
'optimizer': 'adam',
'adam_lr': 1e-4,
'num_users': 6040,
'num_items': 3706,
'problem' : 'prediction' , #'topk' ,
'latent_dim_gmf': 16,
'latent_dim_mlp': 32,
'num_negative': 4,
'k' : 10,
'l2_regularization': 1e-3,
'gmf_out_dim' : 8 ,
'mlp_out_dim' : 8 ,
'pretrained': False,
'pretrained_gmf_dir': 'checkpoints/prediction_gmf.model',
'pretrained_mlp_dir': 'checkpoints/prediction_mlp.model',
'model_dir':'checkpoints/topk_ncf.model',
'train_path' : base + 'digital_music.train.rating' ,
'val_path' : base + 'digital_music.valid.rating',
'test_path' : base + 'digital_music.test.rating' ,
# 'train_path': './dataset/ml-1m.train.rating' ,
# 'val_path' : None,
# 'test_path' : './dataset/ml-1m.test.negative' ,
}