Skip to content

Commit

Permalink
Merge pull request BVLC#1238 from kmatzen/db
Browse files Browse the repository at this point in the history
leveldb/lmdb refactoring
  • Loading branch information
sguada committed Oct 15, 2014
2 parents 69f8047 + 70a11e7 commit b0c5905
Show file tree
Hide file tree
Showing 21 changed files with 1,838 additions and 480 deletions.
56 changes: 28 additions & 28 deletions examples/cifar10/convert_cifar_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@

#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "leveldb/db.h"
#include "stdint.h"

#include "caffe/dataset_factory.hpp"
#include "caffe/proto/caffe.pb.h"

using std::string;

using caffe::Dataset;
using caffe::DatasetFactory;
using caffe::Datum;
using caffe::shared_ptr;

const int kCIFARSize = 32;
const int kCIFARImageNBytes = 3072;
const int kCIFARBatchSize = 10000;
Expand All @@ -31,26 +36,21 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
return;
}

void convert_dataset(const string& input_folder, const string& output_folder) {
// Leveldb options
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
void convert_dataset(const string& input_folder, const string& output_folder,
const string& db_type) {
shared_ptr<Dataset<string, Datum> > train_dataset =
DatasetFactory<string, Datum>(db_type);
CHECK(train_dataset->open(output_folder + "/cifar10_train_" + db_type,
Dataset<string, Datum>::New));
// Data buffer
int label;
char str_buffer[kCIFARImageNBytes];
string value;
caffe::Datum datum;
Datum datum;
datum.set_channels(3);
datum.set_height(kCIFARSize);
datum.set_width(kCIFARSize);

LOG(INFO) << "Writing Training data";
leveldb::DB* train_db;
leveldb::Status status;
status = leveldb::DB::Open(options, output_folder + "/cifar10_train_leveldb",
&train_db);
CHECK(status.ok()) << "Failed to open leveldb.";
for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) {
// Open files
LOG(INFO) << "Training Batch " << fileid + 1;
Expand All @@ -62,17 +62,19 @@ void convert_dataset(const string& input_folder, const string& output_folder) {
read_image(&data_file, &label, str_buffer);
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
datum.SerializeToString(&value);
snprintf(str_buffer, kCIFARImageNBytes, "%05d",
int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
fileid * kCIFARBatchSize + itemid);
train_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
CHECK(train_dataset->put(string(str_buffer, length), datum));
}
}
CHECK(train_dataset->commit());
train_dataset->close();

LOG(INFO) << "Writing Testing data";
leveldb::DB* test_db;
CHECK(leveldb::DB::Open(options, output_folder + "/cifar10_test_leveldb",
&test_db).ok()) << "Failed to open leveldb.";
shared_ptr<Dataset<string, Datum> > test_dataset =
DatasetFactory<string, Datum>(db_type);
CHECK(test_dataset->open(output_folder + "/cifar10_test_" + db_type,
Dataset<string, Datum>::New));
// Open files
std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
std::ios::in | std::ios::binary);
Expand All @@ -81,28 +83,26 @@ void convert_dataset(const string& input_folder, const string& output_folder) {
read_image(&data_file, &label, str_buffer);
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
datum.SerializeToString(&value);
snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
test_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
CHECK(test_dataset->put(string(str_buffer, length), datum));
}

delete train_db;
delete test_db;
CHECK(test_dataset->commit());
test_dataset->close();
}

int main(int argc, char** argv) {
if (argc != 3) {
if (argc != 4) {
printf("This script converts the CIFAR dataset to the leveldb format used\n"
"by caffe to perform classification.\n"
"Usage:\n"
" convert_cifar_data input_folder output_folder\n"
" convert_cifar_data input_folder output_folder db_type\n"
"Where the input folder should contain the binary batch files.\n"
"The CIFAR dataset could be downloaded at\n"
" http://www.cs.toronto.edu/~kriz/cifar.html\n"
"You should gunzip them after downloading.\n");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(string(argv[1]), string(argv[2]));
convert_dataset(string(argv[1]), string(argv[2]), string(argv[3]));
}
return 0;
}
11 changes: 6 additions & 5 deletions examples/cifar10/create_cifar10.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

EXAMPLE=examples/cifar10
DATA=data/cifar10
DBTYPE=lmdb

echo "Creating leveldb..."
echo "Creating $DBTYPE..."

rm -rf $EXAMPLE/cifar10_train_leveldb $EXAMPLE/cifar10_test_leveldb
rm -rf $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/cifar10_test_$DBTYPE

./build/examples/cifar10/convert_cifar_data.bin $DATA $EXAMPLE
./build/examples/cifar10/convert_cifar_data.bin $DATA $EXAMPLE $DBTYPE

echo "Computing image mean..."

./build/tools/compute_image_mean $EXAMPLE/cifar10_train_leveldb \
$EXAMPLE/mean.binaryproto leveldb
./build/tools/compute_image_mean $EXAMPLE/cifar10_train_$DBTYPE \
$EXAMPLE/mean.binaryproto $DBTYPE

echo "Done."
2 changes: 1 addition & 1 deletion examples/feature_extraction/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Extract Features

Now everything necessary is in place.

./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10
./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 lmdb

The name of feature blob that you extract is `fc7`, which represents the highest level feature of the reference model.
We can use any other layer, as well, such as `conv5` or `pool3`.
Expand Down
14 changes: 3 additions & 11 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

#include "boost/scoped_ptr.hpp"
#include "hdf5.h"
#include "leveldb/db.h"
#include "lmdb.h"

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/dataset.hpp"
#include "caffe/filler.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
Expand Down Expand Up @@ -101,15 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
protected:
virtual void InternalThreadEntry();

// LEVELDB
shared_ptr<leveldb::DB> db_;
shared_ptr<leveldb::Iterator> iter_;
// LMDB
MDB_env* mdb_env_;
MDB_dbi mdb_dbi_;
MDB_txn* mdb_txn_;
MDB_cursor* mdb_cursor_;
MDB_val mdb_key_, mdb_value_;
shared_ptr<Dataset<string, Datum> > dataset_;
Dataset<string, Datum>::const_iterator iter_;
};

/**
Expand Down
Loading

0 comments on commit b0c5905

Please sign in to comment.