Skip to content

Commit

Permalink
Add command-line option for size of the db batch size.
Browse files Browse the repository at this point in the history
  • Loading branch information
jyegerlehner committed Oct 14, 2014
1 parent f2a0882 commit 05d5d4b
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions tools/extract_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,45 @@ int feature_extraction_pipeline(int argc, char** argv) {
"Usage: extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
" save_feature_leveldb_name1[,name2,...] num_mini_batches [CPU/GPU]"
" [DEVICE_ID=0]\n"
" [DEVICE_ID=0] [DatabaseBatchSize=100]\n"
"Note: you can extract multiple features in one pass by specifying"
" multiple feature blob names and leveldb names seperated by ','."
" The names cannot contain white space characters and the number of blobs"
" and leveldbs must be equal.";
return 1;
}
int arg_pos = num_required_args;

uint db_batch_size = 100;
uint device_id = 0;
bool using_gpu = false;
arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
LOG(ERROR)<< "Using GPU";
uint device_id = 0;
if (argc > arg_pos + 1) {
device_id = atoi(argv[arg_pos + 1]);
while (arg_pos < argc) {
if ( arg_pos == num_required_args ) {
if (strcmp(argv[arg_pos], "GPU") == 0) {
LOG(ERROR)<< "Using GPU";
Caffe::set_mode(Caffe::GPU);
using_gpu = true;
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
}
} else if (arg_pos == num_required_args+1) {
device_id = atoi(argv[arg_pos]);
CHECK_GE(device_id, 0);
} else if (arg_pos == num_required_args+2) {
db_batch_size = atoi(argv[arg_pos]);
CHECK_GE(db_batch_size, 1);
}
arg_pos++;
}

if (using_gpu) {
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
}

LOG(ERROR) << "Using DB batch size=" << db_batch_size;

Caffe::set_phase(Caffe::TEST);

arg_pos = 0; // the name of the executable
Expand Down Expand Up @@ -144,7 +159,6 @@ int feature_extraction_pipeline(int argc, char** argv) {
vector<Blob<float>*> input_vec;
vector<int> image_indices(num_features, 0);

const int DB_BATCH_SIZE = 100;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
for (int i = 0; i < num_features; ++i) {
Expand All @@ -170,7 +184,7 @@ int feature_extraction_pipeline(int argc, char** argv) {
snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
feature_batches[i]->Put(string(key_str), value);
++image_indices[i];
if (image_indices[i] % DB_BATCH_SIZE == 0) {
if (image_indices[i] % db_batch_size == 0) {
feature_dbs[i]->Write(leveldb::WriteOptions(),
feature_batches[i].get());
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
Expand All @@ -182,7 +196,7 @@ int feature_extraction_pipeline(int argc, char** argv) {
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
for (int i = 0; i < num_features; ++i) {
if (image_indices[i] % DB_BATCH_SIZE != 0) {
if (image_indices[i] % db_batch_size != 0) {
feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get());
}
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
Expand Down

0 comments on commit 05d5d4b

Please sign in to comment.