Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Slice Tensor for int64 index in ElemenetwiseKernel #57313

Merged
merged 22 commits into from
Oct 10, 2023

Conversation

AnnaTrainingG
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG commented Sep 14, 2023

PR types

Others

PR changes

Others

Description

Others
Pcard-70459
image
5120 4 384 384, 1 1 384 384 bfloat16 精度验证
该规模下本PR性能为9.5ms, baseline nsys 统计不出性能数据
加sync 本PR: paddle.Add 加了sync的 程序运行时间:0.0124376 秒
加sync baseline: paddle.Add 加了sync的程序运行时间 程序运行时间:0.10285941 秒

无sync 本PR:程序运行时间:0.021558000000000003 秒
无sync old: 程序运行时间:0.007517 秒

说明:baseline 统计的耗时是 commit 3474e09 之前的耗时,原因是: 3474e09 commit 实现有误,会导致访存越界
c94d13c12bbe0a35ae2b2d0b20497f4b

import paddle
import numpy as np
import datetime

tensor_a = paddle.rand(shape=[5120, 4, 384, 384], dtype="float16")
tensor_b = paddle.rand(shape=[5120, 1, 384, 384], dtype="float16")
elapsed_time = datetime.timedelta()
# 记录起始时间
# 计算运行时间
for i in range(10):
    paddle.device.cuda.synchronize()
    start_time = datetime.datetime.now()
    tensor_z= paddle.add(tensor_a,tensor_b)
    paddle.device.cuda.synchronize()
    end_time = datetime.datetime.now()
    elapsed_time += end_time - start_time

a0,a1, = paddle.split(tensor_z, num_or_sections=2, axis=1)
in0,in1= paddle.split(tensor_a, num_or_sections=2, axis=1)

r0 = paddle.add(tensor_b, in0)
r1 = paddle.add(tensor_b, in1)

result1 = paddle.any(paddle.equal(a0,r0),[0,1,2,3])
result2 = paddle.any(paddle.equal(a1,r1),[0,1,2,3]) 
np.testing.assert_equal(result1.numpy(),True)
np.testing.assert_equal(result2.numpy(),True)

# 打印结果
print(f"程序运行时间:{(elapsed_time.total_seconds())/10} 秒")

测试脚本

本PR op_benchmark 中算子case规模较小,机器波动造成
image

@paddle-bot
Copy link

paddle-bot bot commented Sep 14, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

#ifndef PADDLE_WITH_XPU_KP
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
auto loader_classifier =
BroadcastTypeClassifier<OutT, Functor, kArity, NumOuts>(ins, outs, axis);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外层构造过一遍BroadcastTypeClassifier,则可以把构造好的loader_classifier作为参数传给BroadcastKernelForDifferentVecSize函数,避免函数里面再构造一次,产生额外的CPU开销。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除 两者的tensor 首地址及dim不同,不可 直接传递

auto compute_size = std::numeric_limits<int32_t>::max();
bool use_int64_index_kernel = kEnabledInt64IndexKernel &&
(*outs)[0]->numel() >= compute_size &&
(!loader_classifier.all_elementwise);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_elementwise即不需要广播分支,也需要支持大Tensor计算

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

(!loader_classifier.all_elementwise);

