forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DistributionRandomKernel.cu
27 lines (21 loc) · 1.21 KB
/
DistributionRandomKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/DistributionTemplates.h>
namespace at::native {
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
at::native::templates::cuda::random_from_to_kernel(iter, range, base, gen);
}
void random_full_64_bits_range_kernel(TensorIteratorBase& iter, std::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
at::native::templates::cuda::random_full_64_bits_range_kernel(iter, gen);
}
void random_kernel(TensorIteratorBase& iter, std::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
at::native::templates::cuda::random_kernel(iter, gen);
}
REGISTER_DISPATCH(random_from_to_stub, &random_from_to_kernel)
REGISTER_DISPATCH(random_stub, &random_kernel)
REGISTER_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel)
} // namespace at::native