-
Notifications
You must be signed in to change notification settings - Fork 18.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor DataLayer using a new DataTransformer
Start the refactoring of the datalayers to avoid data transformation code duplication. So far, only DataLayer has been done.
- Loading branch information
Showing
10 changed files
with
267 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#ifndef CAFFE_DATA_TRANSFORMER_HPP | ||
#define CAFFE_DATA_TRANSFORMER_HPP | ||
|
||
#include "caffe/common.hpp" | ||
#include "caffe/proto/caffe.pb.h" | ||
|
||
namespace caffe { | ||
|
||
/** | ||
* @brief Applies common transformations to the input data, such as | ||
* scaling, mirroring, substracting the image mean... | ||
*/ | ||
template <typename Dtype> | ||
class DataTransformer { | ||
public: | ||
explicit DataTransformer(const TransformationParameter& param) | ||
: param_(param) { | ||
phase_ = Caffe::phase(); | ||
} | ||
virtual ~DataTransformer() {} | ||
|
||
void InitRand(); | ||
|
||
/** | ||
* @brief Applies the transformation defined in the data layer's | ||
* transform_param block to the data. | ||
* | ||
* @param batch_item_id | ||
* Datum position within the batch. This is used to compute the | ||
* writing position in the top blob's data | ||
* @param datum | ||
* Datum containing the data to be transformed. | ||
* @param mean | ||
* @param top_data | ||
* This is meant to be the top blob's data. The transformed data will be | ||
* written at the appropriate place within the blob's data. | ||
*/ | ||
void Transform(const int batch_item_id, const Datum& datum, | ||
const Dtype* mean, Dtype* transformed_data); | ||
|
||
protected: | ||
virtual unsigned int Rand(); | ||
|
||
// Tranformation parameters | ||
TransformationParameter param_; | ||
|
||
|
||
shared_ptr<Caffe::RNG> rng_; | ||
Caffe::Phase phase_; | ||
}; | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_DATA_TRANSFORMER_HPP_ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#include <string> | ||
|
||
#include "caffe/data_transformer.hpp" | ||
#include "caffe/util/math_functions.hpp" | ||
#include "caffe/util/rng.hpp" | ||
|
||
namespace caffe { | ||
|
||
template<typename Dtype> | ||
void DataTransformer<Dtype>::Transform(const int batch_item_id, | ||
const Datum& datum, | ||
const Dtype* mean, | ||
Dtype* transformed_data) { | ||
|
||
const string& data = datum.data(); | ||
const int channels = datum.channels(); | ||
const int height = datum.height(); | ||
const int width = datum.width(); | ||
const int size = datum.channels() * datum.height() * datum.width(); | ||
|
||
const int crop_size = param_.crop_size(); | ||
const bool mirror = param_.mirror(); | ||
const Dtype scale = param_.scale(); | ||
|
||
|
||
|
||
if (mirror && crop_size == 0) { | ||
LOG(FATAL) << "Current implementation requires mirror and crop_size to be " | ||
<< "set at the same time."; | ||
} | ||
|
||
if (crop_size) { | ||
CHECK(data.size()) << "Image cropping only support uint8 data"; | ||
int h_off, w_off; | ||
// We only do random crop when we do training. | ||
if (phase_ == Caffe::TRAIN) { | ||
h_off = Rand() % (height - crop_size); | ||
w_off = Rand() % (width - crop_size); | ||
} else { | ||
h_off = (height - crop_size) / 2; | ||
w_off = (width - crop_size) / 2; | ||
} | ||
if (mirror && Rand() % 2) { | ||
// Copy mirrored version | ||
for (int c = 0; c < channels; ++c) { | ||
for (int h = 0; h < crop_size; ++h) { | ||
for (int w = 0; w < crop_size; ++w) { | ||
int data_index = (c * height + h + h_off) * width + w + w_off; | ||
int top_index = ((batch_item_id * channels + c) * crop_size + h) | ||
* crop_size + (crop_size - 1 - w); | ||
Dtype datum_element = | ||
static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); | ||
transformed_data[top_index] = | ||
(datum_element - mean[data_index]) * scale; | ||
} | ||
} | ||
} | ||
} else { | ||
// Normal copy | ||
for (int c = 0; c < channels; ++c) { | ||
for (int h = 0; h < crop_size; ++h) { | ||
for (int w = 0; w < crop_size; ++w) { | ||
int top_index = ((batch_item_id * channels + c) * crop_size + h) | ||
* crop_size + w; | ||
int data_index = (c * height + h + h_off) * width + w + w_off; | ||
Dtype datum_element = | ||
static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); | ||
transformed_data[top_index] = | ||
(datum_element - mean[data_index]) * scale; | ||
} | ||
} | ||
} | ||
} | ||
} else { | ||
// we will prefer to use data() first, and then try float_data() | ||
if (data.size()) { | ||
for (int j = 0; j < size; ++j) { | ||
Dtype datum_element = | ||
static_cast<Dtype>(static_cast<uint8_t>(data[j])); | ||
transformed_data[j + batch_item_id * size] = | ||
(datum_element - mean[j]) * scale; | ||
} | ||
} else { | ||
for (int j = 0; j < size; ++j) { | ||
transformed_data[j + batch_item_id * size] = | ||
(datum.float_data(j) - mean[j]) * scale; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void DataTransformer<Dtype>::InitRand() { | ||
const bool needs_rand = (phase_ == Caffe::TRAIN) && | ||
(param_.mirror() || param_.crop_size()); | ||
if (needs_rand) { | ||
const unsigned int rng_seed = caffe_rng_rand(); | ||
rng_.reset(new Caffe::RNG(rng_seed)); | ||
} else { | ||
rng_.reset(); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
unsigned int DataTransformer<Dtype>::Rand() { | ||
CHECK(rng_); | ||
caffe::rng_t* rng = | ||
static_cast<caffe::rng_t*>(rng_->generator()); | ||
return (*rng)(); | ||
} | ||
|
||
INSTANTIATE_CLASS(DataTransformer); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.