-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Slice Tensor for int64 index in ElemenetwiseKernel #57313
Changes from 15 commits
678aebd
c8ed2e4
1e9ab49
1e0b71c
da6812e
bfddda2
ca875ea
ca81c22
d20409f
90d8efa
58083d2
51a8293
85e7012
6c2cefb
ff793d3
84d9fb6
d8bf008
dcb66cb
bba107c
3e6cc05
19db945
76b0744
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -518,7 +518,7 @@ HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx( | |
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides, | ||
int rank) { | ||
int64_t dst_idx = 0; | ||
int64_t old_src_idx = src_idx; | ||
int64_t origin_src_idx = src_idx; | ||
for (int k = 0; k < rank; ++k) { | ||
auto local_idx = src_idx / src_strides[k + 1]; | ||
src_idx -= local_idx * src_strides[k + 1]; | ||
|
@@ -871,59 +871,6 @@ BroadcastKernelForDifferentVecSize(const KPDevice &ctx, | |
int vec_size = GetVectorizedSizeForTensors(ins, *outs); | ||
#endif | ||
|
||
#ifndef PADDLE_WITH_XPU_KP | ||
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && Arity <= 3); | ||
bool use_int64_index_kernel = | ||
kEnabledInt64IndexKernel && | ||
(*outs)[0]->numel() >= std::numeric_limits<int32_t>::max(); | ||
if (use_int64_index_kernel) { | ||
switch (vec_size) { | ||
case VecSizeL: { | ||
LaunchBroadcastKernelWithInt64IndexHelper<OutT, | ||
Functor, | ||
Arity, | ||
NumOuts, | ||
VecSizeL>::Run(ctx, | ||
ins, | ||
outs, | ||
axis, | ||
func); | ||
break; | ||
} | ||
case VecSizeM: { | ||
LaunchBroadcastKernelWithInt64IndexHelper<OutT, | ||
Functor, | ||
Arity, | ||
NumOuts, | ||
VecSizeM>::Run(ctx, | ||
ins, | ||
outs, | ||
axis, | ||
func); | ||
break; | ||
} | ||
case VecSizeS: { | ||
LaunchBroadcastKernelWithInt64IndexHelper<OutT, | ||
Functor, | ||
Arity, | ||
NumOuts, | ||
VecSizeS>::Run(ctx, | ||
ins, | ||
outs, | ||
axis, | ||
func); | ||
break; | ||
} | ||
default: { | ||
PADDLE_THROW(phi::errors::Unimplemented( | ||
"Unsupported vectorized size: %d!", vec_size)); | ||
break; | ||
} | ||
} | ||
return; | ||
} | ||
#endif | ||
|
||
auto classifier = | ||
BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts>(ins, outs, axis); | ||
switch (vec_size) { | ||
|
@@ -950,6 +897,195 @@ BroadcastKernelForDifferentVecSize(const KPDevice &ctx, | |
} | ||
} | ||
|
||
static void updateStridesDims(std::vector<int64_t> *strides, | ||
std::vector<int64_t> *dims) { | ||
for (int i = 1; i < strides->size(); i++) { | ||
(*strides)[i] = (*strides)[i - 1] * (*dims)[i - 1]; | ||
} | ||
// reverse origin_in_dim and origin_in_stride if in's dim_size > 0 | ||
std::reverse(strides->begin(), strides->end()); | ||
std::reverse(dims->begin(), dims->end()); | ||
} | ||
|
||
static void SliceTensor(DenseTensor *x, | ||
const DenseTensor *share, | ||
const std::vector<int64_t> &out_compute_dims, | ||
int64_t offset) { | ||
auto new_dim = make_ddim(out_compute_dims); | ||
DenseTensorMeta meta(share->dtype(), | ||
new_dim, | ||
share->layout(), | ||
offset * SizeOf(share->dtype())); | ||
x->set_meta(meta); | ||
x->ShareBufferWith(*(share), true); | ||
x->Resize(new_dim); | ||
} | ||
|
||
template <typename OutT, typename Functor, int kArity, int NumOuts = 1> | ||
void BroadcastKernelSplit(const KPDevice &ctx, | ||
const std::vector<const DenseTensor *> &ins, | ||
std::vector<DenseTensor *> *outs, | ||
int axis, | ||
Functor func, | ||
const int64_t compute_size) { | ||
const auto dims_simplifier = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里对 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处不仅仅是简化dims 更是对input dims的扩展,保证出0D之外,输入输出的dims_size是一样的 |
||
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); | ||
if (VLOG_IS_ON(6)) { | ||
DimsSimplifiedLogger<int64_t>::Log( | ||
ins, outs, dims_simplifier, "GPU Broadcast"); | ||
} | ||
|
||
int all_rank = dims_simplifier.rank; | ||
std::vector<int64_t> origin_out_strides(all_rank, 1); | ||
auto origin_in_dims = dims_simplifier.in_dims; | ||
auto origin_out_dims = dims_simplifier.out_dims; | ||
auto origin_in_strides = dims_simplifier.in_dims; | ||
|
||
// for split | ||
std::vector<int64_t> loop_num_out(all_rank, 1); | ||
std::vector<int64_t> loop_num_out_stride(all_rank, 1); | ||
|
||
// for input's offset | ||
std::vector<int64_t> ins_offset(kArity, 0); | ||
std::vector<int64_t> ins_scale_for_dim(kArity, 0); | ||
|
||
// init offset and check in's dim | ||
for (int k = 0; k < kArity; k++) { | ||
ins_scale_for_dim[k] = ins[k]->dims().size() == 0 ? 0 : 1; | ||
if (ins_scale_for_dim[k]) { | ||
origin_in_strides[k][0] = 1; | ||
} | ||
} | ||
|
||
updateStridesDims(&origin_out_strides, &origin_out_dims); | ||
for (int k = 0; k < kArity; k++) { | ||
if (ins_scale_for_dim[k]) { | ||
updateStridesDims(&origin_in_strides[k], &origin_in_dims[k]); | ||
} | ||
} | ||
|
||
// init out_split_dim and in_split_dims | ||
auto out_split_dim = origin_out_dims; | ||
auto in_split_dims = origin_in_dims; | ||
|
||
// init | ||
int64_t loop_num = 1; | ||
int64_t split_idx = 0; | ||
|
||
for (int r = 0; r < all_rank; r++) { | ||
// if the compute_size was too small the split_size must be 0, but the | ||
// dim_num must ge 1 | ||
int64_t split_size = compute_size / origin_out_strides[r]; | ||
out_split_dim[r] = std::max(split_size, static_cast<int64_t>(1)); | ||
loop_num_out[r] = | ||
(origin_out_dims[r] + out_split_dim[r] - 1) / out_split_dim[r]; | ||
loop_num *= loop_num_out[r]; | ||
|
||
for (int k = 0; k < kArity; k++) { | ||
if (ins_scale_for_dim[k]) { | ||
in_split_dims[k][r] = std::min(origin_in_dims[k][r], out_split_dim[r]); | ||
} | ||
} | ||
|
||
// split_idx is the index for lash split dim | ||
if (split_size != 0) { | ||
split_idx = r; | ||
break; | ||
} | ||
} | ||
|
||
loop_num_out_stride[all_rank - 1] = 1; | ||
for (int r = all_rank - 2; r >= 0; r--) { | ||
loop_num_out_stride[r] = loop_num_out_stride[r + 1] * loop_num_out[r + 1]; | ||
} | ||
|
||
// compute | ||
|
||
for (int iter = 0; iter < loop_num; iter++) { | ||
std::vector<const DenseTensor *> new_ins = {}; | ||
std::vector<DenseTensor *> new_outs = {}; | ||
phi::DenseTensor tmp_in[kArity]; | ||
DenseTensor tmp_out[NumOuts]; | ||
|
||
int64_t tmp_size = iter; | ||
int64_t out_offset = 0; | ||
// compute the offset before last split dim | ||
for (int i = 0; i < split_idx; i++) { | ||
auto repeat_times = tmp_size / loop_num_out_stride[i]; | ||
out_offset += repeat_times * origin_out_strides[i]; | ||
for (int k = 0; k < kArity; k++) { | ||
if (ins_scale_for_dim[k]) { | ||
ins_offset[k] += | ||
(repeat_times % origin_in_dims[k][i]) * origin_in_strides[k][i]; | ||
} | ||
} | ||
tmp_size = tmp_size % loop_num_out_stride[i]; | ||
} | ||
// tmp_size is the last split_dims's repeat idx | ||
auto pre_deal_size = tmp_size * out_split_dim[split_idx]; | ||
out_offset += pre_deal_size * origin_out_strides[split_idx]; | ||
// compute_size | ||
auto remainder_size = origin_out_dims[split_idx] - pre_deal_size; | ||
|
||
// get current compute size | ||
auto out_compute_dims = out_split_dim; | ||
out_compute_dims[split_idx] = | ||
std::min(out_split_dim[split_idx], remainder_size); | ||
|
||
// in + compute_size | ||
auto in_compute_dims = in_split_dims; | ||
for (int k = 0; k < kArity; k++) { | ||
if (ins_scale_for_dim[k]) { | ||
auto split_repeat = | ||
origin_in_dims[k][split_idx] == origin_out_dims[split_idx] | ||
? tmp_size | ||
: 0; | ||
ins_offset[k] += split_repeat * in_split_dims[k][split_idx] * | ||
origin_in_strides[k][split_idx]; | ||
in_compute_dims[k][split_idx] = | ||
std::min(in_split_dims[k][split_idx], out_compute_dims[split_idx]); | ||
} | ||
SliceTensor(&tmp_in[k], | ||
ins[k], | ||
in_compute_dims[k], | ||
ins_scale_for_dim[k] * ins_offset[k]); | ||
new_ins.emplace_back(&tmp_in[k]); | ||
ins_offset[k] = 0; | ||
} | ||
|
||
for (int n = 0; n < NumOuts; n++) { | ||
SliceTensor(&tmp_out[n], (*outs)[n], out_compute_dims, out_offset); | ||
new_outs.emplace_back(&tmp_out[n]); | ||
} | ||
|
||
BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>( | ||
ctx, new_ins, &new_outs, axis, func); | ||
} | ||
} | ||
|
||
template <typename OutT, typename Functor, int kArity, int NumOuts = 1> | ||
void BroadcastKernelApply(const KPDevice &ctx, | ||
const std::vector<const DenseTensor *> &ins, | ||
std::vector<DenseTensor *> *outs, | ||
int axis, | ||
Functor func) { | ||
#ifndef PADDLE_WITH_XPU_KP | ||
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3); | ||
// check whether need broadcast | ||
auto compute_size = std::numeric_limits<int32_t>::max(); | ||
bool use_int64_index_kernel = | ||
kEnabledInt64IndexKernel && (*outs)[0]->numel() >= compute_size; | ||
|
||
if (use_int64_index_kernel) { // use_int64_index_kernel | ||
BroadcastKernelSplit<OutT, Functor, kArity, NumOuts>( | ||
ctx, ins, outs, axis, func, compute_size); | ||
return; | ||
} | ||
#endif | ||
BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>( | ||
ctx, ins, outs, axis, func); | ||
} | ||
|
||
template <typename OutT, typename Functor, int NumOuts = 1> | ||
void BroadcastKernel(const KPDevice &ctx, | ||
const std::vector<const DenseTensor *> &ins, | ||
|
@@ -1014,7 +1150,7 @@ void BroadcastKernel(const KPDevice &ctx, | |
max_rank = std::max(max_rank, (*outs)[0]->dims().size()); | ||
} | ||
axis = axis == -1 ? max_rank - min_rank : axis; | ||
BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>( | ||
BroadcastKernelApply<OutT, Functor, kArity, NumOuts>( | ||
ctx, ins, outs, axis, func); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,37 @@ def init_dtype(self): | |
or not core.is_bfloat16_supported(core.CUDAPlace(0)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 支持bfloat16的CI太少了,单测可以只测float16。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经修改 |
||
"core is not compiled with CUDA and do not support bfloat16", | ||
) | ||
class TestTensorAddSplit(unittest.TestCase): | ||
def _split_compute(self, dtype): | ||
paddle.disable_static() | ||
tensor_a = paddle.rand(shape=[5120, 4, 384, 384], dtype=dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 注意显存使用,这一个tensor就5.6G。感觉新增一个测试文件比较好?和其他单测在一起,显存是否容易崩? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经新增了测试文件 |
||
tensor_b = paddle.rand(shape=[5120, 1, 384, 384], dtype=dtype) | ||
tensor_z = paddle.subtract(tensor_a, tensor_b) | ||
|
||
in0, in1 = paddle.split(tensor_a, num_or_sections=2, axis=1) | ||
( | ||
out0, | ||
out1, | ||
) = paddle.split(tensor_z, num_or_sections=2, axis=1) | ||
|
||
split_add0 = paddle.subtract(tensor_b, in0) | ||
split_add1 = paddle.subtract(tensor_b, in1) | ||
|
||
result1 = paddle.any(paddle.equal(out0, split_add0), [0, 1, 2, 3]) | ||
result2 = paddle.any(paddle.equal(out1, split_add1), [0, 1, 2, 3]) | ||
paddle.device.cuda.synchronize() | ||
np.testing.assert_equal(result1.numpy(), True) | ||
np.testing.assert_equal(result2.numpy(), True) | ||
|
||
def test_float16_add(self): | ||
self._split_compute("float16") | ||
|
||
|
||
class TestTensorAddSplit1(TestTensorAddSplit): | ||
def test_bfloat16_add(self): | ||
self._split_compute("bfloat16") | ||
|
||
|
||
class TestElementwiseBF16OP(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_sub" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他int64实现相关的源码也可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续单独提交PR 删
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.