Skip to content

Commit

Permalink
Merge pull request huggingface#1 from ondewo/hannan_updates_flax
Browse files Browse the repository at this point in the history
update flax_utils.py
  • Loading branch information
teddius authored May 7, 2023
2 parents ef42c2c + 6ecae94 commit 45bd44f
Showing 1 changed file with 55 additions and 11 deletions.
66 changes: 55 additions & 11 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ class FlaxGreedySearchOutput(ModelOutput):
Args:
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
sequences (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The generated sequences.
scores (`jnp.ndarray` of shape `(batch_size, sequence_length, vocab_size)`):
The scores (log probabilities) of the generated tokens.
"""

sequences: jnp.ndarray = None
scores: jnp.ndarray = None


@flax.struct.dataclass
Expand All @@ -72,7 +75,7 @@ class FlaxSampleOutput(ModelOutput):
Args:
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
sequences (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The generated sequences.
"""

Expand All @@ -86,20 +89,21 @@ class FlaxBeamSearchOutput(ModelOutput):
Args:
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
sequences (`jnp.ndarray` of shape `(batch_size, num_return_sequences, sequence_length)`):
The generated sequences.
scores (`jnp.ndarray` of shape `(batch_size,)`):
sequences_scores (`jnp.ndarray` of shape `(batch_size, num_return_sequences)`):
The scores (log probabilities) of the generated sequences.
"""

sequences: jnp.ndarray = None
scores: jnp.ndarray = None
sequences_scores: jnp.ndarray = None


@flax.struct.dataclass
class GreedyState:
cur_len: jnp.ndarray
sequences: jnp.ndarray
scores: Optional[jnp.ndarray]
running_token: jnp.ndarray
is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
Expand Down Expand Up @@ -419,6 +423,7 @@ def generate(
generation_config.max_length,
generation_config.pad_token_id,
generation_config.eos_token_id,
output_scores=generation_config.output_scores,
logits_processor=logits_processor,
trace=trace,
params=params,
Expand All @@ -439,6 +444,8 @@ def generate(
model_kwargs=model_kwargs,
)
elif not generation_config.do_sample and generation_config.num_beams > 1:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
# broadcast input_ids & encoder_outputs
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)

Expand All @@ -458,9 +465,11 @@ def generate(
generation_config.max_length,
generation_config.pad_token_id,
generation_config.eos_token_id,
output_scores=generation_config.output_scores,
length_penalty=generation_config.length_penalty,
early_stopping=generation_config.early_stopping,
logits_processor=logits_processor,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
trace=trace,
params=params,
num_return_sequences=generation_config.num_return_sequences,
Expand Down Expand Up @@ -562,6 +571,7 @@ def _greedy_search(
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_scores: Optional[bool] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
Expand All @@ -571,6 +581,7 @@ def _greedy_search(
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores

batch_size, cur_len = input_ids.shape

Expand All @@ -581,6 +592,29 @@ def _greedy_search(
# per batch-item holding current token in loop.
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
if output_scores:
if hasattr(self.config, "vocab_size") and self.config.vocab_size is not None:
vocab_size = self.config.vocab_size
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
and self.config.decoder.vocab_size is not None
):
vocab_size = self.config.decoder.vocab_size
elif (
hasattr(self.config, "encoder")
and hasattr(self.config.encoder, "vocab_size")
and self.config.encoder.vocab_size is not None
):
vocab_size = self.config.encoder.vocab_size
else:
raise TypeError(
f"The current model class ({self.__class__.__name__}) has not a recognized "
f"`vocab_size` , as it doesn't support output_scores."
)
scores = jnp.ones((batch_size, max_length, vocab_size)) * np.array(-1.0e7)
else:
scores = None

# per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
Expand All @@ -595,6 +629,7 @@ def _greedy_search(
state = GreedyState(
cur_len=cur_len,
sequences=sequences,
scores=scores,
running_token=input_ids,
is_sent_finished=is_sent_finished,
model_kwargs=model_kwargs,
Expand All @@ -613,10 +648,10 @@ def greedy_search_body_fn(state):
logits = model_outputs.logits[:, -1]

# apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len)

next_token = jnp.argmax(logits, axis=-1)
next_tokens_scores = logits_processor(state.sequences, logits, state.cur_len)

next_token = jnp.argmax(next_tokens_scores, axis=-1)
tokens_scores = state.scores.at[:, state.cur_len, :].set(next_tokens_scores) if output_scores else None
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token[:, None]
Expand All @@ -626,6 +661,7 @@ def greedy_search_body_fn(state):
return GreedyState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
scores=tokens_scores,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
Expand All @@ -640,7 +676,12 @@ def greedy_search_body_fn(state):
else:
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)

return FlaxGreedySearchOutput(sequences=state.sequences)
if output_scores:
final_scores = state.scores
else:
final_scores = None

return FlaxGreedySearchOutput(sequences=state.sequences, scores=final_scores)

def _sample(
self,
Expand Down Expand Up @@ -745,9 +786,11 @@ def _beam_search(
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_scores: Optional[bool] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[Union[bool, str]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
num_beam_hyps_to_keep: Optional[int] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
num_return_sequences: Optional[int] = None,
Expand Down Expand Up @@ -795,6 +838,7 @@ def gather_fn(tensor):
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
)
Expand Down Expand Up @@ -1003,6 +1047,6 @@ def beam_search_body_fn(state, input_ids_length=1):

# Take best beams for each batch (the score is sorted in descending order)
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences])
sequences_scores = flatten_beam_dim(scores[:, :num_return_sequences]) if output_scores else None

return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
return FlaxBeamSearchOutput(sequences=sequences, sequences_scores=sequences_scores)

0 comments on commit 45bd44f

Please sign in to comment.