-
Notifications
You must be signed in to change notification settings - Fork 18.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added contrastive loss layer, associated tests, and a siamese network…
… example using shared weights and the contrastive loss.
- Loading branch information
Nick Carlevaris-Bianco
committed
Aug 26, 2014
1 parent
9516115
commit 921229a
Showing
14 changed files
with
1,263 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// | ||
// This script converts the MNIST dataset to the leveldb format used | ||
// by caffe to train siamese network. | ||
// Usage: | ||
// convert_mnist_data input_image_file input_label_file output_db_file | ||
// The MNIST dataset could be downloaded at | ||
// http://yann.lecun.com/exdb/mnist/ | ||
#include <fstream> // NOLINT(readability/streams) | ||
#include <string> | ||
|
||
#include "glog/logging.h" | ||
#include "google/protobuf/text_format.h" | ||
#include "leveldb/db.h" | ||
#include "stdint.h" | ||
|
||
#include "caffe/proto/caffe.pb.h" | ||
#include "caffe/util/math_functions.hpp" | ||
|
||
uint32_t swap_endian(uint32_t val) { | ||
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); | ||
return (val << 16) | (val >> 16); | ||
} | ||
|
||
void read_image(std::ifstream* image_file, std::ifstream* label_file, | ||
uint32_t index, uint32_t rows, uint32_t cols, | ||
char* pixels, char* label) { | ||
image_file->seekg(index * rows * cols + 16); | ||
image_file->read(pixels, rows * cols); | ||
label_file->seekg(index + 8); | ||
label_file->read(label, 1); | ||
} | ||
|
||
void convert_dataset(const char* image_filename, const char* label_filename, | ||
const char* db_filename) { | ||
// Open files | ||
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); | ||
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary); | ||
CHECK(image_file) << "Unable to open file " << image_filename; | ||
CHECK(label_file) << "Unable to open file " << label_file; | ||
// Read the magic and the meta data | ||
uint32_t magic; | ||
uint32_t num_items; | ||
uint32_t num_labels; | ||
uint32_t rows; | ||
uint32_t cols; | ||
|
||
image_file.read(reinterpret_cast<char*>(&magic), 4); | ||
magic = swap_endian(magic); | ||
CHECK_EQ(magic, 2051) << "Incorrect image file magic."; | ||
label_file.read(reinterpret_cast<char*>(&magic), 4); | ||
magic = swap_endian(magic); | ||
CHECK_EQ(magic, 2049) << "Incorrect label file magic."; | ||
image_file.read(reinterpret_cast<char*>(&num_items), 4); | ||
num_items = swap_endian(num_items); | ||
label_file.read(reinterpret_cast<char*>(&num_labels), 4); | ||
num_labels = swap_endian(num_labels); | ||
CHECK_EQ(num_items, num_labels); | ||
image_file.read(reinterpret_cast<char*>(&rows), 4); | ||
rows = swap_endian(rows); | ||
image_file.read(reinterpret_cast<char*>(&cols), 4); | ||
cols = swap_endian(cols); | ||
|
||
// Open leveldb | ||
leveldb::DB* db; | ||
leveldb::Options options; | ||
options.create_if_missing = true; | ||
options.error_if_exists = true; | ||
leveldb::Status status = leveldb::DB::Open( | ||
options, db_filename, &db); | ||
CHECK(status.ok()) << "Failed to open leveldb " << db_filename | ||
<< ". Is it already existing?"; | ||
|
||
char label_i; | ||
char label_j; | ||
char* pixels = new char[2 * rows * cols]; | ||
const int kMaxKeyLength = 10; | ||
char key[kMaxKeyLength]; | ||
std::string value; | ||
|
||
caffe::Datum datum; | ||
datum.set_channels(2); // one channel for each image in the pair | ||
datum.set_height(rows); | ||
datum.set_width(cols); | ||
LOG(INFO) << "A total of " << num_items << " items."; | ||
LOG(INFO) << "Rows: " << rows << " Cols: " << cols; | ||
for (int itemid = 0; itemid < num_items; ++itemid) { | ||
int i = caffe::caffe_rng_rand() % num_items; // pick a random pair | ||
int j = caffe::caffe_rng_rand() % num_items; | ||
read_image(&image_file, &label_file, i, rows, cols, | ||
pixels, &label_i); | ||
read_image(&image_file, &label_file, j, rows, cols, | ||
pixels + (rows * cols), &label_j); | ||
datum.set_data(pixels, 2*rows*cols); | ||
if (label_i == label_j) { | ||
datum.set_label(1); | ||
} else { | ||
datum.set_label(0); | ||
} | ||
datum.SerializeToString(&value); | ||
snprintf(key, kMaxKeyLength, "%08d", itemid); | ||
db->Put(leveldb::WriteOptions(), std::string(key), value); | ||
} | ||
|
||
delete db; | ||
delete pixels; | ||
} | ||
|
||
int main(int argc, char** argv) { | ||
if (argc != 4) { | ||
printf("This script converts the MNIST dataset to the leveldb format used\n" | ||
"by caffe to train a siamese network.\n" | ||
"Usage:\n" | ||
" convert_mnist_data input_image_file input_label_file " | ||
"output_db_file\n" | ||
"The MNIST dataset could be downloaded at\n" | ||
" http://yann.lecun.com/exdb/mnist/\n" | ||
"You should gunzip them after downloading.\n"); | ||
} else { | ||
google::InitGoogleLogging(argv[0]); | ||
convert_dataset(argv[1], argv[2], argv[3]); | ||
} | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/usr/bin/env sh | ||
# This script converts the mnist data into leveldb format. | ||
|
||
EXAMPLES=../../build/examples/siamese | ||
DATA=../../data/mnist | ||
|
||
echo "Creating leveldb..." | ||
|
||
rm -rf mnist-siamese-train-leveldb | ||
rm -rf mnist-siamese-test-leveldb | ||
|
||
$EXAMPLES/convert_mnist_siamese_data.bin \ | ||
$DATA/train-images-idx3-ubyte \ | ||
$DATA/train-labels-idx1-ubyte \ | ||
mnist-siamese-train-leveldb | ||
$EXAMPLES/convert_mnist_siamese_data.bin \ | ||
$DATA/t10k-images-idx3-ubyte \ | ||
$DATA/t10k-labels-idx1-ubyte \ | ||
mnist-siamese-test-leveldb | ||
|
||
echo "Done." |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
name: "mnist_siamese" | ||
input: "data" | ||
input_dim: 10000 | ||
input_dim: 1 | ||
input_dim: 28 | ||
input_dim: 28 | ||
|
||
layers { | ||
name: "conv1" | ||
type: CONVOLUTION | ||
bottom: "data" | ||
top: "conv1" | ||
blobs_lr: 1 | ||
blobs_lr: 2 | ||
convolution_param { | ||
num_output: 20 | ||
kernel_size: 5 | ||
stride: 1 | ||
} | ||
} | ||
layers { | ||
name: "pool1" | ||
type: POOLING | ||
bottom: "conv1" | ||
top: "pool1" | ||
pooling_param { | ||
pool: MAX | ||
kernel_size: 2 | ||
stride: 2 | ||
} | ||
} | ||
layers { | ||
name: "conv2" | ||
type: CONVOLUTION | ||
bottom: "pool1" | ||
top: "conv2" | ||
blobs_lr: 1 | ||
blobs_lr: 2 | ||
convolution_param { | ||
num_output: 50 | ||
kernel_size: 5 | ||
stride: 1 | ||
} | ||
} | ||
layers { | ||
name: "pool2" | ||
type: POOLING | ||
bottom: "conv2" | ||
top: "pool2" | ||
pooling_param { | ||
pool: MAX | ||
kernel_size: 2 | ||
stride: 2 | ||
} | ||
} | ||
layers { | ||
name: "ip1" | ||
type: INNER_PRODUCT | ||
bottom: "pool2" | ||
top: "ip1" | ||
blobs_lr: 1 | ||
blobs_lr: 2 | ||
inner_product_param { | ||
num_output: 500 | ||
} | ||
} | ||
layers { | ||
name: "relu1" | ||
type: RELU | ||
bottom: "ip1" | ||
top: "ip1" | ||
} | ||
layers { | ||
name: "ip2" | ||
type: INNER_PRODUCT | ||
bottom: "ip1" | ||
top: "ip2" | ||
blobs_lr: 1 | ||
blobs_lr: 2 | ||
inner_product_param { | ||
num_output: 10 | ||
} | ||
} | ||
|
||
layers { | ||
name: "feat" | ||
type: INNER_PRODUCT | ||
bottom: "ip2" | ||
top: "feat" | ||
blobs_lr: 1 | ||
blobs_lr: 2 | ||
inner_product_param { | ||
num_output: 2 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# The train/test net protocol buffer definition | ||
net: "mnist_siamese_train.prototxt" | ||
# test_iter specifies how many forward passes the test should carry out. | ||
# In the case of MNIST, we have test batch size 100 and 100 test iterations, | ||
# covering the full 10,000 testing images. | ||
test_iter: 100 | ||
# Carry out testing every 500 training iterations. | ||
test_interval: 500 | ||
# The base learning rate, momentum and the weight decay of the network. | ||
base_lr: 0.01 | ||
momentum: 0.9 | ||
weight_decay: 0.0000 | ||
# The learning rate policy | ||
lr_policy: "inv" | ||
gamma: 0.0001 | ||
power: 0.75 | ||
# Display every 100 iterations | ||
display: 100 | ||
# The maximum number of iterations | ||
max_iter: 100000 | ||
# snapshot intermediate results | ||
snapshot: 5000 | ||
snapshot_prefix: "mnist_siamese" | ||
# solver mode: CPU or GPU | ||
solver_mode: GPU |
Oops, something went wrong.