forked from yanshanjing/caffe-parallel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_imageset.cpp
196 lines (178 loc) · 6.34 KB
/
convert_imageset.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
// This program converts a set of images to a lmdb/leveldb by storing them
// as Datum proto buffers.
// Usage:
// convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 7
// ....
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <lmdb.h>
#include <sys/stat.h>
#include <algorithm>
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"
using namespace caffe; // NOLINT(build/namespaces)
using std::pair;
using std::string;
DEFINE_bool(gray, false,
"When this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle, false,
"Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb", "The backend for storing the result");
DEFINE_int32(resize_width, 0, "Width images are resized to");
DEFINE_int32(resize_height, 0, "Height images are resized to");
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
"format used as input for Caffe.\n"
"Usage:\n"
" convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
"The ImageNet dataset for the training demo is at\n"
" http://www.image-net.org/download-images\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc != 4) {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
return 1;
}
bool is_color = !FLAGS_gray;
std::ifstream infile(argv[2]);
std::vector<std::pair<string, int> > lines;
string filename;
int label;
while (infile >> filename >> label) {
lines.push_back(std::make_pair(filename, label));
}
if (FLAGS_shuffle) {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images.";
const string& db_backend = FLAGS_backend;
const char* db_path = argv[3];
int resize_height = std::max<int>(0, FLAGS_resize_height);
int resize_width = std::max<int>(0, FLAGS_resize_width);
// Open new db
// lmdb
MDB_env *mdb_env;
MDB_dbi mdb_dbi;
MDB_val mdb_key, mdb_data;
MDB_txn *mdb_txn;
// leveldb
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
leveldb::WriteBatch* batch = NULL;
// Open db
if (db_backend == "leveldb") { // leveldb
LOG(INFO) << "Opening leveldb " << db_path;
leveldb::Status status = leveldb::DB::Open(
options, db_path, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_path
<< ". Is it already existing?";
batch = new leveldb::WriteBatch();
} else if (db_backend == "lmdb") { // lmdb
LOG(INFO) << "Opening lmdb " << db_path;
CHECK_EQ(mkdir(db_path, 0744), 0)
<< "mkdir " << db_path << "failed";
CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS) // 1TB
<< "mdb_env_set_mapsize failed";
CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS)
<< "mdb_env_open failed";
CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
<< "mdb_txn_begin failed";
CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
<< "mdb_open failed. Does the lmdb already exist?";
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
// Storing to db
string root_folder(argv[1]);
Datum datum;
int count = 0;
const int kMaxKeyLength = 256;
char key_cstr[kMaxKeyLength];
int data_size;
bool data_size_initialized = false;
for (int line_id = 0; line_id < lines.size(); ++line_id) {
if (!ReadImageToDatum(root_folder + lines[line_id].first,
lines[line_id].second, resize_height, resize_width, is_color, &datum)) {
continue;
}
if (!data_size_initialized) {
data_size = datum.channels() * datum.height() * datum.width();
data_size_initialized = true;
} else {
const string& data = datum.data();
CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
<< data.size();
}
// sequential
snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
lines[line_id].first.c_str());
string value;
datum.SerializeToString(&value);
string keystr(key_cstr);
// Put in db
if (db_backend == "leveldb") { // leveldb
batch->Put(keystr, value);
} else if (db_backend == "lmdb") { // lmdb
mdb_data.mv_size = value.size();
mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
mdb_key.mv_size = keystr.size();
mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
<< "mdb_put failed";
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
if (++count % 1000 == 0) {
// Commit txn
if (db_backend == "leveldb") { // leveldb
db->Write(leveldb::WriteOptions(), batch);
delete batch;
batch = new leveldb::WriteBatch();
} else if (db_backend == "lmdb") { // lmdb
CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
<< "mdb_txn_commit failed";
CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
<< "mdb_txn_begin failed";
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
LOG(ERROR) << "Processed " << count << " files.";
}
}
// write the last batch
if (count % 1000 != 0) {
if (db_backend == "leveldb") { // leveldb
db->Write(leveldb::WriteOptions(), batch);
delete batch;
delete db;
} else if (db_backend == "lmdb") { // lmdb
CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
mdb_close(mdb_env, mdb_dbi);
mdb_env_close(mdb_env);
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
LOG(ERROR) << "Processed " << count << " files.";
}
return 0;
}