From a215c46a3fd6e1daad1b4de3917d0b216a90fcee Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Thu, 29 Jun 2023 13:40:59 +0800 Subject: [PATCH] Add fused_rope forward op (#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move --- paddle/phi/api/yaml/legacy_backward.yaml | 11 ++ paddle/phi/api/yaml/legacy_ops.yaml | 11 ++ paddle/phi/infermeta/backward.cc | 27 +++ paddle/phi/infermeta/backward.h | 7 + paddle/phi/infermeta/multiary.cc | 27 +++ paddle/phi/infermeta/multiary.h | 7 + paddle/phi/kernels/fused_rope_grad_kernel.h | 31 ++++ paddle/phi/kernels/fused_rope_kernel.h | 30 ++++ .../phi/kernels/gpu/fused_rope_grad_kernel.cu | 163 +++++++++++++++++ paddle/phi/kernels/gpu/fused_rope_kernel.cu | 167 ++++++++++++++++++ .../paddle/incubate/nn/functional/__init__.py | 2 + .../fused_rotary_position_embedding.py | 47 +++++ .../test_fused_rotary_position_embedding.py | 144 +++++++++++++++ 13 files changed, 674 insertions(+) create mode 100644 paddle/phi/kernels/fused_rope_grad_kernel.h create mode 100644 paddle/phi/kernels/fused_rope_kernel.h create mode 100644 paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/fused_rope_kernel.cu create mode 100644 python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py create mode 100644 test/legacy_test/test_fused_rotary_position_embedding.py diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index df644df098a80..10135adf6a0f4 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -302,6 +302,17 @@ kernel : func : frobenius_norm_grad +- backward_op : fused_rope_grad + forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) + args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad) + output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) + optional : out_k_grad, out_v_grad, k_grad, v_grad + infer_meta : + func : FusedRopeGradInferMeta + kernel : + func : fused_rope_grad + data_type : out_q_grad + - backward_op : hardswish_grad forward : hardswish (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index fd0d4c1c52005..b440ef6cd98c7 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -423,6 +423,17 @@ optional : skip_update, master_params inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) +- op : fused_rope + args : (Tensor q, Tensor k, Tensor v) + output : Tensor(out_q), Tensor(out_k), Tensor(out_v) + infer_meta : + func : FusedRopeInferMeta + optional : k,v, out_k, out_v + kernel : + func : fused_rope + data_type : q + backward: fused_rope_grad + - op : gaussian args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) output: Tensor(out) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 2ea81ad52f7ee..c784e295a13cf 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1202,4 +1202,31 @@ void IndexAddGradInferMeta(const MetaTensor& index, } } +void FusedRopeGradInferMeta(const MetaTensor& dout_q, + const MetaTensor& dout_k, + const MetaTensor& dout_v, + MetaTensor* dq, + MetaTensor* dk, + MetaTensor* dv) { + auto input_dims = dout_q.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + input_dims.size())); + if (dout_q) { + dq->set_dims(dout_q.dims()); + dq->set_dtype(dout_q.dtype()); + } + if (dout_k) { + dk->set_dims(dout_k.dims()); + dk->set_dtype(dout_k.dtype()); + } + if (dout_v) { + dv->set_dims(dout_v.dims()); + dv->set_dtype(dout_v.dtype()); + } +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 4dc995cb296f6..cb923e16446af 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -184,6 +184,13 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, MetaTensor* x_grad, MetaTensor* y_grad); +void FusedRopeGradInferMeta(const MetaTensor& dout_q, + const MetaTensor& dout_k, + const MetaTensor& dout_v, + MetaTensor* dq, + MetaTensor* dk, + MetaTensor* dv); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 31ea58775ffd5..46b90d5d42bae 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3484,6 +3484,33 @@ void FusedConvInferMeta(const MetaTensor& input, config); } +void FusedRopeInferMeta(const MetaTensor& q, + const MetaTensor& k, + const MetaTensor& v, + MetaTensor* out_q, + MetaTensor* out_k, + MetaTensor* out_v) { + auto input_dims = q.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + input_dims.size())); + if (q) { + out_q->set_dims(q.dims()); + out_q->set_dtype(q.dtype()); + } + if (k) { + out_k->set_dims(k.dims()); + out_k->set_dtype(k.dtype()); + } + if (v) { + out_v->set_dims(v.dims()); + out_v->set_dtype(v.dtype()); + } +} + void MoeInferMeta(const MetaTensor& x, const MetaTensor& gate, const MetaTensor& bmm0, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index a792544ee005d..d0cc876c840ac 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -673,4 +673,11 @@ void MoeInferMeta(const MetaTensor& x, const std::string& act_type, MetaTensor* out); +void FusedRopeInferMeta(const MetaTensor& q, + const MetaTensor& k, + const MetaTensor& v, + MetaTensor* out_q, + MetaTensor* out_k, + MetaTensor* out_v); + } // namespace phi diff --git a/paddle/phi/kernels/fused_rope_grad_kernel.h b/paddle/phi/kernels/fused_rope_grad_kernel.h new file mode 100644 index 0000000000000..26e8ed451d64b --- /dev/null +++ b/paddle/phi/kernels/fused_rope_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedRopeGradKernel(const Context& dev_ctx, + const DenseTensor& dout_q, + const paddle::optional& dout_k, + const paddle::optional& dout_v, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv); + +} // namespace phi diff --git a/paddle/phi/kernels/fused_rope_kernel.h b/paddle/phi/kernels/fused_rope_kernel.h new file mode 100644 index 0000000000000..cdced91dcfdef --- /dev/null +++ b/paddle/phi/kernels/fused_rope_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FusedRopeKernel(const Context& dev_ctx, + const DenseTensor& q, + const paddle::optional& k, + const paddle::optional& v, + DenseTensor* out_q, + DenseTensor* out_k, + DenseTensor* out_v); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu new file mode 100644 index 0000000000000..59db5dbdb9ae5 --- /dev/null +++ b/paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu @@ -0,0 +1,163 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fused_rope_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +namespace phi { + +template +__global__ void VectorizedFusedRopeGradKernel(phi::Array ins_data, + int batch_size, + int seq_len, + int num_heads, + int head_dim, + phi::Array outs_data, + int num_inputs, + MPType div_c) { + int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int stride = gridDim.x * blockDim.x * VecSize; + int size = batch_size * seq_len * num_heads * head_dim; + MPType sin_value[VecSize]; + MPType cos_value[VecSize]; + MPType result[VecSize]; + T store[VecSize]; + using VecType = phi::AlignedVector; + constexpr int kVectorsPerThread = VecSize / 2; + + for (; index < size; index += stride) { +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + // get sin_index and cos_index + int index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int pos_seq = index_wc / (num_heads * head_dim); + MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); + MPType indicses = + static_cast(1) / + pow(static_cast(10000), idx * static_cast(div_c)); + MPType value = pos_seq * indicses; + sin_value[nx] = sin(value); + cos_value[nx] = cos(value); + } + +#pragma unroll + for (int iter = 0; iter < 3; iter++) { + if (iter > num_inputs) break; + const T* input = ins_data[iter] + index; + VecType* out = reinterpret_cast(outs_data[iter] + index); + +#pragma unroll + for (int nx = 0; nx < kVectorsPerThread; ++nx) { + int pr_index = nx * 2; + int ls_index = pr_index + 1; + + MPType p0 = static_cast(input[pr_index]); + MPType p1 = static_cast(input[ls_index]); + result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1; + result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0; + + store[pr_index] = static_cast(result[pr_index]); + store[ls_index] = static_cast(result[ls_index]); + } + out[0] = *(reinterpret_cast(store)); + } + } +} + +template +void FusedRopeGradKernel(const Context& dev_ctx, + const DenseTensor& dout_q, + const paddle::optional& dout_k, + const paddle::optional& dout_v, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { + int numel = dout_q.numel(); + if (numel <= 0) return; + dev_ctx.template Alloc(dq); + dq->Resize(dout_q.dims()); + // small size for broadcast + auto batch_size = dout_q.dims()[0]; + auto num_heads = dout_q.dims()[2]; + auto head_dim = dout_q.dims()[3]; + auto seq_len = dout_q.dims()[1]; + PADDLE_ENFORCE_NE(head_dim % 2, + 1, + phi::errors::InvalidArgument( + "The head_dim of input must be a multiple of 2.")); + + constexpr const int vec_size = 2; + + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); + + int grid = config.block_per_grid.x; + int block = config.thread_per_block.x; + auto stream = dev_ctx.stream(); + + phi::Array outs_data; + phi::Array ins_data; + + ins_data[0] = dout_q.data(); + outs_data[0] = dq->data(); + int num_inputs = 0; + + if (dout_k.get_ptr()) { + dev_ctx.template Alloc(dk); + dk->Resize(dout_q.dims()); + outs_data[1] = dk->data(); + ins_data[1] = dout_k->data(); + num_inputs++; + } + + if (dout_v.get_ptr()) { + dev_ctx.template Alloc(dv); + dv->Resize(dout_q.dims()); + outs_data[2] = dv->data(); + ins_data[2] = dout_v->data(); + num_inputs++; + } + + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType div_c = static_cast(1.0f / head_dim); + + VectorizedFusedRopeGradKernel + <<>>(ins_data, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); +} +} // namespace phi + +PD_REGISTER_KERNEL(fused_rope_grad, + GPU, + ALL_LAYOUT, + phi::FusedRopeGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/gpu/fused_rope_kernel.cu new file mode 100644 index 0000000000000..f378a211a3583 --- /dev/null +++ b/paddle/phi/kernels/gpu/fused_rope_kernel.cu @@ -0,0 +1,167 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fused_rope_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +namespace phi { + +template +__global__ void VectorizedFusedRopeKernel(phi::Array ins_data, + int batch_size, + int seq_len, + int num_heads, + int head_dim, + phi::Array outs_data, + int num_inputs, + MPType div_c) { + int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int stride = gridDim.x * blockDim.x * VecSize; + int size = batch_size * seq_len * num_heads * head_dim; + MPType sin_value[VecSize]; + MPType cos_value[VecSize]; + MPType result[VecSize]; + T store[VecSize]; + using VecType = phi::AlignedVector; + constexpr int kVectorsPerThread = VecSize / 2; + + for (; index < size; index += stride) { +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + // get sin_index and cos_index + int index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int pos_seq = index_wc / (num_heads * head_dim); + MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); + MPType indicses = + static_cast(1) / + pow(static_cast(10000), idx * static_cast(div_c)); + MPType value = pos_seq * indicses; + sin_value[nx] = sin(value); + cos_value[nx] = cos(value); + } + +#pragma unroll + for (int iter = 0; iter < 3; iter++) { + if (iter > num_inputs) break; + const T* input = ins_data[iter] + index; + VecType* out = reinterpret_cast(outs_data[iter] + index); + +#pragma unroll + for (int nx = 0; nx < kVectorsPerThread; ++nx) { + int pr_index = nx * 2; + int ls_index = pr_index + 1; + + MPType p0 = static_cast(input[pr_index]); + MPType p1 = static_cast(input[ls_index]); + + result[pr_index] = cos_value[pr_index] * p0; + result[pr_index] -= sin_value[pr_index] * p1; + + result[ls_index] = sin_value[ls_index] * p0; + result[ls_index] += cos_value[ls_index] * p1; + + store[pr_index] = static_cast(result[pr_index]); + store[ls_index] = static_cast(result[ls_index]); + } + out[0] = *(reinterpret_cast(store)); + } + } +} + +template +void FusedRopeKernel(const Context& dev_ctx, + const DenseTensor& q, + const paddle::optional& k, + const paddle::optional& v, + DenseTensor* out_q, + DenseTensor* out_k, + DenseTensor* out_v) { + int numel = q.numel(); + if (numel <= 0) return; + dev_ctx.template Alloc(out_q); + out_q->Resize(q.dims()); + // small size for broadcast + auto batch_size = q.dims()[0]; + auto num_heads = q.dims()[2]; + auto head_dim = q.dims()[3]; + auto seq_len = q.dims()[1]; + PADDLE_ENFORCE_NE(head_dim % 2, + 1, + phi::errors::InvalidArgument( + "The head_dim of input must be a multiple of 2.")); + + constexpr const int vec_size = 2; + + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); + + int grid = config.block_per_grid.x; + int block = config.thread_per_block.x; + auto stream = dev_ctx.stream(); + + phi::Array outs_data; + phi::Array ins_data; + + ins_data[0] = q.data(); + outs_data[0] = out_q->data(); + int num_inputs = 0; + + if (k.get_ptr()) { + dev_ctx.template Alloc(out_k); + out_k->Resize(q.dims()); + ins_data[1] = k->data(); + outs_data[1] = out_k->data(); + num_inputs++; + } + + if (v.get_ptr()) { + dev_ctx.template Alloc(out_v); + out_v->Resize(q.dims()); + ins_data[2] = v->data(); + outs_data[2] = out_v->data(); + num_inputs++; + } + + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType div_c = static_cast(1.0f / head_dim); + + VectorizedFusedRopeKernel + <<>>(ins_data, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); +} +} // namespace phi + +PD_REGISTER_KERNEL(fused_rope, + GPU, + ALL_LAYOUT, + phi::FusedRopeKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index ccccadd284e9e..b5ffd1f26da89 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -20,6 +20,7 @@ from .fused_ec_moe import fused_ec_moe from .fused_dropout_add import fused_dropout_add from .fused_gate_attention import fused_gate_attention +from .fused_rotary_position_embedding import fused_rotary_position_embedding __all__ = [ @@ -31,4 +32,5 @@ 'fused_bias_dropout_residual_layer_norm', 'fused_ec_moe', 'fused_dropout_add', + 'fused_rotary_position_embedding', ] diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py new file mode 100644 index 0000000000000..f63b58a793c73 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle import _C_ops +from paddle.framework import in_dynamic_mode + + +def fused_rotary_position_embedding(q, k, v): + r""" + Fused rotary position embedding. + + Args: + q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. + k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. + v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. + + Returns: + out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . + + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.functional import fused_rotary_position_embedding + + q = paddle.randn([1, 1, 4, 10], dtype='float16') + k = paddle.randn([1, 1, 4, 10], dtype='float16') + v = paddle.randn([1, 1, 4, 10], dtype='float16') + out = fused_rotary_position_embedding(q, k, v) + """ + if in_dynamic_mode(): + return _C_ops.fused_rope(q, k, v) diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py new file mode 100644 index 0000000000000..5798415409c1a --- /dev/null +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -0,0 +1,144 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core +from paddle.incubate.nn.functional import fused_rotary_position_embedding + + +def deal_qkv(init_q, init_k, init_v): + perm = [0, 2, 1, 3] + q = paddle.transpose(x=init_q, perm=perm) + k = paddle.transpose(x=init_k, perm=perm) + v = paddle.transpose(x=init_v, perm=perm) + return q, k, v + + +def mult_qkv(value, cos_tensor, sin_tensor): + rotate_half_q = paddle.reshape( + paddle.stack([value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1), + paddle.shape(value), + ) + query = paddle.add( + paddle.multiply(value, cos_tensor), + paddle.multiply(rotate_half_q, sin_tensor), + ) + return query + + +def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): + q, k, v = deal_qkv(init_q, init_k, init_v) + + pos_seq = paddle.arange(0, q.shape[2], 1, dtype="float32") + indices = paddle.arange(0, q.shape[3], 2, dtype="float32") + + indices = 1 / 10000 ** (indices / q.shape[3]) + sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0) + + sin_sin = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32) + cos_cos = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32) + numpy_array = sinusoid_inp.numpy() + iter_array = np.nditer(numpy_array) + + i = 0 + + for value in iter_array: + sin_sin[i * 2] = -1 * np.sin(value) + cos_cos[i * 2 + 0] = np.cos(value) + sin_sin[i * 2 + 1] = np.sin(value) + cos_cos[i * 2 + 1] = np.cos(value) + i += 1 + + sin_tensor = paddle.reshape( + paddle.to_tensor(sin_sin, place=paddle.CPUPlace()), + [1, 1, q.shape[2], q.shape[3]], + ) + cos_tensor = paddle.reshape( + paddle.to_tensor(cos_cos, place=paddle.CPUPlace()), + [1, 1, q.shape[2], q.shape[3]], + ) + + query = mult_qkv(q, cos_tensor, sin_tensor) + value = mult_qkv(v, cos_tensor, sin_tensor) + key = mult_qkv(k, cos_tensor, sin_tensor) + + r_query, r_key, r_value = deal_qkv(query, key, value) + + return r_query, r_key, r_value + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA ", +) +class TestFusedRotaryPositionEmbedding(unittest.TestCase): + def setUp(self): + self.shape = [1, 16, 1, 16] + self.dtype = 'float32' + self.training = True + self.seed = 1203 + + def get_paddle_tensor(self): + tmp = paddle.randn(self.shape, self.dtype) + tmp.stop_gradient = False + return tmp + + def get_forward_backward(self, rope_function, seed): + paddle.disable_static() + paddle.seed(seed) + fw = [] + bw = [] + tensor_q = self.get_paddle_tensor() + tensor_k = self.get_paddle_tensor() + tensor_v = self.get_paddle_tensor() + out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v) + + fw.append(out_q) + fw.append(out_k) + fw.append(out_v) + + out_gq = paddle.randn(out_q.shape, self.dtype) + out_gk = paddle.randn(out_q.shape, self.dtype) + out_gv = paddle.randn(out_q.shape, self.dtype) + paddle.autograd.backward( + [out_q, out_k, out_v], [out_gq, out_gk, out_gv], True + ) + bw.append(tensor_q) + bw.append(tensor_k) + bw.append(tensor_v) + + return fw, bw + + def test_fused_dropout_add(self): + p_fw, p_bw = self.get_forward_backward( + paddle_fused_rotary_position_embedding, seed=self.seed + ) + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, seed=self.seed + ) + for i in range(len(p_fw)): + np.testing.assert_allclose( + p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 + ) + + +if __name__ == '__main__': + unittest.main()