Skip to content

Commit

Permalink
add_update_kwags_self
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoxiaohehe001 committed Aug 25, 2023
1 parent 3258050 commit 55d1f6d
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def generate(
)
return ret

@staticmethod
def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_token_id, config, model_kwargs):
def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, eos_token_id, model_kwargs):
if cache is None:
model_kwargs["step_idx"] = paddle.where(
model_kwargs["seq_len_encoder"] == 0,
Expand All @@ -163,7 +162,7 @@ def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_tok
if cache is None:
# encoder's generation
model_kwargs["tgt_ids"] = paddle.where(just_decoder, model_kwargs["tgt_ids"], next_tokens)
if config["position_encoding_2d"] and config.position_encoding_2d is True:
if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True:
tgt_pos = model_kwargs["tgt_pos"]
new_position_id = tgt_pos[:, 0, :].clone()
new_block_id = tgt_pos[:, 1, :].clone()
Expand All @@ -183,7 +182,7 @@ def update_model_kwargs_for_generation(cache, just_decoder, next_tokens, eos_tok
)
else:
model_kwargs["tgt_ids"] = next_tokens
if config["position_encoding_2d"] and config.position_encoding_2d is True:
if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True:
tgt_pos = model_kwargs["tgt_pos"]
new_position_id = tgt_pos[:, 0, :].clone()
new_block_id = tgt_pos[:, 1, :].clone()
Expand Down Expand Up @@ -268,7 +267,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
paddle.distributed.broadcast(next_tokens, 0)

model_kwargs = self.update_model_kwargs_for_generation(
cache, just_decoder, next_tokens, eos_token_id, self.model.config, model_kwargs
cache, just_decoder, next_tokens, eos_token_id, model_kwargs
)

save_with_output(
Expand Down

0 comments on commit 55d1f6d

Please sign in to comment.