-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
131 lines (100 loc) · 3.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Author: Simon Andersen
# Date: 2023-04-01
# Description: Main file for the project
# - Modules
import glob
import pickle
import numpy
# - music21 modules
from music21 import converter, instrument, note, chord
# - Keras modules
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.layers import Activation
from keras.layers import BatchNormalization as BatchNorm
from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
# - Main method for running the entire training process
def train_network():
# - Get the notes from the midi files
notes = get_notes()
# - Get the amount of pitch names
n_vocab = len(set(notes))
# - Prepare the sequences used by the Neural Network
network_input, network_output = prepare_sequences(notes, n_vocab)
# - Create the model
model = create_network(network_input, n_vocab)
# - Train the model
train(model, network_input, network_output)
# - Preparation of data of the notes and chords from the midi files
def get_notes():
notes = []
for file in glob.glob("CMajorAMinor/*.mid"):
# Converter is a music21 function that parses each of the midi files
midi = converter.parse(file)
print("Parsing %s" % file)
notes_to_parse = None
try: # file has instrument parts
s2 = instrument.partitionByInstrument(midi)
notes_to_parse = s2.parts[0].recurse()
except:
notes_to_parse = midi.flat.notes
for element in notes_to_parse:
if isinstance(element, note.Note):
notes.append(str(element.pitch))
elif isinstance(element, chord.Chord):
notes.append('.'.join(str(n) for n in element.normalOrder))
with open('data/notes', 'wb') as filepath:
pickle.dump(notes, filepath)
return notes
def prepare_sequences(notes, n_vocab):
sequence_length = 100
pitchnames = sorted(set(item for item in notes))
note_to_int = dict((note, number) for number, note in enumerate(pitchnames))
network_input = []
network_output = []
for i in range(0, len(notes) - sequence_length, 1):
sequence_in = notes[i:i + sequence_length]
sequence_out = notes[i + sequence_length]
network_input.append([note_to_int[char] for char in sequence_in])
network_output.append(note_to_int[sequence_out])
n_patterns = len(network_input)
network_input = numpy.reshape(network_input, (n_patterns, sequence_length, 1))
network_input = network_input / float(n_vocab)
network_output = np_utils.to_categorical(network_output)
return (network_input, network_output)
def create_network(network_input, n_vocab):
model = Sequential()
model.add(LSTM(
512,
input_shape=(network_input.shape[1], network_input.shape[2]),
recurrent_dropout=0.3,
return_sequences=True
))
model.add(LSTM(512, return_sequences=True, recurrent_dropout=0.3,))
model.add(LSTM(512))
model.add(BatchNorm())
model.add(Dropout(0.3))
model.add(Dense(256))
model.add(Activation('relu'))
model.add(BatchNorm())
model.add(Dropout(0.3))
model.add(Dense(n_vocab))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
return model
def train(model, network_input, network_output):
filepath = "weights/weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor = 'loss',
verbose = 0,
save_best_only = True,
mode = 'min'
)
callbacks_list = [checkpoint]
model.fit(network_input, network_output, epochs = 200, batch_size = 128, callbacks = callbacks_list)
if __name__ == '__main__':
train_network()