Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
Tile Layer BVLC#2083
  • Loading branch information
ctrevino committed Aug 26, 2015
2 parents 331f497 + 990835f commit e303c28
Show file tree
Hide file tree
Showing 15 changed files with 877 additions and 25 deletions.
67 changes: 67 additions & 0 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,44 @@ class EltwiseLayer : public Layer<Dtype> {
bool stable_prod_grad_;
};

/**
* @brief A layer for learning "embeddings" of one-hot vector input.
* Equivalent to an InnerProductLayer with one-hot vectors as input, but
* for efficiency the input is the "hot" index of each column itself.
*
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
*/
template <typename Dtype>
class EmbedLayer : public Layer<Dtype> {
public:
explicit EmbedLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Embed"; }
virtual inline int ExactNumBottomBlobs() const { return 1; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int M_;
int K_;
int N_;
bool bias_term_;
Blob<Dtype> bias_multiplier_;
};

/**
* @brief Takes two+ Blobs, interprets last Blob as a selector and
* filter remaining Blobs accordingly with selector data (0 means that
Expand Down Expand Up @@ -624,6 +662,35 @@ class SliceLayer : public Layer<Dtype> {
vector<int> slice_point_;
};

/**
* @brief Copy a Blob along specified dimensions.
*/
template <typename Dtype>
class TileLayer : public Layer<Dtype> {
public:
explicit TileLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Tile"; }
virtual inline int ExactNumBottomBlobs() const { return 1; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

unsigned int axis_, tiles_, outer_dim_, inner_dim_;
};

} // namespace caffe

#endif // CAFFE_COMMON_LAYERS_HPP_
11 changes: 8 additions & 3 deletions include/caffe/test/test_gradient_check_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class GradientChecker {
void CheckGradientEltwise(Layer<Dtype>* layer,
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);

// Checks the gradient of a single output with respect to particular input
// blob(s). If check_bottom = i >= 0, check only the ith bottom Blob.
// If check_bottom == -1, check everything -- all bottom Blobs and all
// param Blobs. Otherwise (if check_bottom < -1), check only param Blobs.
void CheckGradientSingle(Layer<Dtype>* layer,
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top,
int check_bottom, int top_id, int top_data_id, bool element_wise = false);
Expand Down Expand Up @@ -83,21 +87,22 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
// First, figure out what blobs we need to check against, and zero init
// parameter blobs.
vector<Blob<Dtype>*> blobs_to_check;
vector<bool> propagate_down(bottom.size(), check_bottom < 0);
vector<bool> propagate_down(bottom.size(), check_bottom == -1);
for (int i = 0; i < layer->blobs().size(); ++i) {
Blob<Dtype>* blob = layer->blobs()[i].get();
caffe_set(blob->count(), static_cast<Dtype>(0), blob->mutable_cpu_diff());
blobs_to_check.push_back(blob);
}
if (check_bottom < 0) {
if (check_bottom == -1) {
for (int i = 0; i < bottom.size(); ++i) {
blobs_to_check.push_back(bottom[i]);
}
} else {
} else if (check_bottom >= 0) {
CHECK_LT(check_bottom, bottom.size());
blobs_to_check.push_back(bottom[check_bottom]);
propagate_down[check_bottom] = true;
}
CHECK_GT(blobs_to_check.size(), 0) << "No blobs to check.";
// Compute the gradient analytically using Backward
Caffe::set_random_seed(seed_);
// Ignore the loss from the layer (it's just the weighted sum of the losses
Expand Down
35 changes: 35 additions & 0 deletions include/caffe/util/gpu_util.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef CAFFE_UTIL_GPU_UTIL_H_
#define CAFFE_UTIL_GPU_UTIL_H_

namespace caffe {

template <typename Dtype>
inline __device__ Dtype caffe_gpu_atomic_add(const Dtype val, Dtype* address);

template <>
inline __device__
float caffe_gpu_atomic_add(const float val, float* address) {
return atomicAdd(address, val);
}

// double atomicAdd implementation taken from:
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#axzz3PVCpVsEG
template <>
inline __device__
double caffe_gpu_atomic_add(const double val, double* address) {
unsigned long long int* address_as_ull = // NOLINT(runtime/int)
// NOLINT_NEXT_LINE(runtime/int)
reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_ull; // NOLINT(runtime/int)
unsigned long long int assumed; // NOLINT(runtime/int)
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
}

} // namespace caffe

#endif // CAFFE_UTIL_GPU_UTIL_H_
13 changes: 7 additions & 6 deletions src/caffe/layers/concat_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
int offset_concat_axis = 0;
const int top_concat_axis = top[0]->shape(concat_axis_);
for (int i = 0; i < bottom.size(); ++i) {
if (!propagate_down[i]) { continue; }
Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
for (int n = 0; n < num_concats_; ++n) {
caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
(n * top_concat_axis + offset_concat_axis) * concat_input_size_,
bottom_diff + n * bottom_concat_axis * concat_input_size_);
if (propagate_down[i]) {
Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
for (int n = 0; n < num_concats_; ++n) {
caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
(n * top_concat_axis + offset_concat_axis) * concat_input_size_,
bottom_diff + n * bottom_concat_axis * concat_input_size_);
}
}
offset_concat_axis += bottom_concat_axis;
}
Expand Down
17 changes: 9 additions & 8 deletions src/caffe/layers/concat_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,16 @@ void ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const int top_concat_axis = top[0]->shape(concat_axis_);
const bool kForward = false;
for (int i = 0; i < bottom.size(); ++i) {
if (!propagate_down[i]) { continue; }
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
const int bottom_concat_size = bottom_concat_axis * concat_input_size_;
const int nthreads = bottom_concat_size * num_concats_;
Concat<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
nthreads, top_diff, kForward, num_concats_, concat_input_size_,
top_concat_axis, bottom_concat_axis, offset_concat_axis, bottom_diff);
if (propagate_down[i]) {
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
const int bottom_concat_size = bottom_concat_axis * concat_input_size_;
const int nthreads = bottom_concat_size * num_concats_;
Concat<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
nthreads, top_diff, kForward, num_concats_, concat_input_size_,
top_concat_axis, bottom_concat_axis, offset_concat_axis, bottom_diff);
}
offset_concat_axis += bottom_concat_axis;
}
}
Expand Down
122 changes: 122 additions & 0 deletions src/caffe/layers/embed_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/common_layers.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
N_ = this->layer_param_.embed_param().num_output();
CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive.";
K_ = this->layer_param_.embed_param().input_dim();
CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive.";
bias_term_ = this->layer_param_.embed_param().bias_term();
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
if (bias_term_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Initialize the weights --
// transposed from InnerProductLayer for spatial locality.
vector<int> weight_shape(2);
weight_shape[0] = K_;
weight_shape[1] = N_;
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
}

template <typename Dtype>
void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Figure out the dimensions
M_ = bottom[0]->count();
vector<int> top_shape = bottom[0]->shape();
top_shape.push_back(N_);
top[0]->Reshape(top_shape);
// Set up the bias multiplier
if (bias_term_) {
vector<int> bias_shape(1, M_);
bias_multiplier_.Reshape(bias_shape);
caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input";
caffe_copy(N_, weight + index * N_, top_data + n * N_);
}
if (bias_term_) {
const Dtype* bias = this->blobs_[1]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
bias_multiplier_.cpu_data(), bias, Dtype(1), top_data);
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
if (this->param_propagate_down_[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n])
<< "non-integer input";
caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_);
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
bias_multiplier_.cpu_data(), Dtype(1), bias_diff);
}
}

#ifdef CPU_ONLY
STUB_GPU(EmbedLayer);
#endif

INSTANTIATE_CLASS(EmbedLayer);
REGISTER_LAYER_CLASS(Embed);

} // namespace caffe
Loading

0 comments on commit e303c28

Please sign in to comment.