-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fused_rope forward op #54351
Add fused_rope forward op #54351
Changes from all commits
115cc40
4281348
b143bd0
48275da
c19cb0d
96a5dc6
070966a
ecb663a
bba1968
7703138
46ae391
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1202,4 +1202,31 @@ void IndexAddGradInferMeta(const MetaTensor& index, | |
} | ||
} | ||
|
||
void FusedRopeGradInferMeta(const MetaTensor& dout_q, | ||
const MetaTensor& dout_k, | ||
const MetaTensor& dout_v, | ||
MetaTensor* dq, | ||
MetaTensor* dk, | ||
MetaTensor* dv) { | ||
auto input_dims = dout_q.dims(); | ||
PADDLE_ENFORCE_EQ(input_dims.size(), | ||
4, | ||
phi::errors::InvalidArgument( | ||
"Input should be a 4-D tensor of format [N, C, H, W] " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
"or [N, H, W, C], but got %u.", | ||
input_dims.size())); | ||
if (dout_q) { | ||
dq->set_dims(dout_q.dims()); | ||
dq->set_dtype(dout_q.dtype()); | ||
} | ||
if (dout_k) { | ||
dk->set_dims(dout_k.dims()); | ||
dk->set_dtype(dout_k.dtype()); | ||
} | ||
if (dout_v) { | ||
dv->set_dims(dout_v.dims()); | ||
dv->set_dtype(dout_v.dtype()); | ||
} | ||
} | ||
|
||
} // namespace phi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void FusedRopeGradKernel(const Context& dev_ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fuse类型的kernel不用写头文件声明 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下个PR统一修改 |
||
const DenseTensor& dout_q, | ||
const paddle::optional<DenseTensor>& dout_k, | ||
const paddle::optional<DenseTensor>& dout_v, | ||
DenseTensor* dq, | ||
DenseTensor* dk, | ||
DenseTensor* dv); | ||
|
||
} // namespace phi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void FusedRopeKernel(const Context& dev_ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下个PR统一修改 |
||
const DenseTensor& q, | ||
const paddle::optional<DenseTensor>& k, | ||
const paddle::optional<DenseTensor>& v, | ||
DenseTensor* out_q, | ||
DenseTensor* out_k, | ||
DenseTensor* out_v); | ||
|
||
} // namespace phi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 融合的算子实现到 |
||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/fused_rope_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/backends/gpu/gpu_launch_config.h" | ||
#include "paddle/phi/common/amp_type_traits.h" | ||
#include "paddle/phi/core/enforce.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/aligned_vector.h" | ||
namespace phi { | ||
|
||
template <typename T, typename MPType, int VecSize = 2> | ||
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data, | ||
int batch_size, | ||
int seq_len, | ||
int num_heads, | ||
int head_dim, | ||
phi::Array<T*, 3> outs_data, | ||
int num_inputs, | ||
MPType div_c) { | ||
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; | ||
int stride = gridDim.x * blockDim.x * VecSize; | ||
int size = batch_size * seq_len * num_heads * head_dim; | ||
MPType sin_value[VecSize]; | ||
MPType cos_value[VecSize]; | ||
MPType result[VecSize]; | ||
T store[VecSize]; | ||
using VecType = phi::AlignedVector<T, VecSize>; | ||
constexpr int kVectorsPerThread = VecSize / 2; | ||
|
||
for (; index < size; index += stride) { | ||
#pragma unroll | ||
for (int nx = 0; nx < VecSize; ++nx) { | ||
// get sin_index and cos_index | ||
int index_wc = (index + nx) % (seq_len * num_heads * head_dim); | ||
int pos_seq = index_wc / (num_heads * head_dim); | ||
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0); | ||
MPType indicses = | ||
static_cast<MPType>(1) / | ||
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c)); | ||
MPType value = pos_seq * indicses; | ||
sin_value[nx] = sin(value); | ||
cos_value[nx] = cos(value); | ||
} | ||
|
||
#pragma unroll | ||
for (int iter = 0; iter < 3; iter++) { | ||
if (iter > num_inputs) break; | ||
const T* input = ins_data[iter] + index; | ||
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index); | ||
|
||
#pragma unroll | ||
for (int nx = 0; nx < kVectorsPerThread; ++nx) { | ||
int pr_index = nx * 2; | ||
int ls_index = pr_index + 1; | ||
|
||
MPType p0 = static_cast<MPType>(input[pr_index]); | ||
MPType p1 = static_cast<MPType>(input[ls_index]); | ||
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1; | ||
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0; | ||
|
||
store[pr_index] = static_cast<T>(result[pr_index]); | ||
store[ls_index] = static_cast<T>(result[ls_index]); | ||
} | ||
out[0] = *(reinterpret_cast<VecType*>(store)); | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void FusedRopeGradKernel(const Context& dev_ctx, | ||
const DenseTensor& dout_q, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我有点不太理解 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 就是反向传递过来的dout |
||
const paddle::optional<DenseTensor>& dout_k, | ||
const paddle::optional<DenseTensor>& dout_v, | ||
DenseTensor* dq, | ||
DenseTensor* dk, | ||
DenseTensor* dv) { | ||
int numel = dout_q.numel(); | ||
if (numel <= 0) return; | ||
dev_ctx.template Alloc<T>(dq); | ||
dq->Resize(dout_q.dims()); | ||
// small size for broadcast | ||
auto batch_size = dout_q.dims()[0]; | ||
auto num_heads = dout_q.dims()[2]; | ||
auto head_dim = dout_q.dims()[3]; | ||
auto seq_len = dout_q.dims()[1]; | ||
PADDLE_ENFORCE_NE(head_dim % 2, | ||
1, | ||
phi::errors::InvalidArgument( | ||
"The head_dim of input must be a multiple of 2.")); | ||
|
||
constexpr const int vec_size = 2; | ||
|
||
auto config = | ||
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); | ||
|
||
int grid = config.block_per_grid.x; | ||
int block = config.thread_per_block.x; | ||
auto stream = dev_ctx.stream(); | ||
|
||
phi::Array<T*, 3> outs_data; | ||
phi::Array<const T*, 3> ins_data; | ||
|
||
ins_data[0] = dout_q.data<T>(); | ||
outs_data[0] = dq->data<T>(); | ||
int num_inputs = 0; | ||
|
||
if (dout_k.get_ptr()) { | ||
dev_ctx.template Alloc<T>(dk); | ||
dk->Resize(dout_q.dims()); | ||
outs_data[1] = dk->data<T>(); | ||
ins_data[1] = dout_k->data<T>(); | ||
num_inputs++; | ||
} | ||
|
||
if (dout_v.get_ptr()) { | ||
dev_ctx.template Alloc<T>(dv); | ||
dv->Resize(dout_q.dims()); | ||
outs_data[2] = dv->data<T>(); | ||
ins_data[2] = dout_v->data<T>(); | ||
num_inputs++; | ||
} | ||
|
||
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; | ||
MPType div_c = static_cast<MPType>(1.0f / head_dim); | ||
|
||
VectorizedFusedRopeGradKernel<T, MPType, vec_size> | ||
<<<grid, block, 0, stream>>>(ins_data, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
} | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(fused_rope_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::FusedRopeGradKernel, | ||
float, | ||
double, | ||
phi::dtype::float16, | ||
phi::dtype::bfloat16) { | ||
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); | ||
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); | ||
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里需要设置 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是删除后又加回来了?还有前向也是。 |
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以放在fused_ops.yaml里
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR 再改