-
Notifications
You must be signed in to change notification settings - Fork 8
/
config.py
executable file
·127 lines (85 loc) · 2.27 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
''' Config Proto '''
import sys
import os
####### INPUT OUTPUT #######
# me of the current model for output
fold_number = 1
name = 'primate_net_fold_' + str(fold_number)
# The folder to save log and model
log_base_dir = './log/'
# The interval between writing summary
summary_interval = 5
# Dataset
dataset_path = "LemurDataset"
# Cross-validation Parameters
K_CV = 5 # Number of cross-validation folds (training/testing splits)
splits_path = './splits'
#Target image size for the input of network
image_size = [112,112]
# 3 channels means RGB, 1 channel for grayscale
channels = 3
# Resize images before processing, assign as (w,h) or False
resize = (112,112)
# Preprocessing for training
preprocess_train = [
('resize', [(112,112)]),
('random_flip', []),
('standardize', ['deb'])
]
preprocess_test = [
('resize', [(112,112)]),
('standardize', ['deb'])
]
# Number of GPUs
num_gpus = 1
####### NETWORK #######
# Auto alignment network
localization_net = None
# The network architecture
network = "nets/lemur_net.py"
# Model version, only for some networks
model_version = 'lemur'
# Number of dimensions in the embedding space
embedding_size = 512
####### TRAINING STRATEGY #######
# Optimizer
optimizer = "RMSPROP"
# Number of samples per batch
batch_size = 128
# Number of batches per epoch
epoch_size = 80
# Number of epochs
num_epochs = 300
# learning rate strategy
learning_rate_strategy = 'step'
# learning rate schedule
learning_rate_schedule = {
0: 0.01,
#400: 0.01,
#480: 0.001,
#5000: 0.001,
#7000: 0.0001
}
# Multiply the learning rate for variables that contain certain keywords
learning_rate_multipliers = {
'InceptionResnetV2': 0.000,
}
# Build batches with random templates rather than instances
template_batch = False
# Restore model
restore_model = None
# Keywords to filter restore variables, set None for all
restore_scopes = None
# Weight decay for model variables
weight_decay = 5e-4
# Keep probability for dropouts
keep_prob = 1.0
####### LOSS FUNCTION #######
# Scale for the logits
losses = {
#'softmax': {},
#'cosine': {'gamma': 'auto'},
# 'angular': {'m': 4, 'lamb_min':5.0, 'lamb_max':1500.0},
'split': {'gamma': 'auto'}
# 'norm': {'alpha': 1e-5},
}