-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
119 lines (81 loc) · 2.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import os.path
import math
import csv
import re
from keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
from PIL import Image
from s3 import Client
from models import cnn
s3 = Client()
BATCH_SIZE = 64
EPOCHS = 50
EARLY_STOPPING_PATIENCE = 5
MODEL_NAME = 'model.hdf5'
def load_labels():
global labels
labels = {}
with open('labels.csv') as f:
for line in csv.reader(f):
labels[line[0]] = line[1]
def generate(ids):
while 1:
indexes = np.arange(len(ids))
np.random.shuffle(indexes)
batches = int((len(indexes))/BATCH_SIZE)
for i in range(batches):
batch_start = i * BATCH_SIZE
batch_end = (i+1) * BATCH_SIZE
ids_to_load = [ids[k] for k in indexes[batch_start:batch_end]]
yield loader(ids_to_load)
def loader(ids):
x = np.empty((BATCH_SIZE, 1, 128, 1291))
y = np.empty((BATCH_SIZE), dtype=int)
for i, id in enumerate(ids):
x[i, ] = get_audio_slice(id)
y[i] = labels[get_id_from_slice(id)]
return x, y
def get_id_from_slice(slice):
return re.search(r'/(?<=/)(.*)(?=-[0-9]+.png)', slice).group(1)
def get_audio_slice(slice):
if not os.path.exists(slice):
s3.download(slice)
img = Image.open(slice)
image_data = np.asarray(img, dtype=np.uint8)
return image_data/255.
class UploadCheckpoint(Callback):
def __init__(self):
self.last_change = None
self.path_local = 'models/' + MODEL_NAME
def on_epoch_end(self, *args):
if os.path.getmtime(self.path_local) != self.last_change:
s3.upload(self.path_local)
self.last_change = os.path.getmtime(self.path_local)
def segment(slices, ids):
return [slice for slice in slices if get_id_from_slice(slice) in ids]
def train():
slices = s3.list_slices()
ids = list(labels.keys())
np.random.shuffle(ids)
split = int(math.ceil(len(ids) * 0.7))
train = segment(slices, ids[:split])
test = segment(slices, ids[split:])
training_generator = generate(train)
validation_generator = generate(test)
s3_upload = UploadCheckpoint()
early_stop = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
filepath = 'models/' + MODEL_NAME
checkpoint = ModelCheckpoint(
filepath, monitor='val_acc', save_best_only=True, mode='max')
callbacks_list = [checkpoint, s3_upload, early_stop]
model = cnn()
model.fit_generator(
generator=training_generator,
steps_per_epoch=len(train)//BATCH_SIZE,
validation_data=validation_generator,
validation_steps=len(test)//BATCH_SIZE,
epochs=EPOCHS,
callbacks=callbacks_list)
if __name__ == "__main__":
load_labels()
train()