-
Notifications
You must be signed in to change notification settings - Fork 31
/
train.py
60 lines (49 loc) · 1.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard
from data import load_data, tf_dataset
from model import build_model
def iou(y_true, y_pred):
def f(y_true, y_pred):
intersection = (y_true * y_pred).sum()
union = y_true.sum() + y_pred.sum() - intersection
x = (intersection + 1e-15) / (union + 1e-15)
x = x.astype(np.float32)
return x
return tf.numpy_function(f, [y_true, y_pred], tf.float32)
if __name__ == "__main__":
## Dataset
path = "CVC-612/"
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)
## Hyperparameters
batch = 8
lr = 1e-4
epochs = 20
train_dataset = tf_dataset(train_x, train_y, batch=batch)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch)
model = build_model()
opt = tf.keras.optimizers.Adam(lr)
metrics = ["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision(), iou]
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=metrics)
callbacks = [
ModelCheckpoint("files/model.h5"),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
CSVLogger("files/data.csv"),
TensorBoard(),
EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=False)
]
train_steps = len(train_x)//batch
valid_steps = len(valid_x)//batch
if len(train_x) % batch != 0:
train_steps += 1
if len(valid_x) % batch != 0:
valid_steps += 1
model.fit(train_dataset,
validation_data=valid_dataset,
epochs=epochs,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
callbacks=callbacks)