This tutorial follows the lines of the PyTorch Transfer Learning Tutorial.
We will use transfer learning to leverage a pretrained ResNet model on a small dataset. This dataset is made of images of ants and bees that we want to classify, there are roughly 240 training images and 150 validation images each of them of size 224x224. The dataset is so small that training even the simplest convolutional neural network on it would be very difficult.
Instead the original tutorial proposes two alternatives to train the classifier.
- Finetuning the pretrained model. We start from a ResNet-18 model pretrained on ImageNet 1000 categories, replace the last layer by a binary classifier and train the resulting model as usual.
- Using a pretrained model as feature extractor. The pretrained model weights are frozen and we run this model and store the outputs of the last layer before the final classifier. We then train a binary classifier on the resulting features.
We will focus on the second alternative but first we need to get the code building and running and we also have to download the dataset and pretrained weights.
Run the following commands to download the latest tch-rs version and run the tests, this installs the CPU version of libtorch if necessary.
git clone https://github.com/LaurentMazare/tch-rs.git
cd tch-rs
cargo test
The ants and bees dataset can be downloaded here. You can download the weights for a ResNet-18 network pretrained on ImageNet, resnet18.ot.
Once this is done and the dataset has been extracted we can build and run the code with:
cargo run --example transfer-learning resnet18.ot hymenoptera_data
Let us now have a look at the code from main.rs
.
The dataset is loaded via some helper functions.
let dataset = imagenet::load_from_dir(dataset_dir)?;
println!("{:?}", dataset);
The println!
macro prints the dimensions of the tensors that have
been created. For training the tensor has shape 211x3x224x224
, this
corresponds to 211 images of height and width both 224 with 3 channels
(PyTorch uses the NCHW ordering for image data). The testing image
tensor has dimensions 127
so there are 127 images with the
same size as used in training.
The pixel data from the dataset is converted to features by running a pre-trained ResNet model.
let mut vs = tch::nn::VarStore::new(tch::Device::Cpu);
let net = resnet::resnet18_no_final_layer(&vs.root());
vs.load(weights)?;
let train_images = tch::no_grad(|| dataset.train_images.apply_t(&net, false));
let test_images = tch::no_grad(|| dataset.test_images.apply_t(&net, false));
This snippet performs the following steps:
- A variable store
vs
is created to hold the network weights. - A ResNet-18 model is created using this variable store. At this point the model weights are randomly initialized.
vs.load(weights)
loads the weights stored in a given file and copy their values to some tensors. Tensors are named in the serialized file in a way that matches the names we used when creating the ResNet model.- Finally for each tensor of the training and testing datasets,
apply
performs a forward pass on the model and returns the resulting tensor. In this case the result is a vector of 512 values per sample. Theno_grad
closure informs PyTorch that there is no need to keep a graph of the forward pass as we do not plan on asking for gradients.
Now that we have precomputed the output of the ResNet model on our training and testing images we will train a linear binary classifier to recognize ants vs bees.
We start by defining a model, for this we need a variable store to hold the trainable variables.
let vs = tch::nn::VarStore::new(tch::Device::Cpu);
let linear = nn::linear(vs.root(), 512, dataset.labels, Default::default());
We will use stochastic gradient descent to minimize the cross-entropy loss
on the classification task. To do this we create a sgd
optimizer and then
iterate on the training dataset. After each epoch the accuracy is computed
on the testing set and printed.
let sgd = nn::Sgd::default().build(&vs, 1e-3)?;
for epoch_idx in 1..1001 {
let predicted = train_images.apply(&linear);
let loss = predicted.cross_entropy_for_logits(&dataset.train_labels);
sgd.backward_step(&loss);
let test_accuracy = test_images
.apply(&linear)
.accuracy_for_logits(&dataset.test_labels);
println!("{} {:.2}%", epoch_idx, 100. * f64::from(test_accuracy));
}
On each training step the model output is computed through a forward pass. The cross-entropy loss is then evaluated on the resulting logits using the training labels. The backward pass then evaluates gradients for the trainable variables of our model and these variables are updated by the optimizer.
let predicted = train_images.apply(&linear);
let loss = predicted.cross_entropy_for_logits(&dataset.train_labels);
sgd.backward_step(&loss);
After each epoch the accuracy is evaluated on the testing set and printed out.
let test_accuracy = test_images
.apply(&linear)
.accuracy_for_logits(&dataset.test_labels);
println!("{} {:.2}%", epoch_idx, 100. * f64::from(test_accuracy));
This should result in a 94.5%
accuracy on the testing set.
The whole code for this example can be found in main.rs.