-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multi-outputs feature for broadcast ops #38329
Support multi-outputs feature for broadcast ops #38329
Conversation
✅ This PR's description meets the template requirements! |
Thanks for your contribution! |
3d77a4c
to
474b13c
Compare
474b13c
to
ffccae7
Compare
@@ -208,6 +231,7 @@ __device__ void ElementwiseBroadcastKernelImpl( | |||
Functor func) { | |||
InT args[Arity][VecSize]; | |||
OutT result[VecSize]; | |||
ScalarType<OutT> vec_result[NumOuts][VecSize]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个vec_result
可以挪到L199 WriteData
函数里面定义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已根据要求修改
@@ -441,8 +475,8 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, | |||
} | |||
} else { | |||
// Vector type | |||
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; | |||
const int kVectorsPerThread = NX / kVectorSize; | |||
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constexpr 有性能影响吗 (c++14 支持) ?理论上 const 编译器也会在编译时计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个修改是从读代码的角度出发的,切换成constexpr
的话,可以显式认定是编译期计算
@@ -170,11 +170,11 @@ struct DimensionsTransform { | |||
} | |||
}; | |||
|
|||
template <typename T, int VecSize, int Rank, bool IsBoundary = false> | |||
template <typename T, int VecSize, int Rank, bool IsBoundary> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要去掉默认设置呢 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改的时候看到LoadData
只接受了ElementwiseBroadcastKernelImpl
函数的调用,而ElementwiseBroadcastKernelImpl
已经设定了bool IsBoundary = false
,所以取消了这里的默认模板参数设置;
修改回原始状态
__device__ __forceinline__ void LoadData( | ||
T *dst, | ||
const T *__restrict__ src, | ||
uint32_t block_offset, | ||
int block_offset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要从 uint 改成 int ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和上一条相同,修改的时候看到LoadData只接受了ElementwiseBroadcastKernelImpl
函数的调用,而ElementwiseBroadcastKernelImpl
的参数列表中是int block_offset
,所以就把这里改成了,不过看到了KP中采用的是uint32_t
,会再提个commit把这里改回来
修改回原始状态uint32_t
@@ -428,6 +428,40 @@ __device__ __forceinline__ void ReadDataReduce( | |||
* src: The register pointer, the size is NX * NY. | |||
* size: The current block needs to load size elements continuously. | |||
*/ | |||
|
|||
#if defined(__NVCC__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个是新加的吗 ?之前的功能有不完善吗 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
功能上已经完成了,此处修改的目的是想从编译阶段区分IsBoundary
的两种情况所走入的计算分支,本质是想实现c++ 17中的逻辑
if constexpr (condition) {
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过检查,这里采用偏特化的修改导致了性能下降问题,故复原回原始写法
… support_multi-output_for_broadcast
… support_multi-output_for_broadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
LaunchBroadcastElementwiseCudaKernel
所产生的启动overhead开销out_1
和out_2
的维度应相同,且2个functor的axis
设置须要相同paddle::framework::Array
作为输入的数据类型,并设计如下所示的functor
LaunchBroadcastElementwiseCudaKernel
时,模板参数由于<InT, OutT, functor>
转变为<InT, OutT, functor, NumOuts>
,其中NumOuts
用于表达functor
的输出元素数量(本例为2),默认值为1故兼容现有的单个functor计算的写法支持多输出前后,
elementwise
计算性能几无影响:Fig. 1 支持多输出功能前
Fig. 2 支持多输出功能后
NumOuts
用于实现多输出情形,但面对多输出情形Function_traits
中的ReturnType
模板参数本身也是paddle::framework::Array<OutT, NumOuts>
的类型,一直想从ReturnType
中尝试获取NumOuts
,暂时未能有效实现