Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug related to Boost RNG #297

Merged
merged 17 commits into from
Apr 9, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,21 @@ class Caffe {
public:
RNG();
explicit RNG(unsigned int seed);
~RNG();
RNG(const RNG&);
explicit RNG(const RNG&);
RNG& operator=(const RNG&);
const void* generator() const;
void* generator();
void set_generator(const void* other_rng);
private:
class Generator;
Generator* generator_;
shared_ptr<Generator> generator_;
};

// Getters for boost rng, curand, and cublas handles
inline static RNG &rng_stream() {
return Get().random_generator_;
inline static const RNG& rng_stream() {
if (!Get().random_generator_) {
Get().random_generator_.reset(new RNG());
}
return *(Get().random_generator_);
}
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
inline static curandGenerator_t curand_generator() {
Expand All @@ -118,6 +120,9 @@ class Caffe {
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both boost and curand
static void set_random_seed(const unsigned int seed);
// Sets the boost RNG engine from another RNG engine to maintain state across
// variate_generator calls.
static void set_generator(const void* other_rng);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
static void SetDevice(const int device_id);
Expand All @@ -127,7 +132,7 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
RNG random_generator_;
shared_ptr<RNG> random_generator_;

Brew mode_;
Phase phase_;
Expand Down
16 changes: 7 additions & 9 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ class UniformFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
Dtype(this->filler_param_.min()),
Dtype(this->filler_param_.max()));
caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
}
};

Expand All @@ -65,9 +64,8 @@ class GaussianFiller : public Filler<Dtype> {
virtual void Fill(Blob<Dtype>* blob) {
Dtype* data = blob->mutable_cpu_data();
CHECK(blob->count());
caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
Dtype(this->filler_param_.mean()),
Dtype(this->filler_param_.std()));
caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
}
};

Expand All @@ -79,7 +77,7 @@ class PositiveUnitballFiller : public Filler<Dtype> {
virtual void Fill(Blob<Dtype>* blob) {
Dtype* data = blob->mutable_cpu_data();
DCHECK(blob->count());
caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
// We expect the filler to not be called very frequently, so we will
// just use a simple implementation
int dim = blob->count() / blob->num();
Expand Down Expand Up @@ -113,8 +111,8 @@ class XavierFiller : public Filler<Dtype> {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
Dtype scale = sqrt(Dtype(3) / fan_in);
caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
-scale, scale);
caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
blob->mutable_cpu_data());
}
};

Expand Down
27 changes: 23 additions & 4 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,33 @@ template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);

template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);
void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r);

// caffe_gpu_rng_uniform with two arguments generates integers in the range
// [0, UINT_MAX].
void caffe_gpu_rng_uniform(const int n, unsigned int* r);

// caffe_gpu_rng_uniform with four arguments generates floats in the range
// (a, b] (strictly greater than a, less than or equal to b) due to the
// specification of curandGenerateUniform. With a = 0, b = 1, just calls
// curandGenerateUniform; with other limits will shift and scale the outputs
// appropriately after calling curandGenerateUniform.
template <typename Dtype>
void caffe_gpu_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r);

template <typename Dtype>
void caffe_rng_gaussian(const int n, const Dtype mu, const Dtype sigma,
Dtype* r);

template <typename Dtype>
void caffe_gpu_rng_gaussian(const int n, const Dtype mu, const Dtype sigma,
Dtype* r);

template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);
void caffe_rng_bernoulli(const int n, const Dtype p, int* r);

template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p);
void caffe_gpu_rng_bernoulli(const int n, const Dtype p, int* r);

template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);
Expand Down
10 changes: 7 additions & 3 deletions include/caffe/util/rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
namespace caffe {

typedef boost::mt19937 rng_t;
inline rng_t& caffe_rng() {
Caffe::RNG &generator = Caffe::rng_stream();
return *(caffe::rng_t*) generator.generator();

inline const rng_t& caffe_rng() {
return *static_cast<const caffe::rng_t*>(Caffe::rng_stream().generator());
}

inline void caffe_set_rng(const caffe::rng_t& other) {
Caffe::set_generator(static_cast<const void*>(&other));
}

} // namespace caffe
Expand Down
6 changes: 3 additions & 3 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class DropoutLayer : public NeuronLayer<Dtype> {
const bool propagate_down, vector<Blob<Dtype>*>* bottom);

shared_ptr<SyncedMemory> rand_vec_;
float threshold_;
float scale_;
Dtype threshold_;
Dtype scale_;
unsigned int uint_thres_;
};

Expand Down Expand Up @@ -607,7 +607,7 @@ class PoolingLayer : public Layer<Dtype> {
int width_;
int pooled_height_;
int pooled_width_;
Blob<float> rand_idx_;
Blob<Dtype> rand_idx_;
};

template <typename Dtype>
Expand Down
42 changes: 22 additions & 20 deletions src/caffe/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ void Caffe::set_random_seed(const unsigned int seed) {
LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
}
// RNG seed
Get().random_generator_ = RNG(seed);
Get().random_generator_.reset(new RNG(seed));
}

void Caffe::set_generator(const void* other_rng) {
Get().random_generator_->set_generator(other_rng);
}

