From 2589ad4eb7b62c47c824f3fc09fd371b5c5c005a Mon Sep 17 00:00:00 2001 From: jfxu-st <143591296+jfxu-st@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:49:20 +0800 Subject: [PATCH] Implement rotary_embedding for Ascend (#753) * implement rotary_embedding for Ascend * create functions_ext folder and move rotary_embedding.cpp to it --- impl/ascend/convert_config.yaml | 3 +++ impl/ascend/device_configs.py | 12 --------- impl/ascend_npu/CMakeLists.txt | 2 +- impl/ascend_npu/ascend_config.yaml | 1 + .../functions_ext/rotary_embedding.cpp | 26 +++++++++++++++++++ impl/ascend_npu/diopi_impl/helper.hpp | 4 +-- 6 files changed, 33 insertions(+), 15 deletions(-) create mode 100644 impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp diff --git a/impl/ascend/convert_config.yaml b/impl/ascend/convert_config.yaml index fdb8a7ea2..05473086a 100644 --- a/impl/ascend/convert_config.yaml +++ b/impl/ascend/convert_config.yaml @@ -259,3 +259,6 @@ - diopiScatterInpScalar: dtype: (uint8,int8)->int32 + +- diopiRotaryEmbedding: + dtype: (float64)->float32 diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index b69c0b287..d53888c9f 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -1822,18 +1822,6 @@ ), ), - 'rotary_emb': dict( - name=["rotary_emb"], - tensor_para=dict( - args=[ - { - "ins": ['input'], - "dtype": [Skip(np.float64), Skip(np.float32), Skip(np.float16)], - }, - ], - ), - ), - # temp for 910B 'normal_': dict( name=["normal_"], diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index a18e3ffdf..0d632f17d 100755 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -555,7 +555,7 @@ set(OP_PLUGIN_API_SRC ) set(DIOPI_IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/diopi_impl/) -file(GLOB DIOPI_IMPL_SRC "${DIOPI_IMPL_DIR}/*.cpp") +file(GLOB DIOPI_IMPL_SRC "${DIOPI_IMPL_DIR}/*.cpp" "${DIOPI_IMPL_DIR}/functions_ext/*.cpp") add_definitions(-DBUILD_LIBTORCH) diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index fbcd305fb..7ba2d697a 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -211,3 +211,4 @@ ascend_npu: - diopiNormalInp - diopiNorm - diopiMaxPool2dWithIndices +- diopiRotaryEmbedding diff --git a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp new file mode 100644 index 000000000..d62f9c7e1 --- /dev/null +++ b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp @@ -0,0 +1,26 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2023, DeepLink. + */ + +#include "../helper.hpp" +#include "op_plugin/AclOpsInterface.h" + +namespace OP_IMPL_NS { + +DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos, + diopiConstTensorHandle_t sin, const bool conj, const bool interleaved) { + TORCH_CHECK(false == interleaved, "interleaved=true is currently not supported for Ascend"); + BEGIN_CALL_ACL_OP(out, x, cos, sin); + at::Tensor cosRepeated = acl_op::repeat(cosAt, {1, 1, 1, 2}); + at::Tensor sinRepeated = acl_op::repeat(sinAt, {1, 1, 1, 2}); + if (conj) { + acl_op::neg_(sinRepeated); + } + at_npu::native::OpCommand cmd; + cmd.Name("RotaryMul").Input(xAt).Input(cosRepeated).Input(sinRepeated).Output(outAt).Run(); + END_CALL_ACL_OP(); +} + +} // namespace OP_IMPL_NS diff --git a/impl/ascend_npu/diopi_impl/helper.hpp b/impl/ascend_npu/diopi_impl/helper.hpp index 5c4f53fbd..c78c77abc 100755 --- a/impl/ascend_npu/diopi_impl/helper.hpp +++ b/impl/ascend_npu/diopi_impl/helper.hpp @@ -304,8 +304,8 @@ template <> inline std::string dumpArgs(const at::Tensor& t) { std::stringstream stream; if (t.defined()) { - stream << " shape:" << t.sizes() << ", numel:" << t.numel() << ", strides:" << t.strides() << " " << t.options() << ", ptr:" << t.data_ptr() - << ", nbytes:" << t.storage().nbytes() << ", is_contiguous:" << t.is_contiguous(); + stream << " shape:" << t.sizes() << ", numel:" << t.numel() << ", strides:" << t.strides() << ", storage_offset:" << t.storage_offset() << " " + << t.options() << ", ptr:" << t.data_ptr() << ", nbytes:" << t.storage().nbytes() << ", is_contiguous:" << t.is_contiguous() << "\n"; } else { stream << " undefined" << std::endl; }