Skip to content

Commit

Permalink
Add fused_rope forward op (#54351)
Browse files Browse the repository at this point in the history
* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move
  • Loading branch information
AnnaTrainingG authored Jun 29, 2023
1 parent 7c89b97 commit a215c46
Show file tree
Hide file tree
Showing 13 changed files with 674 additions and 0 deletions.
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@
kernel :
func : frobenius_norm_grad

- backward_op : fused_rope_grad
forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rope_grad
data_type : out_q_grad

- backward_op : hardswish_grad
forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,17 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)

- op : fused_rope
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rope
data_type : q
backward: fused_rope_grad

- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1202,4 +1202,31 @@ void IndexAddGradInferMeta(const MetaTensor& index,
}
}

void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
auto input_dims = dout_q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
if (dout_q) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
}
if (dout_k) {
dk->set_dims(dout_k.dims());
dk->set_dtype(dout_k.dtype());
}
if (dout_v) {
dv->set_dims(dout_v.dims());
dv->set_dtype(dout_v.dtype());
}
}

} // namespace phi
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
MetaTensor* x_grad,
MetaTensor* y_grad);

void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);

void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,33 @@ void FusedConvInferMeta(const MetaTensor& input,
config);
}

void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
auto input_dims = q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
if (q) {
out_q->set_dims(q.dims());
out_q->set_dtype(q.dtype());
}
if (k) {
out_k->set_dims(k.dims());
out_k->set_dtype(k.dtype());
}
if (v) {
out_v->set_dims(v.dims());
out_v->set_dtype(v.dtype());
}
}

void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,4 +673,11 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);

void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);

} // namespace phi
31 changes: 31 additions & 0 deletions paddle/phi/kernels/fused_rope_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);

} // namespace phi
30 changes: 30 additions & 0 deletions paddle/phi/kernels/fused_rope_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v);

} // namespace phi
163 changes: 163 additions & 0 deletions paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/fused_rope_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {

template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;

for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}

#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);

#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;

MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0;

store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
dq->Resize(dout_q.dims());
// small size for broadcast
auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2];
auto head_dim = dout_q.dims()[3];
auto seq_len = dout_q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));

constexpr const int vec_size = 2;

auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);

int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();

phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;

ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
int num_inputs = 0;

if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk);
dk->Resize(dout_q.dims());
outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>();
num_inputs++;
}

if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv);
dv->Resize(dout_q.dims());
outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>();
num_inputs++;
}

using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);

VectorizedFusedRopeGradKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} // namespace phi

PD_REGISTER_KERNEL(fused_rope_grad,
GPU,
ALL_LAYOUT,
phi::FusedRopeGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
Loading

0 comments on commit a215c46

Please sign in to comment.