-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fused_rope forward op #54351
Add fused_rope forward op #54351
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
0b1089d
to
8e174d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一下单测
@@ -0,0 +1,160 @@ | |||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
融合的算子实现到phi/kernels/fusion/gpu
目录下吧
template <typename T, int VecSize> | ||
struct alignas(sizeof(T) * VecSize) VectorType { | ||
T val[VecSize]; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为何不直接使用AlignedVector
呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
phi::dtype::bfloat16) { | ||
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); | ||
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); | ||
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要设置ALL_BACKEND
吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是删除后又加回来了?还有前向也是。
#include "paddle/phi/kernels/funcs/aligned_vector.h" | ||
|
||
namespace phi { | ||
template <typename T, int VecSize> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接使用AlignedVector
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
int C, | ||
int main_offset, | ||
phi::Array<T*, 3> outs_data, | ||
int break_iter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
break_iter
-> num_inputs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
auto N = q.dims()[0]; | ||
auto H = q.dims()[1]; | ||
auto W = q.dims()[2]; | ||
auto C = q.dims()[3]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q是序列,它的四个维度含义分别是[batch_size, seq_len, num_heads, head_dim]
,用维度含义来命名变量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
python/paddle/tensor/math.py
Outdated
@@ -621,6 +621,11 @@ def add(x, y, name=None): | |||
return _elementwise_op(LayerHelper('elementwise_add', **locals())) | |||
|
|||
|
|||
def fused_rope(q, k, v): | |||
if in_dynamic_mode(): | |||
return _C_ops.fused_rope(q, k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fused_rope API应该加到paddle.incubate.nn.functional
下面比较合适吧,API名使用完整的rotary_position_embedding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
a7e40fd
to
80c112d
Compare
80c112d
to
115cc40
Compare
731faf0
to
4281348
Compare
72829f9
to
b143bd0
Compare
@@ -422,6 +422,17 @@ | |||
optional : skip_update, master_params | |||
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) | |||
|
|||
- op : fused_rope |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以放在fused_ops.yaml里
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR 再改
paddle/phi/infermeta/backward.h
Outdated
@@ -459,5 +459,11 @@ void IndexAddGradInferMeta(const MetaTensor& index, | |||
int axis, | |||
MetaTensor* x_grad, | |||
MetaTensor* add_tensor_grad); | |||
void FusedRopeGradInferMeta(const MetaTensor& dout_q, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数按字典序放置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
paddle/phi/infermeta/multiary.cc
Outdated
@@ -3489,5 +3489,33 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, | |||
out_count->set_dims({-1}); | |||
out_count->set_dtype(DataType::INT32); | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void FusedRopeGradKernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fuse类型的kernel不用写头文件声明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR统一修改
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void FusedRopeKernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR统一修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for skipIf
7703138
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. 一些review建议下个PR再改下
PADDLE_ENFORCE_EQ(input_dims.size(), | ||
4, | ||
phi::errors::InvalidArgument( | ||
"Input should be a 4-D tensor of format [N, C, H, W] " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的N、C、H、W
也统一改成实际含义吧,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
|
||
template <typename T, typename Context> | ||
void FusedRopeGradKernel(const Context& dev_ctx, | ||
const DenseTensor& dout_q, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我有点不太理解dout_q
是啥意思,是前向的计算结果out_q
吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就是反向传递过来的dout
phi::dtype::bfloat16) { | ||
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); | ||
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); | ||
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是删除后又加回来了?还有前向也是。
Fused rotary position embedding. | ||
|
||
Args: | ||
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档里面加一下输入Tensor shape的描述吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的, 下个PR 修改
indices = 1 / 10000 ** (indices / q.shape[3]) | ||
sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0) | ||
|
||
sin_sin = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥一部分计算用Paddle API、一部分计算用Numpy API呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为要不一样呀
* style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move
* Add fused_rope forward op (#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move * Update the rope op according to the comments (#54985) * Update multiary.cc * Update __init__.py * for int64_t and assert * more * remove useless assert first --------- Co-authored-by: sneaxiy <sneaxiy@126.com>
…addlePaddle#55931) * Add fused_rope forward op (PaddlePaddle#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move * Update the rope op according to the comments (PaddlePaddle#54985) * Update multiary.cc * Update __init__.py * for int64_t and assert * more * remove useless assert first --------- Co-authored-by: sneaxiy <sneaxiy@126.com>
…addlePaddle#55931) * Add fused_rope forward op (PaddlePaddle#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move * Update the rope op according to the comments (PaddlePaddle#54985) * Update multiary.cc * Update __init__.py * for int64_t and assert * more * remove useless assert first --------- Co-authored-by: sneaxiy <sneaxiy@126.com>
PR types
Others
PR changes
Others
Description
Pcard-70458
Others