Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] support sparse attention for LLAMA #8592

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
attention_mask,
output_attentions,
alibi=None,
attn_mask_startend_row_indices=None,
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
Expand Down Expand Up @@ -208,13 +209,23 @@
is_causal=True,
)
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
if attn_mask_startend_row_indices is not None:
assert alibi is None, "flash_attention_with_sparse_mask not support alibi"
attn_output = F.flash_attention_with_sparse_mask(

Check warning on line 214 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L212-L214

Added lines #L212 - L214 were not covered by tests
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
)
else:
attn_output = F.scaled_dot_product_attention(

Check warning on line 222 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L222

Added line #L222 was not covered by tests
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
)
attn_weights = None

if reshard_layer is not None:
Expand Down
23 changes: 19 additions & 4 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
attention_mask,
output_attentions,
alibi=None,
attn_mask_startend_row_indices=None,
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
Expand All @@ -228,6 +229,7 @@
attention_mask,
output_attentions,
alibi,
attn_mask_startend_row_indices,
sequence_parallel,
reshard_layer,
npu_is_casual,
Expand Down Expand Up @@ -816,6 +818,7 @@
output_attentions: bool = False,
use_cache: bool = False,
alibi: Optional[paddle.Tensor] = None,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
npu_is_casual: bool = False,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Expand Down Expand Up @@ -1014,6 +1017,7 @@
attention_mask,
output_attentions,
alibi,
attn_mask_startend_row_indices,
self.sequence_parallel,
reshard_layer=self.reshard_layer,
use_reentrant=self.config.recompute_use_reentrant,
Expand All @@ -1027,6 +1031,7 @@
attention_mask,
output_attentions,
alibi,
attn_mask_startend_row_indices,
self.sequence_parallel,
reshard_layer=self.reshard_layer,
npu_is_casual=npu_is_casual,
Expand Down Expand Up @@ -1082,6 +1087,7 @@
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
alibi: Optional[paddle.Tensor] = None,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
npu_is_casual: bool = False,
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
"""
Expand Down Expand Up @@ -1119,6 +1125,7 @@
output_attentions,
use_cache,
alibi,
attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
)
else:
Expand All @@ -1130,6 +1137,7 @@
output_attentions,
use_cache,
alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
npu_is_casual=npu_is_casual,
)

Expand Down Expand Up @@ -1459,6 +1467,7 @@
past_key_value: Tensor,
use_cache: bool,
alibi=None,
attn_mask_startend_row_indices=None,
):
def create_custom_forward(module):
def custom_forward(*inputs):
Expand All @@ -1475,6 +1484,7 @@
past_key_value,
use_cache,
alibi,
attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
)

Expand All @@ -1491,6 +1501,7 @@
output_attentions=False,
output_hidden_states=None,
return_dict=False,
attn_mask_startend_row_indices=None,
**kwargs,
):
if self.sequence_parallel and use_cache:
Expand Down Expand Up @@ -1537,10 +1548,10 @@
if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi):
raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi")
# embed positions
if attention_mask is None:
if attn_mask_startend_row_indices is None and attention_mask is None:
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if self.config.alibi:
if attn_mask_startend_row_indices is None and self.config.alibi:
if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
Expand Down Expand Up @@ -1571,14 +1582,14 @@

if use_casual_mask:
attention_mask = None
else:
elif attn_mask_startend_row_indices is None:

Check warning on line 1585 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1585

Added line #L1585 was not covered by tests
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
if attn_mask_startend_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu":
if use_casual_mask:
is_casual = True
else:
Expand Down Expand Up @@ -1615,6 +1626,7 @@
past_key_value,
use_cache,
alibi=alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -1625,6 +1637,7 @@
past_key_value,
use_cache,
alibi=alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
npu_is_casual=is_casual,
)

Expand Down Expand Up @@ -1886,6 +1899,7 @@
output_attentions=None,
output_hidden_states=None,
return_dict=None,
attn_mask_startend_row_indices=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -1902,6 +1916,7 @@
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

hidden_states = outputs[0] # [bs, seq_len, dim]
Expand Down
80 changes: 66 additions & 14 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import paddle
import paddle.distributed.fleet as fleet
import paddle.nn as nn
Expand Down Expand Up @@ -47,36 +49,48 @@

def parse_args(args):
if isinstance(args, tuple):
if len(args) == 4:
hidden_states, attention_mask, position_ids, alibi = args
if len(args) == 3:
hidden_states, attention_mask, position_ids = args
if len(args) == 5:
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = args
elif len(args) == 4:
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args
alibi = None
elif len(args) == 3:
hidden_states, attention_mask, attn_mask_startend_row_indices = args
position_ids = None

Check warning on line 59 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L52-L59

Added lines #L52 - L59 were not covered by tests
alibi = None
elif len(args) == 2:
hidden_states, attention_mask = args
attn_mask_startend_row_indices = None

Check warning on line 63 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L63

Added line #L63 was not covered by tests
position_ids = None
alibi = None
else:
hidden_states = args
attention_mask, position_ids, alibi = None, None, None
attention_mask, attn_mask_startend_row_indices, position_ids, alibi = None, None, None, None

Check warning on line 68 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L68

Added line #L68 was not covered by tests

if position_ids is not None:
position_ids.stop_gradient = True

if attention_mask is not None:
attention_mask.stop_gradient = True

if attn_mask_startend_row_indices is not None:
attn_mask_startend_row_indices.stop_gradient = True

Check warning on line 77 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L76-L77

Added lines #L76 - L77 were not covered by tests

if alibi is not None:
alibi.stop_gradient = True

return hidden_states, attention_mask, position_ids, alibi
return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi

Check warning on line 82 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L82

Added line #L82 was not covered by tests


def return_args(hidden_states, attention_mask=None, position_ids=None, alibi=None):
def return_args(
hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None, alibi=None
):
ret = (hidden_states,)

if attention_mask is not None:
ret += (attention_mask.clone(),)
if attn_mask_startend_row_indices is not None:
ret += (attn_mask_startend_row_indices.clone(),)

Check warning on line 93 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L92-L93

Added lines #L92 - L93 were not covered by tests
if position_ids is not None:
ret += (position_ids.clone(),)
if alibi is not None:
Expand Down Expand Up @@ -114,7 +128,7 @@
Returns:
_type_: _description_
"""
input_ids, attention_mask, position_ids, alibi = parse_args(args)
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)

