Skip to content

Commit

Permalink
Implement rotary_embedding for Ascend (#753)
Browse files Browse the repository at this point in the history
* implement rotary_embedding for Ascend

* create functions_ext folder and move rotary_embedding.cpp to it
  • Loading branch information
jfxu-st authored Dec 20, 2023
1 parent 9c47205 commit 2589ad4
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 15 deletions.
3 changes: 3 additions & 0 deletions impl/ascend/convert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,6 @@

- diopiScatterInpScalar:
dtype: (uint8,int8)->int32

- diopiRotaryEmbedding:
dtype: (float64)->float32
12 changes: 0 additions & 12 deletions impl/ascend/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,18 +1822,6 @@
),
),

'rotary_emb': dict(
name=["rotary_emb"],
tensor_para=dict(
args=[
{
"ins": ['input'],
"dtype": [Skip(np.float64), Skip(np.float32), Skip(np.float16)],
},
],
),
),

# temp for 910B
'normal_': dict(
name=["normal_"],
Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ set(OP_PLUGIN_API_SRC
)

set(DIOPI_IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/diopi_impl/)
file(GLOB DIOPI_IMPL_SRC "${DIOPI_IMPL_DIR}/*.cpp")
file(GLOB DIOPI_IMPL_SRC "${DIOPI_IMPL_DIR}/*.cpp" "${DIOPI_IMPL_DIR}/functions_ext/*.cpp")

add_definitions(-DBUILD_LIBTORCH)

Expand Down
1 change: 1 addition & 0 deletions impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,4 @@ ascend_npu:
- diopiNormalInp
- diopiNorm
- diopiMaxPool2dWithIndices
- diopiRotaryEmbedding
26 changes: 26 additions & 0 deletions impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* @file
* @author DeepLink
* @copyright (c) 2023, DeepLink.
*/

#include "../helper.hpp"
#include "op_plugin/AclOpsInterface.h"

namespace OP_IMPL_NS {

DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos,
diopiConstTensorHandle_t sin, const bool conj, const bool interleaved) {
TORCH_CHECK(false == interleaved, "interleaved=true is currently not supported for Ascend");
BEGIN_CALL_ACL_OP(out, x, cos, sin);
at::Tensor cosRepeated = acl_op::repeat(cosAt, {1, 1, 1, 2});
at::Tensor sinRepeated = acl_op::repeat(sinAt, {1, 1, 1, 2});
if (conj) {
acl_op::neg_(sinRepeated);
}
at_npu::native::OpCommand cmd;
cmd.Name("RotaryMul").Input(xAt).Input(cosRepeated).Input(sinRepeated).Output(outAt).Run();
END_CALL_ACL_OP();
}

} // namespace OP_IMPL_NS
4 changes: 2 additions & 2 deletions impl/ascend_npu/diopi_impl/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ template <>
inline std::string dumpArgs(const at::Tensor& t) {
std::stringstream stream;
if (t.defined()) {
stream << " shape:" << t.sizes() << ", numel:" << t.numel() << ", strides:" << t.strides() << " " << t.options() << ", ptr:" << t.data_ptr()
<< ", nbytes:" << t.storage().nbytes() << ", is_contiguous:" << t.is_contiguous();
stream << " shape:" << t.sizes() << ", numel:" << t.numel() << ", strides:" << t.strides() << ", storage_offset:" << t.storage_offset() << " "
<< t.options() << ", ptr:" << t.data_ptr() << ", nbytes:" << t.storage().nbytes() << ", is_contiguous:" << t.is_contiguous() << "\n";
} else {
stream << " undefined" << std::endl;
}
Expand Down

0 comments on commit 2589ad4

Please sign in to comment.