Skip to content

Commit

Permalink
change llama/modeling.py to opt npu performence (#8342)
Browse files Browse the repository at this point in the history
* change llama/modeling.py to opt npu performence

* update

* update

* Update modeling.py

* add judge

---------

Co-authored-by: Wang Huan <wanghuan29@baidu.com>
  • Loading branch information
Galaxy1458 and wanghuancoder committed Apr 30, 2024
1 parent e7de0fa commit 1c781d8
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def swiglu(x, y=None):
"LlamaPretrainingCriterion",
]

npu_is_casual = False

def _get_interleave(n):
def _get_interleave_power_of_2(n):
Expand Down Expand Up @@ -244,7 +245,7 @@ def scaled_dot_product_attention(
attention_mask is None,
True,
False,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
Expand Down Expand Up @@ -1118,6 +1119,7 @@ def __init__(self, config, layerwise_recompute: bool = False):
self.layerwise_recompute = layerwise_recompute
self.recompute_granularity = config.recompute_granularity


def forward(
self,
hidden_states: paddle.Tensor,
Expand Down Expand Up @@ -1612,11 +1614,12 @@ def forward(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
is_casual = is_casual_mask(attention_mask)
if is_casual and alibi is None:
attention_mask = None
else:
npu_is_casual = is_casual
attention_mask = attention_mask.astype("bool")
hidden_states = inputs_embeds
# decoder layers
Expand Down Expand Up @@ -1722,9 +1725,12 @@ def forward(self, prediction_scores, masked_lm_labels):
_hcg = fleet.get_hybrid_communicate_group()
masked_lm_loss = ConcatSePMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group())
# skip ignore_index which loss == 0
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
loss = paddle.mean(masked_lm_loss)

# masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
# loss = paddle.mean(masked_lm_loss)
binary_sequence = paddle.where(masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss))
sum_ = paddle.sum(binary_sequence)
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_

return loss


Expand Down

0 comments on commit 1c781d8

Please sign in to comment.