forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AdaptiveAveragePooling.cu
432 lines (357 loc) · 13.1 KB
/
AdaptiveAveragePooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/NativeFunctions.h"
#include "ATen/TensorUtils.h"
#include "ATen/Utils.h"
#include "c10/util/Exception.h"
#include <THC/THCGeneral.h>
#include "THC/THCNumerics.cuh"
#include <algorithm>
#include <cfloat>
#include <cmath>
#define START_IND(a,b,c) (int)std::floor((float)(a * c) / b)
#define END_IND(a,b,c) (int)std::ceil((float)((a + 1) * c) / b)
// #define START_IND(a,b,c) a * c / b
// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
namespace at {
namespace native {
namespace {
// 4d tensor B x D x H x W
// All kernels view batch dim B and feature dim D as collapsed.
/*
* Description:
* this function adaptively average pools an input 4D tensor along dimensions 2 and 3
* 4D input, 4D output
*/
template <typename T>
__global__ void adaptiveaveragepool(T *input, T *output,
int isizeH, int isizeW,
int osizeH, int osizeW,
int64_t istrideD, int64_t istrideH, int64_t istrideW)
{
// iterators on output pixels
int oh, ow;
// select input/output plane based on thread/block ID
int o_plane = blockIdx.x;
int i_plane = o_plane;
output = output + o_plane*osizeH*osizeW;
input = input + i_plane*istrideD;
int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
int oendH = osizeH;
const int ostepH = blockDim.y*gridDim.y;
int ostartW = threadIdx.x;
int oendW = osizeW;
const int ostepW = blockDim.x;
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
int istartH = START_IND(oh, osizeH, isizeH);
int iendH = END_IND(oh, osizeH, isizeH);
int kH = iendH - istartH;
for(ow = ostartW; ow < oendW; ow += ostepW) {
int istartW = START_IND(ow, osizeW, isizeW);
int iendW = END_IND(ow, osizeW, isizeW);
int kW = iendW - istartW;
// Compute the average pooling over corresponding input pixels
T *ptr_input = input + istartH*istrideH + istartW*istrideW;
T *ptr_output = output + oh*osizeW + ow;
T sum = ScalarConvert<int, T>::to(0);
int ih, iw;
for(ih = 0; ih < kH; ++ih) {
for(iw = 0; iw < kW; ++iw) {
T val = ptr_input[iw*istrideW];
sum += val;
}
ptr_input += istrideH; // next input line
}
// Update output
*ptr_output = sum / kH / kW;
}
}
}
/*
* Description:
* this function computes the gradInput from gradOutput
*/
template <typename T>
__global__ void adaptiveaveragegradinput(
T *gradInput, T *gradOutput,
int isizeH, int isizeW, int osizeH, int osizeW
)
{
// iterators on input pixels
int ih, iw;
// select input/output plane based on thread/block ID
int i_plane = blockIdx.x;
int o_plane = i_plane;
gradOutput = gradOutput + o_plane*osizeH*osizeW;
gradInput = gradInput + i_plane*isizeH*isizeW;
int istartH = blockDim.y*blockIdx.y + threadIdx.y;
int iendH = isizeH;
int istepH = blockDim.y*gridDim.y;
int istartW = threadIdx.x;
int iendW = isizeW;
int istepW = blockDim.x;
// compute gradInput
for(ih = istartH; ih < iendH; ih += istepH) {
int ostartH = START_IND(ih, isizeH, osizeH);
int oendH = END_IND(ih, isizeH, osizeH);
for(iw = istartW; iw < iendW; iw += istepW) {
int ostartW = START_IND(iw, isizeW, osizeW);
int oendW = END_IND(iw, isizeW, osizeW);
// Compute the gradients over corresponding output pixels
T *ptr_gradInput = gradInput + ih*isizeW + iw;
int oh, ow;
for(oh = ostartH; oh < oendH; ++oh) {
int kH = START_IND(oh, osizeH, isizeH) - END_IND(oh, osizeH, isizeH);
for(ow = ostartW; ow < oendW; ++ow) {
int kW = START_IND(ow, osizeW, isizeW) - END_IND(ow, osizeW, isizeW);
T grad_delta = gradOutput[ow + oh*osizeW] / kH / kW;
*ptr_gradInput += grad_delta;
}
}
}
}
}
/*
* Description:
* this function computes the gradInput from gradOutput
* (uses atomic add)
*/
template <typename T>
__global__ void atomicadaptiveaveragegradinput(
T *gradInput, T *gradOutput,
int isizeH, int isizeW, int osizeH, int osizeW
)
{
// iterators on output indices
int oh, ow;
// select input/output plane based on thread/block ID
int o_plane = blockIdx.x;
int i_plane = o_plane;
gradOutput = gradOutput + o_plane*osizeW*osizeH;
gradInput = gradInput + i_plane*isizeW*isizeH;
int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
int oendH = osizeH;
int ostepH = blockDim.y*gridDim.y;
int ostartW = threadIdx.x;
int oendW = osizeW;
int ostepW = blockDim.x;
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
int istartH = START_IND(oh, osizeH, isizeH);
int iendH = END_IND(oh, osizeH, isizeH);
int kH = iendH - istartH;
for(ow = ostartW; ow < oendW; ow += ostepW) {
int istartW = START_IND(ow, osizeW, isizeW);
int iendW = END_IND(ow, osizeW, isizeW);
int kW = iendW - istartW;
// Compute the gradients for over corresponding input pixels
T *ptr_gradInput = gradInput + istartH*isizeW + istartW;
T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
T grad_delta = *ptr_gradOutput / kW / kH;
int ih, iw;
for(ih = 0; ih < kH; ++ih) {
for(iw = 0; iw < kW; ++iw) {
// atomic add since different threads could update same variable
atomicAdd(&(ptr_gradInput[iw]), grad_delta);
}
ptr_gradInput += isizeW; // next input line
}
}
}
}
// 4d tensor B x D x H x W
void adaptive_avg_pool2d_out_cuda_template(
Tensor& output,
const Tensor& input,
IntList output_size)
{
TensorArg input_arg{ input, "input", 1 },
output_arg{ output, "output", 2 };
checkAllSameGPU("cudnn_adaptive_avg_pooling2d", {input_arg, output_arg});
for (int64_t i = 0; i < input.ndimension(); i++) {
AT_CHECK(input.size(i) > 0,
"adaptive_avg_pooling2d(): expected input to have non-empty spatial dimensions, "
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
"empty");
}
AT_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
int64_t isizeW = input.size(2);
int64_t istrideD = input.stride(0);
int64_t istrideH = input.stride(1);
int64_t istrideW = input.stride(2);
int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d", [&] {
scalar_t *input_data = input.data<scalar_t>();
output.resize_({sizeD, osizeH, osizeW});
scalar_t *output_data = output.data<scalar_t>();
// cuda blocks & threads:
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
// run averagepool kernel
adaptiveaveragepool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
);
} else {
Tensor input_ = input.contiguous();
int64_t sizeB = input_.size(0);
int64_t sizeD = input_.size(1);
int64_t isizeH = input_.size(2);
int64_t isizeW = input.size(3);
int64_t istrideD = input_.stride(1);
int64_t istrideH = input_.stride(2);
int64_t istrideW = input_.stride(3);
int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d", [&] {
scalar_t *input_data = input_.data<scalar_t>();
output.resize_({sizeB, sizeD, osizeH, osizeW});
scalar_t *output_data = output.data<scalar_t>();
// cuda blocks & threads:
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
dim3 blocks(sizeB * sizeD, blocksH);
dim3 threads(32, 8);
// run averagepool kernel
adaptiveaveragepool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
);
}
THCudaCheck(cudaGetLastError());
}
void adaptive_avg_pool2d_backward_out_cuda_template(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input)
{
TensorArg grad_input_arg{ gradInput, "gradInput", 1 },
grad_output_arg{ gradOutput_, "gradOutput_", 2 },
input_arg{ input, "input", 3 };
checkAllSameGPU("cudnn_adaptive_avg_pooling2d_out",
{grad_input_arg, grad_output_arg, input_arg});
bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests
Tensor gradOutput = gradOutput_.contiguous();
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
int64_t isizeW = input.size(2);
int64_t osizeH = gradOutput.size(1);
int64_t osizeW = gradOutput.size(2);
//bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d_backward", [&] {
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
scalar_t *gradInput_data = gradInput.data<scalar_t>();
// cuda blocks & threads:
int blocksH = std::max((int)(16L / sizeD), 1);
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptiveaveragegradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel
adaptiveaveragegradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
}
);
} else {
int64_t sizeB = input.size(0);
int64_t sizeD = input.size(1);
int64_t isizeH = input.size(2);
int64_t isizeW = input.size(3);
int64_t osizeH = gradOutput.size(2);
int64_t osizeW = gradOutput.size(3);
//bool atomic = //(isizeW%osizeW != 0) || (isizeH%osizeH != 0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d_backward", [&] {
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
scalar_t *gradInput_data = gradInput.data<scalar_t>();
// cuda blocks & threads:
int blocksH = std::max((int)(16L / sizeD), 1);
dim3 blocks(sizeB * sizeD, blocksH);
dim3 threads(32, 8);
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptiveaveragegradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel, accumulate gradients atomically
adaptiveaveragegradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
}
);
}
THCudaCheck(cudaGetLastError());
}
} // namespace
Tensor& adaptive_avg_pool2d_out_cuda(
Tensor& output,
const Tensor& input,
IntList output_size)
{
adaptive_avg_pool2d_out_cuda_template(
output, input, output_size);
return output;
}
Tensor adaptive_avg_pool2d_cuda(
at::Tensor const& input,
IntList output_size)
{
auto output = at::empty({0}, input.options());
adaptive_avg_pool2d_out_cuda_template(
output, input, output_size);
return output;
}
Tensor& adaptive_avg_pool2d_backward_out_cuda(
Tensor& gradInput,
const Tensor& gradOutput,
const Tensor& input)
{
gradInput.resize_as_(input);
adaptive_avg_pool2d_backward_out_cuda_template(
gradInput, gradOutput, input);
return gradInput;
}
Tensor adaptive_avg_pool2d_backward_cuda(
const Tensor& gradOutput,
const Tensor& input)
{
auto gradInput = at::zeros_like(input);
adaptive_avg_pool2d_backward_out_cuda_template(
gradInput, gradOutput, input);
return gradInput;
}
} // at::native
} // at
#undef CUDA_MAX_THREADS
#undef START_IND
#undef END_IND