Skip to content

Commit

Permalink
Added contrastive loss layer, associated tests, and a siamese network…
Browse files Browse the repository at this point in the history
… example using shared weights and the contrastive loss.
  • Loading branch information
Nick Carlevaris-Bianco committed Sep 8, 2014
1 parent fc921bf commit d149c9a
Show file tree
Hide file tree
Showing 14 changed files with 1,323 additions and 2 deletions.
123 changes: 123 additions & 0 deletions examples/siamese/convert_mnist_siamese_data.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//
// This script converts the MNIST dataset to the leveldb format used
// by caffe to train siamese network.
// Usage:
// convert_mnist_data input_image_file input_label_file output_db_file
// The MNIST dataset could be downloaded at
// http://yann.lecun.com/exdb/mnist/
#include <fstream> // NOLINT(readability/streams)
#include <string>

#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "leveldb/db.h"
#include "stdint.h"

#include "caffe/proto/caffe.pb.h"
#include "caffe/util/math_functions.hpp"

uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}

void read_image(std::ifstream* image_file, std::ifstream* label_file,
uint32_t index, uint32_t rows, uint32_t cols,
char* pixels, char* label) {
image_file->seekg(index * rows * cols + 16);
image_file->read(pixels, rows * cols);
label_file->seekg(index + 8);
label_file->read(label, 1);
}

void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_filename) {
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_file;
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;

image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CHECK_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);

// Open leveldb
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status = leveldb::DB::Open(
options, db_filename, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_filename
<< ". Is it already existing?";

char label_i;
char label_j;
char* pixels = new char[2 * rows * cols];
const int kMaxKeyLength = 10;
char key[kMaxKeyLength];
std::string value;

caffe::Datum datum;
datum.set_channels(2); // one channel for each image in the pair
datum.set_height(rows);
datum.set_width(cols);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
for (int itemid = 0; itemid < num_items; ++itemid) {
int i = caffe::caffe_rng_rand() % num_items; // pick a random pair
int j = caffe::caffe_rng_rand() % num_items;
read_image(&image_file, &label_file, i, rows, cols,
pixels, &label_i);
read_image(&image_file, &label_file, j, rows, cols,
pixels + (rows * cols), &label_j);
datum.set_data(pixels, 2*rows*cols);
if (label_i == label_j) {
datum.set_label(1);
} else {
datum.set_label(0);
}
datum.SerializeToString(&value);
snprintf(key, kMaxKeyLength, "%08d", itemid);
db->Put(leveldb::WriteOptions(), std::string(key), value);
}

delete db;
delete pixels;
}

int main(int argc, char** argv) {
if (argc != 4) {
printf("This script converts the MNIST dataset to the leveldb format used\n"
"by caffe to train a siamese network.\n"
"Usage:\n"
" convert_mnist_data input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading.\n");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3]);
}
return 0;
}
21 changes: 21 additions & 0 deletions examples/siamese/create_mnist_siamese.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env sh
# This script converts the mnist data into leveldb format.

EXAMPLES=./build/examples/siamese
DATA=./data/mnist

echo "Creating leveldb..."

rm -rf ./examples/siamese/mnist_siamese_train_leveldb
rm -rf ./examples/siamese/mnist_siamese_test_leveldb

$EXAMPLES/convert_mnist_siamese_data.bin \
$DATA/train-images-idx3-ubyte \
$DATA/train-labels-idx1-ubyte \
./examples/siamese/mnist_siamese_train_leveldb
$EXAMPLES/convert_mnist_siamese_data.bin \
$DATA/t10k-images-idx3-ubyte \
$DATA/t10k-labels-idx1-ubyte \
./examples/siamese/mnist_siamese_test_leveldb

echo "Done."
169 changes: 169 additions & 0 deletions examples/siamese/mnist_siamese.ipynb

Large diffs are not rendered by default.

95 changes: 95 additions & 0 deletions examples/siamese/mnist_siamese.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
name: "mnist_siamese"
input: "data"
input_dim: 10000
input_dim: 1
input_dim: 28
input_dim: 28

layers {
name: "conv1"
type: CONVOLUTION
bottom: "data"
top: "conv1"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
}
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "pool1"
top: "conv2"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
}
}
layers {
name: "pool2"
type: POOLING
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "ip1"
type: INNER_PRODUCT
bottom: "pool2"
top: "ip1"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 500
}
}
layers {
name: "relu1"
type: RELU
bottom: "ip1"
top: "ip1"
}
layers {
name: "ip2"
type: INNER_PRODUCT
bottom: "ip1"
top: "ip2"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 10
}
}

layers {
name: "feat"
type: INNER_PRODUCT
bottom: "ip2"
top: "feat"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 2
}
}
25 changes: 25 additions & 0 deletions examples/siamese/mnist_siamese_solver.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# The train/test net protocol buffer definition
net: "examples/siamese/mnist_siamese_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0000
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 50000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/siamese/mnist_siamese"
# solver mode: CPU or GPU
solver_mode: GPU
Loading

0 comments on commit d149c9a

Please sign in to comment.