forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GatedLinearUnit.cpp
142 lines (124 loc) · 5.1 KB
/
GatedLinearUnit.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <ATen/ATen.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/Activation.h>
namespace at {
namespace meta {
TORCH_META_FUNC(glu) (
const Tensor& self, int64_t dim
) {
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
// can't be evenly halved, but give a nicer error message here.
TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
const int64_t nIn = self.size(wrap_dim);
TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
wrap_dim, " is size ", nIn);
// size output to half of input
const int64_t selfSize = nIn / 2;
Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize);
Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize);
build_borrowing_binary_op(maybe_get_output(), firstHalf, secondHalf);
}
} // namespace meta
namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(glu_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(glu_backward_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(glu_jvp_stub);
TORCH_IMPL_FUNC(glu_out) (const Tensor& self, int64_t dim, const Tensor& out) {
glu_stub(device_type(), *this);
}
Tensor& glu_backward_cpu_out(const Tensor& grad_output, const Tensor& input,
int64_t dim, Tensor& grad_input) {
TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors");
auto wrap_dim = maybe_wrap_dim(dim, input.dim());
const int64_t nIn = input.size(wrap_dim);
TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
wrap_dim, " is size ", nIn);
grad_input.resize_as_(input);
const int64_t inputSize = nIn / 2;
// half tensor
Tensor firstHalf = input.narrow(wrap_dim, 0, inputSize);
Tensor secondHalf = input.narrow(wrap_dim, inputSize, inputSize);
Tensor gradInputfirstHalf = grad_input.narrow(wrap_dim, 0, inputSize);
Tensor gradInputsecondHalf = grad_input.narrow(wrap_dim, inputSize, inputSize);
at::sigmoid_out(gradInputfirstHalf, secondHalf);
// for second gradinput half, can get a better performance by fusion
auto iter = at::TensorIteratorConfig()
.add_output(gradInputsecondHalf)
.add_input(gradInputfirstHalf)
.add_input(firstHalf)
.add_input(grad_output)
.build();
glu_backward_stub(iter.device_type(), iter);
gradInputfirstHalf.mul_(grad_output);
return grad_input;
}
Tensor glu_backward_cpu(const Tensor& grad_output, const Tensor& input, int64_t dim) {
auto grad_input = at::empty({0}, input.options());
return glu_backward_cpu_out(grad_output, input, dim, grad_input);
}
Tensor glu_jvp(
const Tensor& glu,
const Tensor& x,
const Tensor& dx,
int64_t dim
) {
dim = maybe_wrap_dim(dim, x.dim());
const auto glu_size = glu.size(dim);
const auto b = x.narrow(dim, glu_size, glu_size);
const auto da = dx.narrow(dim, 0, glu_size);
const auto db = dx.narrow(dim, glu_size, glu_size);
auto dglu = at::empty_like(glu);
auto iter = at::TensorIteratorConfig()
.add_output(dglu)
.add_input(glu)
.add_input(b)
.add_input(da)
.add_input(db)
.build();
glu_jvp_stub(iter.device_type(), iter);
return dglu;
}
Tensor glu_backward_jvp(
const Tensor& grad_x,
const Tensor& grad_glu,
const Tensor& x,
const Tensor& dgrad_glu,
const Tensor& dx,
int64_t dim
) {
dim = maybe_wrap_dim(dim, x.dim());
const auto glu_size = grad_glu.size(dim);
const auto a = x.narrow(dim, 0, glu_size);
const auto b = x.narrow(dim, glu_size, glu_size);
const auto da = dx.narrow(dim, 0, glu_size);
const auto db = dx.narrow(dim, glu_size, glu_size);
// grad_x_a = grad_glu * sigmoid(b)
const auto grad_x_a = grad_x.narrow(dim, 0, glu_size);
// grad_x_b = grad_x_a * a * (1 - sigmoid(b))
const auto grad_x_b = grad_x.narrow(dim, glu_size, glu_size);
const auto sig_b = at::sigmoid(b);
// TODO: use glu from forward.
// TODO: fuse kernels.
const auto glu = a * sig_b;
const auto db_neg_sig_b = db - db * sig_b;
// dgrad_x_a = d(grad_glu * sigmoid(b))
// = dgrad_glu * sigmoid(b) + grad_glu * sigmoid(b) * (1 - sigmoid(b)) * db
// = dgrad_glu * sig_b + grad_x_a * (db - db * sig_b)
// = dgrad_glu * sig_b + grad_x_a * db_neg_sig_b
const auto dgrad_x_a = dgrad_glu * sig_b + grad_x_a * db_neg_sig_b;
// dgrad_x_b = d(grad_glu * sigmoid(b) * a * (1 - sigmoid(b))
// = d(grad_glu * sigmoid(b)) * a * (1 - sigmoid(b))
// + grad_glu * sigmoid(b) * da * (1 - sigmoid(b))
// - grad_glu * sigmoid(b) * a * sigmoid(b) * (1 - sigmoid(b)) * db
// = dgrad_x_a * a * (1 - sigmoid(b))
// + (grad_glu * sigmoid(b)) * (da * (1 - sigmoid(b)) - a * sigmoid(b) * (1 - sigmoid(b)) * db)
// = dgrad_x_a * (a - glu) + grad_x_a * (da - da * sig_b - glu * db_neg_sig_b
const auto dgrad_x_b = dgrad_x_a * (a - glu) + grad_x_a * (da - da * sig_b - glu * db_neg_sig_b);
return at::cat({dgrad_x_a, dgrad_x_b}, dim);
}
} // at::native
} // at