Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU]Custom fusion operator unification #8431

Merged
merged 34 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5ed9ed
update
Galaxy1458 May 9, 2024
8ebdcfa
Merge branch 'develop' of https://github.com/Galaxy1458/PaddleNLP int…
Galaxy1458 May 9, 2024
bd0aa87
add llama-npu-opt-script
Galaxy1458 May 9, 2024
ce921ab
Merge branch 'PaddlePaddle:develop' into develop
Galaxy1458 May 9, 2024
cc24132
Update dev_opt_lora.sh
Galaxy1458 May 9, 2024
036d03c
Update dev_opt_ppt.sh
Galaxy1458 May 9, 2024
8dd2d02
Update dev_opt_lora.sh
Galaxy1458 May 9, 2024
96e69aa
Update dev_opt_ppt.sh
Galaxy1458 May 9, 2024
a35ba59
Update dev_opt_sft.sh
Galaxy1458 May 9, 2024
68388a7
Rename dev_opt_lora.sh to llama_npu_opt_lora.sh
Galaxy1458 May 11, 2024
fee8f04
Update dev_opt_ppt.sh
Galaxy1458 May 11, 2024
783de3b
Rename dev_opt_ppt.sh to llama_npu_opt_ppt.sh
Galaxy1458 May 11, 2024
10f9415
Update llama_npu_opt_lora.sh
Galaxy1458 May 11, 2024
f3d96e5
Update and rename dev_opt_sft.sh to llama_npu_opt_sft.sh
Galaxy1458 May 11, 2024
e51cc9a
Merge branch 'PaddlePaddle:develop' into develop
Galaxy1458 May 13, 2024
6771aa9
add funsion ops
Galaxy1458 May 13, 2024
61dc79c
add funsion ops
Galaxy1458 May 13, 2024
558200f
add funsion ops
Galaxy1458 May 13, 2024
f387c30
add funsion ops
Galaxy1458 May 13, 2024
a12947b
add funsion ops
Galaxy1458 May 13, 2024
aff105e
add funsion ops
Galaxy1458 May 13, 2024
075c8de
add funsion ops
Galaxy1458 May 13, 2024
15f2fe3
add funsion ops
Galaxy1458 May 13, 2024
2741769
add funsion ops
Galaxy1458 May 13, 2024
12fc048
add funsion ops
Galaxy1458 May 13, 2024
f678361
add funsion ops
Galaxy1458 May 13, 2024
9b2ca6b
add funsion ops
Galaxy1458 May 13, 2024
cac0f8e
add funsion ops
Galaxy1458 May 13, 2024
73866a2
add funsion ops
Galaxy1458 May 13, 2024
d8f1950
add funsion ops
Galaxy1458 May 13, 2024
9a2f1c5
add funsion ops
Galaxy1458 May 13, 2024
df78b71
update
Galaxy1458 May 14, 2024
8c3cd0d
Update fusion_ops.py
Galaxy1458 May 14, 2024
0a6d6b8
update
Galaxy1458 May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import paddle.nn.functional as F

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None

Check warning on line 23 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L22-L23

Added lines #L22 - L23 were not covered by tests

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

Check warning on line 27 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L27

Added line #L27 was not covered by tests

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y

Check warning on line 32 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L29-L32

Added lines #L29 - L32 were not covered by tests


from paddle.utils import try_import

from paddlenlp.utils.tools import get_env_device

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None

Check warning on line 42 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L41-L42

Added lines #L41 - L42 were not covered by tests
try:
if get_env_device() == "npu":
from paddle.base import core

Check warning on line 45 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L45

Added line #L45 was not covered by tests

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)

Check warning on line 49 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L47-L49

Added lines #L47 - L49 were not covered by tests
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None

Check warning on line 52 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L51-L52

Added lines #L51 - L52 were not covered by tests


def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]

Check warning on line 62 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L56-L62

