Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused_rope forward op #54351

Merged
merged 11 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@
kernel :
func : frobenius_norm_grad

- backward_op : fused_rope_grad
forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rope_grad
data_type : out_q_grad

- backward_op : hardswish_grad
forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,17 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)

- op : fused_rope
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以放在fused_ops.yaml里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR 再改

args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rope
data_type : q
backward: fused_rope_grad

- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1202,4 +1202,31 @@ void IndexAddGradInferMeta(const MetaTensor& index,
}
}

void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
auto input_dims = dout_q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的N、C、H、W也统一改成实际含义吧,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

"or [N, H, W, C], but got %u.",
input_dims.size()));
if (dout_q) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
}
if (dout_k) {
dk->set_dims(dout_k.dims());
dk->set_dtype(dout_k.dtype());
}
if (dout_v) {
dv->set_dims(dout_v.dims());
dv->set_dtype(dout_v.dtype());
}
}

} // namespace phi
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
MetaTensor* x_grad,
MetaTensor* y_grad);

void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);

void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,33 @@ void FusedConvInferMeta(const MetaTensor& input,
config);
}

void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
auto input_dims = q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
if (q) {
out_q->set_dims(q.dims());
out_q->set_dtype(q.dtype());
}
if (k) {
out_k->set_dims(k.dims());
out_k->set_dtype(k.dtype());
}
if (v) {
out_v->set_dims(v.dims());
out_v->set_dtype(v.dtype());
}
}

void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,4 +673,11 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);

void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);

} // namespace phi
31 changes: 31 additions & 0 deletions paddle/phi/kernels/fused_rope_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse类型的kernel不用写头文件声明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR统一修改

const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);

} // namespace phi
30 changes: 30 additions & 0 deletions paddle/phi/kernels/fused_rope_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR统一修改

const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v);

} // namespace phi
163 changes: 163 additions & 0 deletions paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

融合的算子实现到phi/kernels/fusion/gpu目录下吧

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/fused_rope_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {

template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;

for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}

#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);

#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;

MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0;

store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我有点不太理解dout_q是啥意思,是前向的计算结果out_q吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就是反向传递过来的dout

const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
dq->Resize(dout_q.dims());
// small size for broadcast
auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2];
auto head_dim = dout_q.dims()[3];
auto seq_len = dout_q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));

constexpr const int vec_size = 2;

auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);

int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();

phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;

ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
int num_inputs = 0;

if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk);
dk->Resize(dout_q.dims());
outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>();
num_inputs++;
}

if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv);
dv->Resize(dout_q.dims());
outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>();
num_inputs++;
}

using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);

VectorizedFusedRopeGradKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} // namespace phi

PD_REGISTER_KERNEL(fused_rope_grad,
GPU,
ALL_LAYOUT,
phi::FusedRopeGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要设置ALL_BACKEND吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是删除后又加回来了?还有前向也是。

}
Loading