if (use_int64_index_kernel) { // use_int64_index_kernel
const auto dims_simplifier =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dims_simplifierBroadcastTypeClassifier内部也计算过,若外部需要使用,可以改成作为成员保存。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两个dim 不一样不可直接使用

int all_rank = dims_simplifier.rank;
auto old_in_dims = dims_simplifier.in_dims;
auto old_out_dims = dims_simplifier.out_dims;
auto old_in_strides = dims_simplifier.in_dims;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old_ -> origin_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto old_out_dims = dims_simplifier.out_dims;
auto old_in_strides = dims_simplifier.in_dims;

old_out_strides.resize(all_rank);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整个代码太长,且有点缺乏条理。建议按照功能封装成几个子函数。另外考虑下这个逻辑复用的可能,比如能否服用到reduce计算中。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 看PR描述,大Tensor情况下的计算时间看起来并不是很长,可以添加一个单测保证计算的正确性
  • nsys跑不出来,那就跑下端到端的时间对比吧,测试代码贴到PR描述里面,多测试几个配置

可以考虑把整个逻辑封装一下,放在单独的一个头文件中(如tensor_slicer.h)方便复用如:

class TensorSlicer {
 public:
  using ArgumentsTuple = std::tuple<std::vector<const DenseTensor *> ins, std::vector<DenseTensor *> outs>;
  TensorSlicer(const std::vector<const DenseTensor *> &ins, std::vector<DenseTensor *> *outs) { 
    // 初始化
  }

  int size() const { return num_splits_; }
  ArgumentsTuple operator[](int i) {
    ...
  }

 private:
  std::vector<const DenseTensor *>* ins_; // not owned
  std::vector<DenseTensor *>* outs_;      // not owned
  int num_splits_{0};
  std::vector<int64_t> strides_; // 切片相关的信息
};

外层调用逻辑为:

if (numel > compute_size) {
  auto slicer = TensorSlicer(ins, outs);
  for (int i = 0; i < slice.size(); ++i) {
    auto args = slicer[i];
    BroadcastKernelForDifferentVecSize(args.first, args.second, ...);
  }
  return;
}
BroadcastKernelForDifferentVecSize(ins, outs, ...)

@@ -950,6 +897,215 @@ BroadcastKernelForDifferentVecSize(const KPDevice &ctx,
}
}

static void initDims(std::vector<int64_t> *dims, int size, int64_t value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::vector有相应的构造函数,支持直接初始化为N个值为value的vector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

std::reverse(dims->begin(), dims->end());
}

static void UpdateTensor(DenseTensor *x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名为SliceTensor更合理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int axis,
Functor func,
const int64_t compute_size) {
const auto dims_simplifier =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对dims化简其实没有意义,计算Tensor切片的strides复杂度比dims化简要低。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处不仅仅是简化dims 更是对input dims的扩展,保证出0D之外,输入输出的dims_size是一样的

BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
ctx, new_ins, &new_outs, axis, func);
}
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除

}

// compute
DenseTensor tmp_in[kArity];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里定义的Tensor看起来并没有被用到,因为L1023又定义了同名Tensor。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重复定义,就已经删除

phi::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outs_data在该函数中并没有被用到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处是为了分配空间,已经修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上层BroadcastKernel已经分配了空间了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@AnnaTrainingG AnnaTrainingG changed the title Broadcast Add Slice Tensor for int64 index in ElemenetwiseKernel Sep 25, 2023
@@ -88,6 +88,37 @@ def init_dtype(self):
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持bfloat16的CI太少了,单测可以只测float16。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

class TestTensorAddSplit(unittest.TestCase):
def _split_compute(self, dtype):
paddle.disable_static()
tensor_a = paddle.rand(shape=[5120, 4, 384, 384], dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意显存使用,这一个tensor就5.6G。感觉新增一个测试文件比较好?和其他单测在一起,显存是否容易崩?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经新增了测试文件

from paddle.base import core


class TestElementwiseOp(OpTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新增的单测不要继承OpTest,因为不是按照OpTest的方式测试的。按照普通的unittest方式加就行。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Oct 6, 2023

Sorry to inform you that db6165c's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@Xreki
Copy link
Contributor

Xreki commented Oct 9, 2023

该规模下本PR性能为9.5ms, baseline nsys 统计不出性能数据

既然develop报错的原因已找到,再统计下性能数据吧。

if (use_int64_index_kernel) {
switch (vec_size) {
case VecSizeL: {
LaunchBroadcastKernelWithInt64IndexHelper<OutT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他int64实现相关的源码也可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续单独提交PR 删

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f13022917c14581e2116516e1249f02c

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and great work~

@AnnaTrainingG AnnaTrainingG merged commit f147f4b into PaddlePaddle:develop Oct 10, 2023
27 checks passed
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 16, 2023
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants