Skip to content

Commit

Permalink
Merge pull request #2410 from sguada/datum_transform
Browse files Browse the repository at this point in the history
Datum transform
  • Loading branch information
jeffdonahue committed May 30, 2015
2 parents 7e98145 + e048b17 commit 8b05a02
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 99 deletions.
36 changes: 36 additions & 0 deletions include/caffe/data_transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DataTransformer {
*/
void Transform(const vector<cv::Mat> & mat_vector,
Blob<Dtype>* transformed_blob);

/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to a cv::Mat
Expand All @@ -87,6 +88,41 @@ class DataTransformer {
*/
void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);

/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
*
* @param datum
* Datum containing the data to be transformed.
*/
vector<int> InferBlobShape(const Datum& datum);
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
* It uses the first element to infer the shape of the blob.
*
* @param datum_vector
* A vector of Datum containing the data to be transformed.
*/
vector<int> InferBlobShape(const vector<Datum> & datum_vector);
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
* It uses the first element to infer the shape of the blob.
*
* @param mat_vector
* A vector of Mat containing the data to be transformed.
*/
vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
*
* @param cv_img
* cv::Mat containing the data to be transformed.
*/
vector<int> InferBlobShape(const cv::Mat& cv_img);

protected:
/**
* @brief Generates a random integer from Uniform({0, 1, ..., n-1}).
Expand Down
116 changes: 112 additions & 4 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,31 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Blob<Dtype>* transformed_blob) {
// If datum is encoded, decoded and transform the cv::image.
if (datum.encoded()) {
CHECK(!param_.force_color() && !param_.force_gray())
<< "cannot set both force_color and force_gray";
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {
// If force_color then decode in color otherwise decode in gray.
cv_img = DecodeDatumToCVMat(datum, param_.force_color());
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}
// Transform the cv::image into blob.
return Transform(cv_img, transformed_blob);
} else {
if (param_.force_color() || param_.force_gray()) {
LOG(ERROR) << "force_color and force_gray only for encoded datum";
}
}

const int crop_size = param_.crop_size();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();

// Check dimensions.
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
Expand All @@ -139,8 +160,6 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
CHECK_LE(width, datum_width);
CHECK_GE(num, 1);

const int crop_size = param_.crop_size();

if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
Expand Down Expand Up @@ -196,10 +215,12 @@ void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
Blob<Dtype>* transformed_blob) {
const int crop_size = param_.crop_size();
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;

// Check dimensions.
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
Expand All @@ -212,7 +233,6 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,

CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";

const int crop_size = param_.crop_size();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
Expand Down Expand Up @@ -297,11 +317,23 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
template<typename Dtype>
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
Blob<Dtype>* transformed_blob) {
const int crop_size = param_.crop_size();
const int input_num = input_blob->num();
const int input_channels = input_blob->channels();
const int input_height = input_blob->height();
const int input_width = input_blob->width();

if (transformed_blob->count() == 0) {
// Initialize transformed_blob with the right shape.
if (crop_size) {
transformed_blob->Reshape(input_num, input_channels,
crop_size, crop_size);
} else {
transformed_blob->Reshape(input_num, input_channels,
input_height, input_width);
}
}

const int num = transformed_blob->num();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
Expand All @@ -313,7 +345,7 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
CHECK_GE(input_height, height);
CHECK_GE(input_width, width);

const int crop_size = param_.crop_size();

const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
Expand Down Expand Up @@ -395,6 +427,82 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
}
}

template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
if (datum.encoded()) {
CHECK(!param_.force_color() && !param_.force_gray())
<< "cannot set both force_color and force_gray";
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {
// If force_color then decode in color otherwise decode in gray.
cv_img = DecodeDatumToCVMat(datum, param_.force_color());
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}
// InferBlobShape using the cv::image.
return InferBlobShape(cv_img);
}

const int crop_size = param_.crop_size();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
// Check dimensions.
CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
// Build BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = datum_channels;
shape[2] = (crop_size)? crop_size: datum_height;
shape[3] = (crop_size)? crop_size: datum_width;
return shape;
}

template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
const vector<Datum> & datum_vector) {
const int num = datum_vector.size();
CHECK_GT(num, 0) << "There is no datum to in the vector";
// Use first datum in the vector to InferBlobShape.
vector<int> shape = InferBlobShape(datum_vector[0]);
// Adjust num to the size of the vector.
shape[0] = num;
return shape;
}

template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {
const int crop_size = param_.crop_size();
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
// Check dimensions.
CHECK_GT(img_channels, 0);
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
// Build BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = img_channels;
shape[2] = (crop_size)? crop_size: img_height;
shape[3] = (crop_size)? crop_size: img_width;
return shape;
}

template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
const vector<cv::Mat> & mat_vector) {
const int num = mat_vector.size();
CHECK_GT(num, 0) << "There is no cv_img to in the vector";
// Use first cv_img in the vector to InferBlobShape.
vector<int> shape = InferBlobShape(mat_vector[0]);
// Adjust num to the size of the vector.
shape[0] = num;
return shape;
}

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
const bool needs_rand = param_.mirror() ||
Expand Down
10 changes: 6 additions & 4 deletions src/caffe/layers/base_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
} else {
output_labels_ = true;
}
// The subclasses should setup the size of bottom and top
DataLayerSetUp(bottom, top);
data_transformer_.reset(
new DataTransformer<Dtype>(transform_param_, this->phase_));
data_transformer_->InitRand();
// The subclasses should setup the size of bottom and top
DataLayerSetUp(bottom, top);
}

template <typename Dtype>
Expand Down Expand Up @@ -62,13 +62,15 @@ void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
JoinPrefetchThread();
DLOG(INFO) << "Thread joined";
// Reshape to loaded data.
top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(),
this->prefetch_data_.height(), this->prefetch_data_.width());
top[0]->ReshapeLike(prefetch_data_);
// Copy the data
caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
top[0]->mutable_cpu_data());
DLOG(INFO) << "Prefetch copied";
if (this->output_labels_) {
// Reshape to loaded labels.
top[1]->ReshapeLike(prefetch_label_);
// Copy the labels.
caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
top[1]->mutable_cpu_data());
}
Expand Down
6 changes: 4 additions & 2 deletions src/caffe/layers/base_data_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
// First, join the thread
JoinPrefetchThread();
// Reshape to loaded data.
top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(),
this->prefetch_data_.height(), this->prefetch_data_.width());
top[0]->ReshapeLike(this->prefetch_data_);
// Copy the data
caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
top[0]->mutable_gpu_data());
if (this->output_labels_) {
// Reshape to loaded labels.
top[1]->ReshapeLike(prefetch_label_);
// Copy the labels.
caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
top[1]->mutable_gpu_data());
}
Expand Down
Loading

0 comments on commit 8b05a02

Please sign in to comment.