Skip to content

Commit

Permalink
Qwen support position_ids (#7359)
Browse files Browse the repository at this point in the history
* qwen add position ids

* qwen add position ids

* bug fix

* bug fix
  • Loading branch information
wtmlon authored Nov 2, 2023
1 parent 5a50bdd commit fc2bf9c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
39 changes: 32 additions & 7 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def forward(
hidden_states,
layer_past=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
Expand Down Expand Up @@ -259,11 +260,11 @@ def forward(
v=None,
sin=sin,
cos=cos,
position_ids=None,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query, key = apply_rotary_pos_emb(query, key, cos, sin)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids=position_ids)

if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
Expand Down Expand Up @@ -352,6 +353,7 @@ def forward(
hidden_states,
layer_past=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
Expand All @@ -363,6 +365,7 @@ def forward(
layernorm_output,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -578,6 +581,7 @@ def recompute_training(
hidden_states,
layer_past,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
Expand All @@ -594,6 +598,7 @@ def custom_forward(*inputs):
hidden_states,
layer_past,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
Expand Down Expand Up @@ -635,6 +640,7 @@ def forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
Expand Down Expand Up @@ -702,6 +708,7 @@ def forward(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
Expand All @@ -712,6 +719,7 @@ def forward(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
Expand Down Expand Up @@ -839,6 +847,10 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
model_kwargs["cache"] = outputs.past_key_values
model_kwargs["past_key_values"] = outputs.past_key_values

if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
position_ids = model_kwargs["position_ids"]
model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1)

# update attention_mask
if not is_encoder_decoder and "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
Expand All @@ -852,10 +864,13 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
return model_kwargs

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
if position_ids is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)

if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand All @@ -867,6 +882,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"position_ids": position_ids,
}
)
return model_inputs
Expand All @@ -888,6 +904,7 @@ def forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
Expand All @@ -904,6 +921,7 @@ def forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
Expand Down Expand Up @@ -985,9 +1003,16 @@ def rotate_half(x):
return paddle.concat([-x2, x1], axis=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):

if position_ids is None:
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
else:
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/qwen/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
class QWenTokenizer(PretrainedTokenizer):
"""QWen tokenizer."""

model_input_names = ["input_ids", "attention_mask"]
model_input_names = ["input_ids", "attention_mask", "position_ids"]
resource_files_names = VOCAB_FILES_NAMES

def __init__(
Expand Down

0 comments on commit fc2bf9c

Please sign in to comment.