void Caffe::SetDevice(const int device_id) {
Expand Down Expand Up @@ -117,36 +121,34 @@ void Caffe::DeviceQuery() {

class Caffe::RNG::Generator {
public:
caffe::rng_t rng;
Generator() : rng_(new caffe::rng_t(cluster_seedgen())) {}
explicit Generator(unsigned int seed) : rng_(new caffe::rng_t(seed)) {}
explicit Generator(const caffe::rng_t& other) :
rng_(new caffe::rng_t(other)) {}
const caffe::rng_t& rng() const { return *rng_; }
private:
shared_ptr<caffe::rng_t> rng_;
};

Caffe::RNG::RNG()
: generator_(new Generator) {
generator_->rng = caffe::rng_t(cluster_seedgen());
}

Caffe::RNG::RNG(unsigned int seed)
: generator_(new Generator) {
generator_->rng = caffe::rng_t(seed);
}
Caffe::RNG::RNG() : generator_(new Generator) { }

Caffe::RNG::~RNG() { delete generator_; }
Caffe::RNG::RNG(unsigned int seed) : generator_(new Generator(seed)) { }

Caffe::RNG::RNG(const RNG& other) : generator_(new Generator) {
*generator_ = *other.generator_;
}
Caffe::RNG::RNG(const RNG& other) : generator_(new Generator(*other.generator_))
{ }

Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
*generator_ = *other.generator_;
generator_.reset(other.generator_.get());
return *this;
}

void* Caffe::RNG::generator() {
return &generator_->rng;
const void* Caffe::RNG::generator() const {
return static_cast<const void*>(&generator_->rng());
}

const void* Caffe::RNG::generator() const {
return &generator_->rng;
void Caffe::RNG::set_generator(const void* other_rng) {
const caffe::rng_t& rng = *static_cast<const caffe::rng_t*>(other_rng);
return generator_.reset(new Generator(rng));
}

const char* cublasGetErrorString(cublasStatus_t error) {
Expand Down
6 changes: 3 additions & 3 deletions src/caffe/layers/dropout_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
DCHECK(threshold_ > 0.);
DCHECK(threshold_ < 1.);
scale_ = 1. / (1. - threshold_);
uint_thres_ = (unsigned int)(UINT_MAX * threshold_);
uint_thres_ = static_cast<unsigned int>(UINT_MAX * threshold_);
}

template <typename Dtype>
Expand All @@ -32,12 +32,12 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int count = bottom[0]->count();
if (Caffe::phase() == Caffe::TRAIN) {
// Create random numbers
caffe_vRngBernoulli<int>(count, mask, 1. - threshold_);
caffe_rng_bernoulli(count, 1. - threshold_, mask);
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
}
} else {
memcpy(top_data, bottom_data, bottom[0]->count() * sizeof(Dtype));
caffe_copy(bottom[0]->count(), bottom_data, top_data);
}
return Dtype(0);
}
Expand Down
15 changes: 8 additions & 7 deletions src/caffe/layers/dropout_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "caffe/layer.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"

using std::max;

Expand All @@ -30,17 +31,16 @@ Dtype DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
if (Caffe::phase() == Caffe::TRAIN) {
CURAND_CHECK(curandGenerate(Caffe::curand_generator(),
(unsigned int*)(rand_vec_->mutable_gpu_data()), count));
unsigned int* mask =
static_cast<unsigned int*>(rand_vec_->mutable_gpu_data());
caffe_gpu_rng_uniform(count, mask);
// set thresholds
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, (unsigned int*)rand_vec_->gpu_data(), uint_thres_,
scale_, top_data);
count, bottom_data, mask, uint_thres_, scale_, top_data);
CUDA_POST_KERNEL_CHECK;
} else {
CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
count * sizeof(Dtype), cudaMemcpyDeviceToDevice));
caffe_gpu_copy(count, bottom_data, top_data);
}
return Dtype(0);
}
Expand All @@ -62,7 +62,8 @@ void DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (propagate_down) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const unsigned int* mask = (unsigned int*)rand_vec_->gpu_data();
const unsigned int* mask =
static_cast<const unsigned int*>(rand_vec_->gpu_data());
const int count = (*bottom)[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
Expand Down
8 changes: 4 additions & 4 deletions src/caffe/layers/pooling_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ __global__ void StoPoolForwardTrain(const int nthreads,
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, float* rand_idx, Dtype* top_data) {
const int kernel_size, const int stride, Dtype* rand_idx, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
Expand Down Expand Up @@ -163,8 +163,8 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
case PoolingParameter_PoolMethod_STOCHASTIC:
if (Caffe::phase() == Caffe::TRAIN) {
// We need to create the random index as well.
CURAND_CHECK(curandGenerateUniform(Caffe::curand_generator(),
rand_idx_.mutable_gpu_data(), count));
caffe_gpu_rng_uniform(count, Dtype(0), Dtype(1),
rand_idx_.mutable_gpu_data());
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
Expand Down Expand Up @@ -257,7 +257,7 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,

template <typename Dtype>
__global__ void StoPoolBackward(const int nthreads,
const float* rand_idx, const Dtype* top_diff,
const Dtype* rand_idx, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* bottom_diff) {
Expand Down
10 changes: 4 additions & 6 deletions src/caffe/test/test_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ TEST_F(CommonTest, TestRandSeedCPU) {
SyncedMemory data_a(10 * sizeof(int));
SyncedMemory data_b(10 * sizeof(int));
Caffe::set_random_seed(1701);
caffe_vRngBernoulli(10,
reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
caffe_rng_bernoulli(10, 0.5, static_cast<int*>(data_a.mutable_cpu_data()));

Caffe::set_random_seed(1701);
caffe_vRngBernoulli(10,
reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
caffe_rng_bernoulli(10, 0.5, static_cast<int*>(data_b.mutable_cpu_data()));

for (int i = 0; i < 10; ++i) {
EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
((const int*)(data_b.cpu_data()))[i]);
EXPECT_EQ(static_cast<const int*>(data_a.cpu_data())[i],
static_cast<const int*>(data_b.cpu_data())[i]);
}
}

Expand Down
Loading