-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_sentences.py
127 lines (112 loc) · 4.16 KB
/
train_sentences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import psutil
import torch
from pytorch_trainer import (EarlyStopping, ModelCheckpoint, Trainer,
WandbLogger)
from src.models.lrs2_resnet_attn import LRS2ResnetAttn
from src.models.lrs2_resnet_ctc import LRS2ResnetCTC
from src.models.wlsnet import WLSNet
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data', default="data/datasets/lrs2")
parser.add_argument('--model', default="resnet")
parser.add_argument('--lm_path')
parser.add_argument("--checkpoint_dir", type=str, default='data/checkpoints/lrs2')
parser.add_argument("--checkpoint", type=str)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-5)
parser.add_argument("--workers", type=int, default=None)
parser.add_argument("--resnet", type=int, default=18)
parser.add_argument("--pretrained", default=True, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument("--pretrain", default=False, action='store_true')
parser.add_argument("--use_amp", default=False, action='store_true')
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
args.workers = psutil.cpu_count(logical=False) if args.workers == None else args.workers
args.pretrained = False if args.checkpoint != None else args.pretrained
if args.model == 'resnet':
model = LRS2ResnetAttn(
hparams=args,
in_channels=1,
)
elif args.model == 'wlsnet':
model = WLSNet(
hparams=args,
in_channels=1,
)
else:
model = LRS2ResnetCTC(
hparams=args,
in_channels=1,
augmentations=False,
)
logger = WandbLogger(
project='lrs2',
model=model,
)
model.logger = logger
trainer = Trainer(
seed=args.seed,
logger=logger,
gpu_id=0,
epochs=args.epochs,
use_amp=args.use_amp,
)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")
logger.log('parameters', trainable_params)
logger.log_hyperparams(args)
if args.checkpoint is not None:
model.pretrain = False
logs = trainer.validate(model, checkpoint=args.checkpoint)
logger.log_metrics(logs)
print(f"Initial metrics: {logs}")
if args.pretrain:
train_epochs = args.epochs
model.pretrain = True
print("Pretraining model")
# curriculum with max_sequence_length, max_text_len, number_of_words, epochs
curriculum = [
[32, 16, 1, 30],
[64, 32, 2, 20],
[96, 40, 3, 20],
[120, 48, 4, 20],
[132, 56, 6, 15],
[148, 64, 8, 10],
[148, 72, 10, 10],
]
for part in curriculum:
checkpoint_callback = ModelCheckpoint(
directory=args.checkpoint_dir,
period=part[3],
prefix=f"lrs2_pretrain_{part[2]}",
)
trainer.checkpoint_callback = checkpoint_callback
model.max_timesteps = part[0]
model.max_text_len = part[1]
model.pretrain = True
model.pretrain_words = part[2]
trainer.epochs = part[3]
args.epochs = part[3]
trainer.fit(model)
logger.save_file(checkpoint_callback.last_checkpoint_path)
args.epochs = train_epochs
model.pretrain = False
trainer.validate(model)
print("Pretraining finished")
checkpoint_callback = ModelCheckpoint(
directory=args.checkpoint_dir,
save_best_only=True,
monitor='val_cer',
mode='min',
prefix="lrs2",
)
trainer.checkpoint_callback = checkpoint_callback
model.pretrain = False
model.max_timesteps = 112
model.max_text_len = 100
trainer.epochs = args.epochs
trainer.fit(model)
logger.save_file(checkpoint_callback.last_checkpoint_path)