forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CUDALoops.cuh
335 lines (305 loc) · 10.6 KB
/
CUDALoops.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
#pragma once
// This file provides two functions to help write GPU elementwise kernels:
//
// gpu_kernel(TensorIterator iter, <lambda>)
// gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
//
// The gpu_kernel_with_scalars generates specializations that support a
// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
// is lifted to a kernel parameter instead of copying to device memory.
// This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
// which is the default for TensorIterator::binary_op. Otherwise, all inputs
// and the output must be on the GPU.
//
// For example, to write a reciprocal kernel for GPU float Tensors:
//
// gpu_kernel(iter, []GPU_LAMBDA(float a) {
// return 1.0f / a;
// });
//
// To write a multiplication kernel for GPU float Tensors where one argument
// may be a CPU scalar:
//
// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
// return a * b;
// });
//
// See BinaryOpsKernel.cu for the complete implementation
//
#include <iostream>
#include <tuple>
#include <type_traits>
#include <ATen/core/Array.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <c10/core/DynamicCast.h>
#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/TypeCast.h>
#ifdef __NVCC__
#define ASSERT_HOST_DEVICE_LAMBDA(type) \
static_assert( \
__nv_is_extended_host_device_lambda_closure_type(type), \
#type " must be a __host__ __device__ lambda")
#else
#define ASSERT_HOST_DEVICE_LAMBDA(type)
#endif
namespace at {
namespace native {
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
using traits = function_traits<func_t>;
int remaining = N - block_work_size() * blockIdx.x;
if (remaining < block_work_size()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
auto policy = memory::policies::unroll<
array_t,
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t>(data));
}
}
template <
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void unrolled_elementwise_kernel(
int N,
func_t f,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s) {
int remaining = N - block_work_size() * blockIdx.x;
auto policy = memory::policies::
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
data, remaining, ic, oc, l, s);
elementwise_kernel_helper(f, policy);
}
// this function assume trivial 1d and no dynamic casting
template <typename func_t, typename array_t>
static inline void launch_vectorized_kernel(
int64_t N,
const func_t& f,
array_t data) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
int vec_size = memory::can_vectorize_up_to<func_t>(data);
switch (vec_size) {
case 4:
vectorized_elementwise_kernel<4, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
vectorized_elementwise_kernel<2, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 1: {
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
unrolled_elementwise_kernel<func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(
N, f, data, input_calc, output_calc, loader, storer);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
}
}
template <
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t>
static inline void launch_unrolled_kernel(
int64_t N,
const func_t& f,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel<func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void elementwise_kernel(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
f(idx);
idx += nt;
}
}
}
template <int nt, int vt, typename func_t>
static void launch_legacy_kernel(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
auto stream = at::cuda::getCurrentCUDAStream();
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename traits, typename func_t, typename index_t, size_t... INDEX>
C10_HOST_DEVICE typename traits::result_type invoke_impl(
const func_t& f,
char* const C10_RESTRICT data[],
const index_t strides[],
int i,
std::index_sequence<INDEX...>) {
(void)strides;
(void)i;
return f(c10::load<typename traits::template arg<INDEX>::type>(
data[INDEX] + i * strides[INDEX])...);
}
template <
typename func_t,
typename index_t,
typename traits = function_traits<func_t>>
C10_HOST_DEVICE typename traits::result_type invoke(
const func_t& f,
char* const C10_RESTRICT data[],
const index_t strides[],
int i) {
using Indices = std::make_index_sequence<traits::arity>;
return invoke_impl<traits>(f, data, strides, i, Indices{});
}
template <typename traits, typename func_t, typename index_t, size_t... I>
C10_HOST_DEVICE typename traits::result_type invoke_impl(
const func_t& f,
char* const C10_RESTRICT data[],
const index_t strides[],
const ScalarType dtypes[],
int i,
std::index_sequence<I...>) {
(void)strides;
(void)i;
return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(
dtypes[I], data[I] + i * strides[I])...);
}
template <
typename func_t,
typename index_t,
typename traits = function_traits<func_t>>
C10_HOST_DEVICE typename traits::result_type invoke(
const func_t& f,
char* const C10_RESTRICT data[],
const index_t strides[],
const ScalarType dtypes[],
int i) {
using Indices = std::make_index_sequence<traits::arity>;
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
}
template <typename func_t>
void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
bool contiguous = iter.is_contiguous();
if (contiguous) {
return launch_vectorized_kernel(numel, f, data);
}
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
if (!needs_dynamic_casting<func_t>::check(iter)) {
return gpu_kernel_impl_nocast(iter, f);
}
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
bool contiguous = iter.is_contiguous();
if (contiguous) {
auto loader = memory::LoadWithCast<traits::arity>(iter);
auto storer = memory::StoreWithCast<1>(iter);
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
launch_unrolled_kernel(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
} else {
at::detail::Array<ScalarType, ntensors> dtypes;
for (int i = 0; i < ntensors; i++) {
dtypes[i] = iter.dtype(i);
}
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result =
invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
}
}
} // namespace native
} // namespace at