Skip to content

Commit

Permalink
Refactor DataLayer using a new DataTransformer
Browse files Browse the repository at this point in the history
Start the refactoring of the datalayers to avoid data transformation
code duplication. So far, only DataLayer has been done.
  • Loading branch information
arntanguy committed Aug 20, 2014
1 parent bf61d4f commit f6ffd8e
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 134 deletions.
8 changes: 4 additions & 4 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/filler.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
Expand All @@ -24,12 +25,12 @@ namespace caffe {

// TODO: DataLayer, ImageDataLayer, and WindowDataLayer all have the
// same basic structure and a lot of duplicated code.

template <typename Dtype>
class DataLayer : public Layer<Dtype>, public InternalThread {
public:
explicit DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param),
data_transformer_(param.data_param().transform_param()) {}
virtual ~DataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
Expand All @@ -53,11 +54,10 @@ class DataLayer : public Layer<Dtype>, public InternalThread {

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual unsigned int PrefetchRand();
// The thread's function
virtual void InternalThreadEntry();

shared_ptr<Caffe::RNG> prefetch_rng_;
DataTransformer<Dtype> data_transformer_;

// LEVELDB
shared_ptr<leveldb::DB> db_;
Expand Down
55 changes: 55 additions & 0 deletions include/caffe/data_transformer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef CAFFE_DATA_TRANSFORMER_HPP
#define CAFFE_DATA_TRANSFORMER_HPP

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief Applies common transformations to the input data, such as
* scaling, mirroring, substracting the image mean...
*/
template <typename Dtype>
class DataTransformer {
public:
explicit DataTransformer(const TransformationParameter& param)
: param_(param) {
phase_ = Caffe::phase();
}
virtual ~DataTransformer() {}

void InitRand();

/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to the data.
*
* @param batch_item_id
* Datum position within the batch. This is used to compute the
* writing position in the top blob's data
* @param datum
* Datum containing the data to be transformed.
* @param mean
* @param top_data
* This is meant to be the top blob's data. The transformed data will be
* written at the appropriate place within the blob's data.
*/
void Transform(const int batch_item_id, const Datum& datum,
const Dtype* mean, Dtype* transformed_data);

protected:
virtual unsigned int Rand();

// Tranformation parameters
TransformationParameter param_;


shared_ptr<Caffe::RNG> rng_;
Caffe::Phase phase_;
};

} // namespace caffe

#endif // CAFFE_DATA_TRANSFORMER_HPP_

114 changes: 114 additions & 0 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include <string>

#include "caffe/data_transformer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

namespace caffe {

template<typename Dtype>
void DataTransformer<Dtype>::Transform(const int batch_item_id,
const Datum& datum,
const Dtype* mean,
Dtype* transformed_data) {

const string& data = datum.data();
const int channels = datum.channels();
const int height = datum.height();
const int width = datum.width();
const int size = datum.channels() * datum.height() * datum.width();

const int crop_size = param_.crop_size();
const bool mirror = param_.mirror();
const Dtype scale = param_.scale();



if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}

if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = Rand() % (height - crop_size);
w_off = Rand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && Rand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int data_index = (c * height + h + h_off) * width + w + w_off;
int top_index = ((batch_item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((batch_item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// we will prefer to use data() first, and then try float_data()
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
transformed_data[j + batch_item_id * size] =
(datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
transformed_data[j + batch_item_id * size] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}
}

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
const bool needs_rand = (phase_ == Caffe::TRAIN) &&
(param_.mirror() || param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
} else {
rng_.reset();
}
}

template <typename Dtype>
unsigned int DataTransformer<Dtype>::Rand() {
CHECK(rng_);
caffe::rng_t* rng =
static_cast<caffe::rng_t*>(rng_->generator());
return (*rng)();
}

INSTANTIATE_CLASS(DataTransformer);

} // namespace caffe
98 changes: 10 additions & 88 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,8 @@ void DataLayer<Dtype>::InternalThreadEntry() {
if (output_labels_) {
top_label = prefetch_label_.mutable_cpu_data();
}
const Dtype scale = this->layer_param_.data_param().scale();
const int batch_size = this->layer_param_.data_param().batch_size();
const int crop_size = this->layer_param_.data_param().crop_size();
const bool mirror = this->layer_param_.data_param().mirror();

if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}
// datum scales
const int channels = datum_channels_;
const int height = datum_height_;
const int width = datum_width_;
const int size = datum_size_;
const Dtype* mean = data_mean_.cpu_data();
for (int item_id = 0; item_id < batch_size; ++item_id) {
// get a blob
Expand All @@ -56,66 +44,13 @@ void DataLayer<Dtype>::InternalThreadEntry() {
LOG(FATAL) << "Unknown database backend";
}

const string& data = datum.data();
if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = PrefetchRand() % (height - crop_size);
w_off = PrefetchRand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && PrefetchRand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// we will prefer to use data() first, and then try float_data()
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
top_data[item_id * size + j] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}
// Apply data transformations (mirror, scale, crop...)
data_transformer_.Transform(item_id, datum, mean, top_data);

if (output_labels_) {
top_label[item_id] = datum.label();
}

// go to the next iter
switch (this->layer_param_.data_param().backend()) {
case DataParameter_DB_LEVELDB:
Expand Down Expand Up @@ -244,7 +179,7 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
}

// image
int crop_size = this->layer_param_.data_param().crop_size();
int crop_size = this->layer_param_.data_param().transform_param().crop_size();
if (crop_size > 0) {
(*top)[0]->Reshape(this->layer_param_.data_param().batch_size(),
datum.channels(), crop_size, crop_size);
Expand Down Expand Up @@ -274,8 +209,9 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(datum_height_, crop_size);
CHECK_GT(datum_width_, crop_size);
// check if we want to have mean
if (this->layer_param_.data_param().has_mean_file()) {
const string& mean_file = this->layer_param_.data_param().mean_file();
if (this->layer_param_.data_param().transform_param().has_mean_file()) {
const string& mean_file =
this->layer_param_.data_param().transform_param().mean_file();
LOG(INFO) << "Loading mean file from" << mean_file;
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
Expand Down Expand Up @@ -305,15 +241,9 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void DataLayer<Dtype>::CreatePrefetchThread() {
phase_ = Caffe::phase();
const bool prefetch_needs_rand = (phase_ == Caffe::TRAIN) &&
(this->layer_param_.data_param().mirror() ||
this->layer_param_.data_param().crop_size());
if (prefetch_needs_rand) {
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
} else {
prefetch_rng_.reset();
}

data_transformer_.InitRand();

CHECK(!StartInternalThread()) << "Pthread execution failed";
}

Expand All @@ -322,14 +252,6 @@ void DataLayer<Dtype>::JoinPrefetchThread() {
CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
}

template <typename Dtype>
unsigned int DataLayer<Dtype>::PrefetchRand() {
CHECK(prefetch_rng_);
caffe::rng_t* prefetch_rng =
static_cast<caffe::rng_t*>(prefetch_rng_->generator());
return (*prefetch_rng)();
}

template <typename Dtype>
void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Expand Down
Loading

0 comments on commit f6ffd8e

Please sign in to comment.