Added lines #L56 - L62 were not covered by tests
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(

Check warning on line 67 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L65-L67

Added lines #L65 - L67 were not covered by tests
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(

Check warning on line 76 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L76

Added line #L76 was not covered by tests
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(

Check warning on line 86 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L86

Added line #L86 was not covered by tests
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
return query_states, key_states

Check warning on line 95 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L95

Added line #L95 was not covered by tests


def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]

Check warning on line 100 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L99-L100

Added lines #L99 - L100 were not covered by tests


def fusion_rms_norm(hidden_states, weight, variance_epsilon):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

Check warning on line 108 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L104-L108

Added lines #L104 - L108 were not covered by tests

return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0]
except ImportError:
raise NotImplementedError(

Check warning on line 112 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L110-L112

Added lines #L110 - L112 were not covered by tests
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, weight, variance_epsilon)

Check warning on line 115 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L115

Added line #L115 was not covered by tests


def fusion_flash_attention(
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi=None,
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
attn_output, attn_weights = flash_attention(

Check warning on line 136 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L130-L136

Added lines #L130 - L136 were not covered by tests
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
attn_output = core.eager._run_custom_op(

Check warning on line 148 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L144-L148

Added lines #L144 - L148 were not covered by tests
"flash_attention_npu",
query_states,
key_states,
value_states,
None,
attention_mask,
0.0,
attention_mask is None,
True,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(

Check warning on line 162 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L162

Added line #L162 was not covered by tests
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

Check warning on line 169 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L169

Added line #L169 was not covered by tests

if reshard_layer is not None:

Check warning on line 171 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L171

Added line #L171 was not covered by tests
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
attn_output = reshard_layer(

Check warning on line 173 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L173

Added line #L173 was not covered by tests
attn_output,
split_axis=1,
concat_axis=2,
)
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
assert (

Check warning on line 179 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L179

Added line #L179 was not covered by tests
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

Check warning on line 183 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L182-L183

Added lines #L182 - L183 were not covered by tests

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])

Check warning on line 186 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L185-L186

Added lines #L185 - L186 were not covered by tests
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output

Check warning on line 189 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L188-L189

Added lines #L188 - L189 were not covered by tests
143 changes: 27 additions & 116 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
)
except:
pass
from paddle.utils import try_import

from paddlenlp.transformers.conversion_utils import (
StateDictNameMapping,
Expand All @@ -81,14 +80,16 @@

try:
if get_env_device() == "npu":
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意看是不是有不需要的代码,注意删除掉。

from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None
from . import fusion_ops

rms_norm_fused = fusion_ops.rms_norm_fused

__all__ = [
"LlamaModel",
Expand Down Expand Up @@ -215,67 +216,22 @@
_, kv_seq_len, _, _ = value_states.shape

if config.use_flash_attention and flash_attention:
return fusion_ops.fusion_flash_attention(

Check warning on line 219 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L219

Added line #L219 was not covered by tests
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi,
sequence_parallel,
reshard_layer,
npu_is_casual,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]

version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
query_states,
key_states,
value_states,
None,
attention_mask,
0.0,
attention_mask is None,
True,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

if reshard_layer is not None:
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
attn_output = reshard_layer(
attn_output,
split_axis=1,
concat_axis=2,
)
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
assert (
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output
else:
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
Expand Down Expand Up @@ -385,11 +341,6 @@
return expanded_mask


def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]


class LlamaRMSNorm(nn.Layer):
def __init__(self, config):
super().__init__()
Expand All @@ -407,18 +358,7 @@

def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
raise NotImplementedError(
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon)

Check warning on line 361 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L361

Added line #L361 was not covered by tests

if paddle.in_dynamic_mode():
with paddle.amp.auto_cast(False):
Expand Down Expand Up @@ -974,45 +914,16 @@
batch_size, seq_length, _, _ = query_states.shape
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
if self.use_fused_rope:
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
query_states, key_states = fusion_ops.fusion_rope(

Check warning on line 917 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L917

Added line #L917 was not covered by tests
query_states,
key_states,
value_states,
hidden_states,
position_ids,
past_key_value,
self.rotary_emb,
)

else:
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
Expand Down
Loading