Check warning on line 131 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L131

Added line #L131 was not covered by tests
input_embeds = self.embed_tokens(input_ids)
if self.sequence_parallel:
from paddlenlp.transformers import ScatterOp
Expand All @@ -128,6 +142,9 @@
batch_size, seq_length = input_ids.shape
alibi = None
if self.config.alibi:
assert (

Check warning on line 145 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L145

Added line #L145 was not covered by tests
attn_mask_startend_row_indices is None
), "alibi and attn_mask_startend_row_indices can not be set at same time"
# embed positions
mask = (
attention_mask
Expand All @@ -150,6 +167,9 @@
alibi.stop_gradient = True

if attention_mask is not None:
assert (

Check warning on line 170 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L170

Added line #L170 was not covered by tests
attn_mask_startend_row_indices is None
), "attention_mask and attn_mask_startend_row_indices can not be set at same time"
attention_mask = LlamaModel._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
)
Expand All @@ -166,28 +186,30 @@
)
attention_mask.stop_gradient = True

return return_args(input_embeds, attention_mask, position_ids, alibi)
return return_args(input_embeds, attention_mask, attn_mask_startend_row_indices, position_ids, alibi)

Check warning on line 189 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L189

Added line #L189 was not covered by tests


class LlamaDecoderLayerPipe(LlamaDecoderLayer):
def forward(self, args):
hidden_states, attention_mask, position_ids, alibi = parse_args(args)
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)

Check warning on line 194 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L194

Added line #L194 was not covered by tests
# we can't distinguish
# hidden_states, attention_mask, position_ids or
# hidden_states, attention_mask, alibi

if self.config.alibi and alibi is None and position_ids is not None:
alibi = position_ids
position_ids = None

has_gradient = not hidden_states.stop_gradient
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
if attention_mask is not None or alibi is not None:
if attention_mask is not None or alibi is not None or attn_mask_startend_row_indices is not None:

Check warning on line 205 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L205

Added line #L205 was not covered by tests
hidden_states = recompute(
super().forward,
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
alibi=alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=False,
)
else:
Expand All @@ -196,14 +218,19 @@
super().forward,
hidden_states,
position_ids=position_ids,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
)
else:
hidden_states = super().forward(
hidden_states, position_ids=position_ids, attention_mask=attention_mask, alibi=alibi
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
alibi=alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

return return_args(hidden_states, attention_mask, position_ids, alibi)
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi)

Check warning on line 233 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L233

Added line #L233 was not covered by tests


class LlamaRMSNormPipe(nn.Layer):
Expand All @@ -212,7 +239,7 @@
self.norm = LlamaRMSNorm(config)

def forward(self, args):
hidden_states, attention_mask, position_ids, alibi = parse_args(args)
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)

Check warning on line 242 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L242

Added line #L242 was not covered by tests
return self.norm(hidden_states)


Expand All @@ -232,6 +259,31 @@

# DONOT Add base_model_prefix !!!!

@classmethod
def _prepare_pipeline_inputs_func(cls, inputs):

first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"]
last_stage_keys = ["labels"]

Check warning on line 266 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L265-L266

Added lines #L265 - L266 were not covered by tests

def get_expected_keys(inputs, keys):
ret = tuple([inputs.pop(k) if k in inputs else None for k in keys])
if len(ret) == 1:
ret = ret[0]
return ret

Check warning on line 272 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L268-L272

Added lines #L268 - L272 were not covered by tests

if type(inputs) is dict or type(inputs) is OrderedDict:
return [

Check warning on line 275 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L274-L275

Added lines #L274 - L275 were not covered by tests
get_expected_keys(inputs, first_stage_keys),
get_expected_keys(inputs, last_stage_keys),
]

keys = list(inputs[0].keys())
inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys}
return [

Check warning on line 282 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L280-L282

Added lines #L280 - L282 were not covered by tests
get_expected_keys(inputs_batch, first_stage_keys),
get_expected_keys(inputs_batch, last_stage_keys),
]

def __init__(self, config):
self.config = config

Expand Down
Loading