forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Math.cuh
92 lines (82 loc) · 2.54 KB
/
Math.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#pragma once
#include <ATen/AccumulateType.h>
namespace at {
namespace native {
/*
* The following function was converted to CUDA form from code that comes
* with the following copyright notice. It has been released under the BSD license.
*
* Cephes Math Library Release 2.8: June, 2000
* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
*/
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,
-4.16666666666666666667E-3,
3.96825396825396825397E-3,
-8.33333333333333333333E-3,
8.33333333333333333333E-2,
};
accscalar_t x = static_cast<accscalar_t>(in);
if (x == 0) {
return static_cast<scalar_t>(INFINITY);
}
bool x_is_integer = x == ::floor(x);
accscalar_t result = 0;
if (x < 0) {
if (x_is_integer) {
return static_cast<scalar_t>(INFINITY);
}
// Rounding errors in tan's input can really affect the output
// for extreme values, so we always perform this computation in double.
result = static_cast<accscalar_t>(- PI_f64 / ::tan(PI_f64 * static_cast<double>(x)));
x = 1 - x;
}
while (x < 10) {
result -= 1 / x;
x += 1;
}
if (x == 10) {
return static_cast<scalar_t>(result + PSI_10);
}
accscalar_t y = 0;
if (x < 1.0e17) {
accscalar_t z = 1.0 / (x * x);
accscalar_t polevl_result = 0;
for (int i = 0; i <= 6; i++) {
polevl_result = polevl_result * z + A[i];
}
y = z * polevl_result;
}
return static_cast<scalar_t>(::log(x) - (0.5 / x) - y + result);
}
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
const accscalar_t PI = 3.14159265358979323846;
accscalar_t x = static_cast<accscalar_t>(in);
accscalar_t sign = +1;
accscalar_t result = 0;
if (x < 0.5f) {
sign = -1;
accscalar_t sin_pi_x = ::sin(PI * x);
result -= (PI * PI) / (sin_pi_x * sin_pi_x);
x = 1 - x;
}
for (int i = 0; i < 6; ++i) {
result += 1 / (x * x);
x += 1;
}
const accscalar_t one = static_cast<scalar_t>(1);
const accscalar_t ixx = 1 / (x*x);
result += (1 + 1 / (2*x) + ixx * (one/6 - ixx * (one/30 - ixx * (one/42)))) / x;
return static_cast<scalar_t>(sign * result);
}
}
}