Skip to content

Commit

Permalink
fix link error for parallel rng (apache#9256)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Dec 30, 2017
1 parent 6e5c6f8 commit ecf4c9f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
6 changes: 6 additions & 0 deletions src/common/random_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ namespace mxnet {
namespace common {
namespace random {

template<>
const int RandGenerator<gpu, float>::kMinNumRandomPerThread = 64;

template<>
const int RandGenerator<gpu, float>::kNumRandomStates = 32768;

__global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t *states_,
const int size,
uint32_t seed) {
Expand Down
25 changes: 13 additions & 12 deletions src/common/random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ template<typename DType>
class RandGenerator<cpu, DType> {
public:
// at least how many random numbers should be generated by one CPU thread.
static const int kMinNumRandomPerThread = 64;
static const int kMinNumRandomPerThread;
// store how many global random states for CPU.
static const int kNumRandomStates = 1024;
static const int kNumRandomStates;

// implementation class for random number generator
class Impl {
Expand Down Expand Up @@ -96,17 +96,23 @@ class RandGenerator<cpu, DType> {

private:
std::mt19937 *states_;
};
}; // class RandGenerator<cpu, DType>

template<typename DType>
const int RandGenerator<cpu, DType>::kMinNumRandomPerThread = 64;

template<typename DType>
const int RandGenerator<cpu, DType>::kNumRandomStates = 1024;

#if MXNET_USE_CUDA

template<typename DType>
class RandGenerator<gpu, DType> {
public:
// at least how many random numbers should be generated by one GPU thread.
static const int kMinNumRandomPerThread = 64;
static const int kMinNumRandomPerThread;
// store how many global random states for GPU.
static const int kNumRandomStates = 32768;
static const int kNumRandomStates;

// uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
// by using 1.0-curand_uniform().
Expand Down Expand Up @@ -159,16 +165,11 @@ class RandGenerator<gpu, DType> {

private:
curandStatePhilox4_32_10_t *states_;
};
}; // class RandGenerator<gpu, DType>

template<>
class RandGenerator<gpu, double> {
public:
// at least how many random numbers should be generated by one GPU thread.
static const int kMinNumRandomPerThread = 64;
// store how many global random states for GPU.
static const int kNumRandomStates = 32768;

// uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
// by using 1.0-curand_uniform().
// Needed as some samplers in sampler.h won't be able to deal with
Expand Down Expand Up @@ -209,7 +210,7 @@ class RandGenerator<gpu, double> {

private:
curandStatePhilox4_32_10_t *states_;
};
}; // class RandGenerator<gpu, double>

#endif // MXNET_USE_CUDA

Expand Down
6 changes: 3 additions & 3 deletions src/operator/random/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ template<typename OP, typename xpu, typename GType, typename ...Args>
inline static void LaunchRNG(mshadow::Stream<xpu> *s,
common::random::RandGenerator<xpu, GType> *gen,
const int N, Args... args) {
const int nloop = (N + RandGenerator<xpu, GType>::kMinNumRandomPerThread - 1) /
RandGenerator<xpu, GType>::kMinNumRandomPerThread;
const int nthread = std::min(nloop, RandGenerator<xpu, GType>::kNumRandomStates);
const int nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
RandGenerator<xpu>::kMinNumRandomPerThread;
const int nthread = std::min(nloop, RandGenerator<xpu>::kNumRandomStates);
const int step = (N + nthread - 1) / nthread;
Kernel<OP, xpu>::Launch(s, nthread, *gen, N, step, args...);
}
Expand Down

0 comments on commit ecf4c9f

Please sign in to comment.