forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
VolumetricDilatedMaxPooling.cu
159 lines (140 loc) · 4.94 KB
/
VolumetricDilatedMaxPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <THCUNN/THCUNN.h>
#include <THC/THCTensor.hpp>
#include <THCUNN/common.h>
#include <THC/THCDeviceTensor.cuh>
#include <THC/THCDeviceTensorUtils.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCAtomics.cuh>
template <typename Dtype>
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
Dtype* inputData, int inputT, int inputH, int inputW,
THCDeviceTensor<THCIndex_t, 4> indices,
THCDeviceTensor<Dtype, 4> output,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
int dilationT, int dilationH, int dilationW,
int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
int tStart = oFrame * dT - padT;
int hStart = oRow * dH - padH;
int wStart = oColumn * dW - padW;
int tEnd = min(tStart + (kT - 1) * dilationT + 1, inputT);
int hEnd = min(hStart + (kH - 1) * dilationH + 1, inputH);
int wEnd = min(wStart + (kW - 1) * dilationW + 1, inputW);
while(tStart < 0)
tStart += dilationT;
while(hStart < 0)
hStart += dilationH;
while(wStart < 0)
wStart += dilationW;
int index = 0;
int maxIndex = -1;
inputData += slice * inputT * inputH * inputW;
Dtype max = THCNumerics<Dtype>::min();
for (int t = tStart; t < tEnd; t += dilationT)
{
for (int h = hStart; h < hEnd; h += dilationH)
{
for (int w = wStart; w < wEnd; w += dilationW)
{
index = t * inputH * inputW + h * inputW + w;
Dtype val = inputData[index];
if ((max < val) || THCNumerics<Dtype>::isnan(val))
{
max = val;
maxIndex = index;
}
}
}
}
output[slice][oFrame][oRow][oColumn] = max;
indices[slice][oFrame][oRow][oColumn] = maxIndex + TH_INDEX_BASE;
}
}
template <int KERNEL_WIDTH, typename Dtype>
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
Dtype* inputData, int inputT, int inputH, int inputW,
THCDeviceTensor<THCIndex_t, 4> indices,
THCDeviceTensor<Dtype, 4> output,
int kT, int kH,
int dT, int dH, int dW,
int padT, int padH, int padW,
int dilationT, int dilationH, int dilationW,
int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
int tStart = oFrame * dT - padT;
int hStart = oRow * dH - padH;
int wStart = oColumn * dW - padW;
int tEnd = min(tStart + (kT - 1) * dilationT + 1, inputT);
int hEnd = min(hStart + (kH - 1) * dilationH + 1, inputH);
int wEnd = min(wStart + (KERNEL_WIDTH - 1) * dilationW + 1, inputW);
while(tStart < 0)
tStart += dilationT;
while(hStart < 0)
hStart += dilationH;
while(wStart < 0)
wStart += dilationW;
int index = 0;
int maxIndex = -1;
Dtype max = THCNumerics<Dtype>::min();
for (int t = tStart; t < tEnd; t += dilationT)
{
for (int h = hStart; h < hEnd; h += dilationH)
{
for (int w = wStart; w < wEnd; w += dilationW)
{
index = t * inputH * inputW + h * inputW + w;
Dtype val = inputData[slice * inputT * inputH * inputW + index];
if (max < val)
{
max = val;
maxIndex = index;
}
}
}
}
output[slice][oFrame][oRow][oColumn] = max;
indices[slice][oFrame][oRow][oColumn] = maxIndex + TH_INDEX_BASE;
}
}
template <typename Dtype>
__global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput(
THCDeviceTensor<Dtype, 4> gradOutput,
THCDeviceTensor<THCIndex_t, 4> indices,
Dtype* gradInputData,
int inputT, int inputH, int inputW,
int dT, int dH, int dW,
int padT, int padH, int padW,
int dilationT, int dilationH, int dilationW,
int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // output slice/feature
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
{
int maxIndex = indices[slice][oFrame][oRow][oColumn] - TH_INDEX_BASE;
if (maxIndex != -1) {
atomicAdd(&gradInputData[slice * inputT * inputH * inputW + maxIndex],
gradOutput[slice][oFrame][oRow][oColumn]);
}
}
}
#include <THCUNN/generic/VolumetricDilatedMaxPooling.cu>
#include <THC/THCGenerateFloatTypes.h>