Skip to content

Commit

Permalink
fix dirichlet op tempalte paramters and macro code style
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Dec 27, 2021
1 parent a290251 commit 818d4fa
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 57 deletions.
12 changes: 6 additions & 6 deletions paddle/fluid/operators/dirichlet_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@

namespace paddle {
namespace operators {
template <typename T, typename uniform_sampler_t, typename normal_sampler_t>
template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha, T* gamma,
BaseSampler<T, uniform_sampler_t> uniform,
BaseSampler<T, normal_sampler_t> normal)
BaseSampler<T, UniformSamplerT> uniform,
BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}

HOST void operator()(int64_t index) {
auto sample = sample_gamma<T, T, uniform_sampler_t, normal_sampler_t>(
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}

const T* alpha_;
T* gamma_;
BaseSampler<T, uniform_sampler_t> uniform_;
BaseSampler<T, normal_sampler_t> normal_;
BaseSampler<T, UniformSamplerT> uniform_;
BaseSampler<T, NormalSamplerT> normal_;
};

template <typename T>
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/operators/dirichlet_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
#endif

#if defined(PADDLE_WITH_CUDA)
using compatRandStatePhilox4_32_10_t = curandStatePhilox4_32_10_t;
#define compat_rand_init curand_init
#define compat_rand_uniform curand_uniform
#define compat_rand_normal curand_normal
using COMPAT_RANDSTATEPHILOX4_32_10_T = curandStatePhilox4_32_10_t;
#define COMPAT_RAND_INIT curand_init
#define COMPAT_RAND_UNIFORM curand_uniform
#define COMPAT_RAND_NORMAL curand_normal
#elif defined(PADDLE_WITH_HIP)
using compatRandStatePhilox4_32_10_t = hiprandStatePhilox4_32_10_t;
#define compat_rand_init hiprand_init
#define compat_rand_uniform hiprand_uniform
#define compat_rand_normal hiprand_normal
using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t;
#define COMPAT_RAND_INIT hiprand_init
#define COMPAT_RAND_UNIFORM hiprand_uniform
#define COMPAT_RAND_NORMAL hiprand_normal
#endif

namespace paddle {
Expand All @@ -47,14 +47,14 @@ struct GammaCUDAFunctor {

DEVICE void operator()(int64_t index) {
// curand initialization
compatRandStatePhilox4_32_10_t state;
compat_rand_init(/*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_,
COMPAT_RANDSTATEPHILOX4_32_10_T state;
COMPAT_RAND_INIT(/*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_,
&state);

// sample
auto uniform_lambda = [&state]() { return compat_rand_uniform(&state); };
auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); };
BaseSampler<T, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto normal_lambda = [&state]() { return compat_rand_normal(&state); };
auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); };
BaseSampler<T, decltype(normal_lambda)> standard_normal(normal_lambda);

auto sample =
Expand Down
77 changes: 38 additions & 39 deletions paddle/fluid/operators/dirichlet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,37 @@

// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(PADDLE_WITH_CUDA)
#define compat_exp exp
#define compat_ceil ceil
#define compat_floor floor
#define compat_log log
#define compat_pow pow
#define compat_sqrt sqrt
#define compat_tan tan
#define compat_abs abs
#define compat_log1p log1p
#define COMPAT_EXP exp
#define COMPAT_CEIL ceil
#define COMPAT_FLOOR floor
#define COMPAT_LOG log
#define COMPAT_POW pow
#define COMPAT_SQRT sqrt
#define COMPAT_TAN tan
#define COMPAT_ABS abs
#define COMPAT_LOG1P log1p
#else
#define compat_exp std::exp
#define compat_ceil std::ceil
#define compat_floor std::floor
#define compat_log std::log
#define compat_pow std::pow
#define compat_sqrt std::sqrt
#define compat_tan std::tan
#define compat_abs std::abs
#define compat_log1p std::log1p
#define COMPAT_EXP std::exp
#define COMPAT_CEIL std::ceil
#define COMPAT_FLOOR std::floor
#define COMPAT_LOG std::log
#define COMPAT_POW std::pow
#define COMPAT_SQRT std::sqrt
#define COMPAT_TAN std::tan
#define COMPAT_ABS std::abs
#define COMPAT_LOG1P std::log1p
#endif

namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DirichletSampler;

template <typename scalar_t, typename sampler_t>
template <typename ScalarT, typename SamplerT>
struct BaseSampler {
sampler_t sampler_;
HOSTDEVICE BaseSampler(const sampler_t& sampler) : sampler_(sampler) {}
HOSTDEVICE scalar_t sample() { return sampler_(); }
SamplerT sampler_;
HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {}
HOSTDEVICE ScalarT sample() { return sampler_(); }
};

// `sample_gamma` is d from Numpy's distributions.c, and add support for
Expand Down Expand Up @@ -78,39 +78,38 @@ struct BaseSampler {
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

template <typename scalar_t, typename accscalar_t, typename uniform_sampler_t,
typename normal_sampler_t>
HOSTDEVICE scalar_t
sample_gamma(scalar_t alpha,
BaseSampler<accscalar_t, uniform_sampler_t> standard_uniform,
BaseSampler<accscalar_t, normal_sampler_t> standard_normal) {
accscalar_t scale = 1.0f;
template <typename ScalarT, typename AccscalarT, typename UniformSamplerT,
typename NormalSamplerT>
HOSTDEVICE ScalarT sample_gamma(
ScalarT alpha, BaseSampler<AccscalarT, UniformSamplerT> standard_uniform,
BaseSampler<AccscalarT, NormalSamplerT> standard_normal) {
AccscalarT scale = 1.0f;

// Boost alpha for higher acceptance probability.
if (alpha < 1.0f) {
if (alpha == 0.f) return 0.f;
scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha);
alpha += 1.0f;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang
// (2000)
// doi:10.1145/358407.358414
const accscalar_t d = alpha - 1.0f / 3.0f;
const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
const AccscalarT d = alpha - 1.0f / 3.0f;
const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d);
for (;;) {
accscalar_t x, y;
AccscalarT x, y;
do {
x = standard_normal.sample();
y = 1.0f + c * x;
} while (y <= 0);
const accscalar_t v = y * y * y;
const accscalar_t u = 1 - standard_uniform.sample();
const accscalar_t xx = x * x;
const AccscalarT v = y * y * y;
const AccscalarT u = 1 - standard_uniform.sample();
const AccscalarT xx = x * x;
if (u < 1.0f - 0.0331f * xx * xx)
return static_cast<scalar_t>(scale * d * v);
if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
return static_cast<scalar_t>(scale * d * v);
return static_cast<ScalarT>(scale * d * v);
if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v)))
return static_cast<ScalarT>(scale * d * v);
}
}

Expand Down

0 comments on commit 818d4fa

Please sign in to comment.