In the previous tutorial we defined a neural network. Now let's train it on the actual data.
Before you can use data, typically some preprocessing is required. In this case, it's minimal – all the images are already the same size and are grayscale. With the built-in functionality, we can convert the Fashion MNIST image archives into a dataset object that we can use for model training.
val (train, test) = fashionMnist()
You may also notice that we are splitting the data into three sets.
We have the test set, which we won't be touching until we are satisfied with the model and want to confirm its performance on unseen data.
And we have the train set which we will split into newTrain
and validation
sets. We'll use these during the training and validation process.
Now everything is ready to train the model. Use the fit()
method for this:
model.use {
it.compile(
optimizer = Adam(),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.summary()
// You can think of the training process as "fitting" the model to describe the given data :)
it.fit(
dataset = train,
epochs = 10,
batchSize = 100
)
val accuracy = it.evaluate(dataset = test, batchSize = 100).metrics[Metrics.ACCURACY]
println("Accuracy: $accuracy")
it.save(File("src/model/my_model"))
}
Here are some important parameters that we need to pass to the fit()
method:
epochs
- Number of iterations over the data you want the training process to perform. Epoch = iteration.batchSize
- How many examples will be used for updating the model's parameters (aka weights and biases) at a time.
After the model has been trained, it's important to evaluate its performance on the validation dataset, so that we can check how it generalizes to the new data.
val accuracy = it.evaluate(dataset = test,
batchSize = 100).metrics[Metrics.ACCURACY]
println("Accuracy: $accuracy")
Accuracy: 0.8821001648902893
NOTE
The results are nondeterministic, and you may have a slightly different Accuracy value.
When we are happy with the model's evaluation metric, we can save the model for future use in the production environment.
it.save(File("model/my_model"), writingMode = WritingMode.OVERRIDE)
And just like that, we have trained, evaluated, and saved a deep learning model that we can now use to generate predictions (aka inference). In the next tutorial, you'll learn how to load and use the model.