-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multi-output feature for elementwise #38410
Support multi-output feature for elementwise #38410
Conversation
Thanks for your contribution! |
… support_multi-output_for_elementwise
… support_multi-output_for_elementwise
int num, | ||
int data_offset, | ||
Functor func) { | ||
InT args[Arity][VecSize]; | ||
OutT result[VecSize]; | ||
OutType<OutT, NumOuts> result[VecSize]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看看是否有更合适的名字吧,OutType
不是很贴切,或者叫PackedOutT
,要么体现下ConditionalOutT
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改成了ConditionalT<OutT, NumOuts>
@@ -174,19 +207,39 @@ void LaunchSameDimsElementwiseCudaKernel( | |||
"is %d, the arity of functor is %d.", | |||
ins.size(), | |||
kArity)); | |||
PADDLE_ENFORCE_EQ(outs->size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要像broadcast一样判断下ET的值吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于function_traits
中这段的存在:
Paddle/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h
Lines 166 to 168 in 706d2c0
using Traits = paddle::platform::FunctionTraits<Functor>; | |
const int kArity = | |
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity; |
所以ET的值就被
kArity
取代了,后面就用kArity
做判断了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -162,7 +162,8 @@ struct DimensionsTransform { | |||
} | |||
}; | |||
|
|||
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> | |||
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件里面的DimensionsTransform
是不是可以删掉了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以删除掉,我提个PR把这里删除吧,本来以为pten那边会删除掉
PADDLE_ENFORCE_EQ(kArity, | ||
2, | ||
PADDLE_ENFORCE_LE(kArity, | ||
ElementwiseType::kTernary, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其实这里就应该用3
,因为ElementwiseType::kTernary
是个枚举类型,可能设置成别的值。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,和上一个删除DimensionsTransform
的PR一同把这里修改掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其实这里就应该用
3
,因为ElementwiseType::kTernary
是个枚举类型,可能设置成别的值。
已经在PR38550中修改
PR types
New features
PR changes
OPs
Describe
特性
Elementwise_no_broadcast
内的核心实现,保证了elmenetwise
可与PR38329中broadcast
一致,支持多输出功能LaunchBroadcastElementwiseCudaKernel
,LaunchSameDimsElementwiseCudaKernel
, 和二者的上层封装LaunchElementwiseCudaKernel
的模板参数列表使用方法
paddle::framework::Array
作为输入的数据类型,并设计如下所示的functor
LaunchBroadcastElementwiseCudaKernel
时,模板参数由于<InT, OutT, functor>
转变为<InT, OutT, functor, NumOuts>
,其中NumOuts
用于表达functor
的输出元素数量(本例为2),默认值为1故兼容现有的单个functor计算的写法