Skip to content

Commit

Permalink
[Phi] move gaussian_random, fix fp16 (#40122)
Browse files Browse the repository at this point in the history
[Phi] move gaussian_random, fix fp16
  • Loading branch information
windstamp committed Mar 4, 2022
1 parent b7bbe39 commit 8374065
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions paddle/phi/kernels/gpu/gaussian_random_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,25 @@ void GaussianRandomKernel(const Context& dev_ctx,
int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);

using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std);
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
GaussianGenerator<MT>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
auto func = GaussianGenerator<T>(static_cast<T>(mean),
static_cast<T>(std),
seed_offset.first,
gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
}
} else {
auto func = GaussianGenerator<MT>(mean, std, seed);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
auto func =
GaussianGenerator<T>(static_cast<T>(mean), static_cast<T>(std), seed);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
}
}

Expand Down

0 comments on commit 8374065

Please sign in to comment.