Skip to content

Commit

Permalink
[Bug fixes] fix generation cache bug (#5535)
Browse files Browse the repository at this point in the history
* fix generation cache bug

* update beam_search & group_beam_search method
  • Loading branch information
wj-Mcat authored Apr 6, 2023
1 parent 019ce12 commit 8737f68
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,10 +998,13 @@ def greedy_search(self, input_ids, logits_processors, max_length, pad_token_id,
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(**model_inputs)
outputs = outputs[0] if isinstance(outputs, tuple) else outputs

# To hundle the logits is a ModelOutput
logits = outputs.logits if isinstance(outputs, ModelOutput) else outputs
if isinstance(outputs, tuple):
logits = outputs[0]
elif isinstance(outputs, ModelOutput):
logits = outputs.logits
else:
logits = outputs

# [batch_size, vocab_size]
next_token_logits = logits[:, -1, :]
Expand Down Expand Up @@ -1061,10 +1064,13 @@ def sample(
# prepare model inputs & get model output
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs)
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
if isinstance(outputs, tuple):
logits = outputs[0]
elif isinstance(outputs, ModelOutput):
logits = outputs.logits
else:
logits = outputs

# To hundle the logits is a ModelOutput
logits = outputs.logits if isinstance(outputs, ModelOutput) else outputs
# [batch_size, vocab_size]
logits = logits[:, -1, :]

Expand Down Expand Up @@ -1147,7 +1153,12 @@ def _forward_(**args):
return self(**model_inputs, **immutable)

def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_flag, model_kwargs):
logits = outputs[0] if isinstance(outputs, tuple) else outputs
if isinstance(outputs, tuple):
logits = outputs[0]
elif isinstance(outputs, ModelOutput):
logits = outputs.logits
else:
logits = outputs

# [batch_size, vocab_size]
logits = logits[:, -1, :]
Expand Down Expand Up @@ -1255,10 +1266,14 @@ def beam_search(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(**model_inputs)
outputs = outputs[0] if isinstance(outputs, tuple) else outputs

# To hundle the logits is a ModelOutput
logits = outputs.logits if isinstance(outputs, ModelOutput) else outputs
if isinstance(outputs, tuple):
logits = outputs[0]
elif isinstance(outputs, ModelOutput):
logits = outputs.logits
else:
logits = outputs

# [batch_size, vocab_size]
logits = logits[:, -1, :]

Expand Down Expand Up @@ -1395,11 +1410,13 @@ def group_beam_search(
)

group_input_ids = input_ids[batch_group_indices]
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
# select outputs of beams of current group only

# To hundle the logits is a ModelOutput
logits = outputs.logits if isinstance(outputs, ModelOutput) else outputs
if isinstance(outputs, tuple):
logits = outputs[0]
elif isinstance(outputs, ModelOutput):
logits = outputs.logits
else:
logits = outputs

logits = logits[:, -1, :]
logits = paddle.index_select(logits, paddle.to_tensor(batch_group_indices))
Expand Down

0 comments on commit 8737f68

Please sign in to comment.