Skip to content

Commit

Permalink
[NPU]Custom fusion operator unification (#8431)
Browse files Browse the repository at this point in the history
* update

* add llama-npu-opt-script

* Update dev_opt_lora.sh

* Update dev_opt_ppt.sh

* Update dev_opt_lora.sh

* Update dev_opt_ppt.sh

* Update dev_opt_sft.sh

* Rename dev_opt_lora.sh to llama_npu_opt_lora.sh

* Update dev_opt_ppt.sh

* Rename dev_opt_ppt.sh to llama_npu_opt_ppt.sh

* Update llama_npu_opt_lora.sh

* Update and rename dev_opt_sft.sh to llama_npu_opt_sft.sh

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* add funsion ops

* update

* Update fusion_ops.py

* update
  • Loading branch information
Galaxy1458 authored May 14, 2024
1 parent 53ad2da commit 05acad5
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 116 deletions.
189 changes: 189 additions & 0 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import paddle.nn.functional as F

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


from paddle.utils import try_import

from paddlenlp.utils.tools import get_env_device

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None
try:
if get_env_device() == "npu":
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None


def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
return query_states, key_states


def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]


def fusion_rms_norm(hidden_states, weight, variance_epsilon):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0]
except ImportError:
raise NotImplementedError(
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, weight, variance_epsilon)


def fusion_flash_attention(
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi=None,
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
query_states,
key_states,
value_states,
None,
attention_mask,
0.0,
attention_mask is None,
True,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

if reshard_layer is not None:
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
attn_output = reshard_layer(
attn_output,
split_axis=1,
concat_axis=2,
)
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
assert (
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output
143 changes: 27 additions & 116 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def swiglu(x, y=None):
)
except:
pass
from paddle.utils import try_import

from paddlenlp.transformers.conversion_utils import (
StateDictNameMapping,
Expand All @@ -81,14 +80,16 @@ def swiglu(x, y=None):

try:
if get_env_device() == "npu":
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None
from . import fusion_ops

rms_norm_fused = fusion_ops.rms_norm_fused

__all__ = [
"LlamaModel",
Expand Down Expand Up @@ -215,67 +216,22 @@ def scaled_dot_product_attention(
_, kv_seq_len, _, _ = value_states.shape

if config.use_flash_attention and flash_attention:
return fusion_ops.fusion_flash_attention(
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi,
sequence_parallel,
reshard_layer,
npu_is_casual,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]

version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
query_states,
key_states,
value_states,
None,
attention_mask,
0.0,
attention_mask is None,
True,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

if reshard_layer is not None:
# attn_output shape: [bs, seqlen, num_head/sep, head_dim]
attn_output = reshard_layer(
attn_output,
split_axis=1,
concat_axis=2,
)
# attn_output shape: [bs, seqlen/sep, num_head, head_dim]
assert (
config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output
else:
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
Expand Down Expand Up @@ -385,11 +341,6 @@ def _expand_2d_mask(mask, dtype, tgt_length):
return expanded_mask


def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]


class LlamaRMSNorm(nn.Layer):
def __init__(self, config):
super().__init__()
Expand All @@ -407,18 +358,7 @@ def __init__(self, config):

def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
raise NotImplementedError(
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
with paddle.amp.auto_cast(False):
Expand Down Expand Up @@ -974,45 +914,16 @@ def forward(
batch_size, seq_length, _, _ = query_states.shape
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
if self.use_fused_rope:
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
query_states, key_states = fusion_ops.fusion_rope(
query_states,
key_states,
value_states,
hidden_states,
position_ids,
past_key_value,
self.rotary_emb,
)

else:
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
Expand Down

0 comments on commit 05acad5

Please sign in to comment.