forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Distributions.cpp
84 lines (75 loc) · 2.82 KB
/
Distributions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/cuda/Distributions.h>
#include <ATen/TensorIterator.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_dirichlet_grad_native.h>
#include <ATen/ops/_sample_dirichlet_native.h>
#include <ATen/ops/_standard_gamma_grad_native.h>
#include <ATen/ops/_standard_gamma_native.h>
#include <ATen/ops/binomial_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/poisson_native.h>
#endif
namespace at::native {
Tensor _s_poisson_cuda(const Tensor& lambda, c10::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
Tensor ret = at::empty(lambda.sizes(), lambda.options());
launch_poisson_cuda_kernel(ret, lambda, gen);
return ret;
}
Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
Tensor ret = at::empty(count.sizes(), count.options());
at::TensorIterator iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(count)
.add_input(prob)
.build();
launch_binomial_cuda_kernel(iter, gen);
return ret;
}
Tensor _s_gamma_cuda(const Tensor& alpha, c10::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
Tensor ret = at::empty(alpha.sizes(), alpha.options());
launch_gamma_kernel(ret, alpha, gen);
return ret;
}
Tensor _s_dirichlet_cuda(const Tensor& alpha, c10::optional<Generator> gen_) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
Tensor ret = at::empty(alpha.sizes(), alpha.options());
launch_gamma_kernel(ret, alpha, gen);
auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true);
at::TensorIterator iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(ret)
.add_input(gamma_sum)
.build();
launch_dirichlet_kernel(iter);
return ret;
}
Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
Tensor ret = at::empty(self.sizes(), self.options());
TensorIterator iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(self)
.add_input(output)
.build();
launch_standard_gamma_grad_kernel(iter);
return ret;
}
Tensor _dirichlet_grad_cuda(const Tensor& x, const Tensor& alpha, const Tensor& total) {
Tensor ret = at::empty(x.sizes(), x.options());
TensorIterator iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(x)
.add_input(alpha)
.add_input(total)
.build();
launch_dirichlet_grad_kernel(iter);
return ret;
}
} // namespace at::native