-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
59 lines (48 loc) · 2.28 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
from types import SimpleNamespace
import wandb # 0.17.1
# Utils
from utils import make, seed_everything
from config import DEVICE
from train import train
default_config = SimpleNamespace(
epochs=20,
classes=1,
batch_size=512,
learning_rate=0.001,
normalize=True,
augmented=False,
optimizer='adam',
dataset="ThermalBreastCancer",
architecture="vit",
fine_tune='classifier',
n_channels=3)
def parse_args():
"Override default argments"
argparser = argparse.ArgumentParser(description="Process hyper-parameters")
argparser.add_argument('--batch_size', type=int, default=default_config.batch_size, help="batch size")
argparser.add_argument('--learning_rate', type=float, default=default_config.learning_rate, help="learning rate")
argparser.add_argument('--optimizer', type=str, default=default_config.optimizer, help="optimizer")
argparser.add_argument('--normalize', type=bool, default=default_config.normalize, help="normalize")
argparser.add_argument('--augmented', type=bool, default=default_config.augmented, help="augmented")
argparser.add_argument('--architecture', type=str, default=default_config.architecture, help="architecture")
argparser.add_argument('--fine_tune', type=str, default=default_config.fine_tune, help="fine_tune")
argparser.add_argument('--n_channels', type=int, default=default_config.n_channels, help="n_channels")
args = argparser.parse_args()
vars(default_config).update(vars(args))
return
def model_pipeline(hyperparameters):
seed_everything()
# tell wandb to get started
with wandb.init(project="hocv-project", entity="ai-uis", config=hyperparameters):
# access all HPs through wandb.config, so logging matches execution!
config = wandb.config
# make the model, data, loss, metrics and optimization problem
model, train_loader, test_loader, criterion, optimizer, accuracy_fn, f1_score_fn, recall_fn, precision_fn, epochs = make(config=config)
# print(model)
# and use them to train the model
train(model, train_loader, test_loader, criterion, optimizer, accuracy_fn, f1_score_fn, recall_fn, precision_fn, epochs, DEVICE)
return model
if __name__ == "__main__":
parse_args()
model_pipeline(default_config)