diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 65d65869afd292..45a0f59276fc1a 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -58,11 +58,14 @@ class FlaxGreedySearchOutput(ModelOutput): Args: - sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): + sequences (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The generated sequences. + scores (`jnp.ndarray` of shape `(batch_size, sequence_length, vocab_size)`): + The scores (log probabilities) of the generated tokens. """ sequences: jnp.ndarray = None + scores: jnp.ndarray = None @flax.struct.dataclass @@ -72,7 +75,7 @@ class FlaxSampleOutput(ModelOutput): Args: - sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): + sequences (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The generated sequences. """ @@ -86,20 +89,21 @@ class FlaxBeamSearchOutput(ModelOutput): Args: - sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): + sequences (`jnp.ndarray` of shape `(batch_size, num_return_sequences, sequence_length)`): The generated sequences. - scores (`jnp.ndarray` of shape `(batch_size,)`): + sequences_scores (`jnp.ndarray` of shape `(batch_size, num_return_sequences)`): The scores (log probabilities) of the generated sequences. """ sequences: jnp.ndarray = None - scores: jnp.ndarray = None + sequences_scores: jnp.ndarray = None @flax.struct.dataclass class GreedyState: cur_len: jnp.ndarray sequences: jnp.ndarray + scores: Optional[jnp.ndarray] running_token: jnp.ndarray is_sent_finished: jnp.ndarray model_kwargs: Dict[str, jnp.ndarray] @@ -419,6 +423,7 @@ def generate( generation_config.max_length, generation_config.pad_token_id, generation_config.eos_token_id, + output_scores=generation_config.output_scores, logits_processor=logits_processor, trace=trace, params=params, @@ -439,6 +444,8 @@ def generate( model_kwargs=model_kwargs, ) elif not generation_config.do_sample and generation_config.num_beams > 1: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") # broadcast input_ids & encoder_outputs input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) @@ -458,9 +465,11 @@ def generate( generation_config.max_length, generation_config.pad_token_id, generation_config.eos_token_id, + output_scores=generation_config.output_scores, length_penalty=generation_config.length_penalty, early_stopping=generation_config.early_stopping, logits_processor=logits_processor, + num_beam_hyps_to_keep=generation_config.num_return_sequences, trace=trace, params=params, num_return_sequences=generation_config.num_return_sequences, @@ -562,6 +571,7 @@ def _greedy_search( max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + output_scores: Optional[bool] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, @@ -571,6 +581,7 @@ def _greedy_search( max_length = max_length if max_length is not None else self.generation_config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores batch_size, cur_len = input_ids.shape @@ -581,6 +592,29 @@ def _greedy_search( # per batch-item holding current token in loop. sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) + if output_scores: + if hasattr(self.config, "vocab_size") and self.config.vocab_size is not None: + vocab_size = self.config.vocab_size + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "vocab_size") + and self.config.decoder.vocab_size is not None + ): + vocab_size = self.config.decoder.vocab_size + elif ( + hasattr(self.config, "encoder") + and hasattr(self.config.encoder, "vocab_size") + and self.config.encoder.vocab_size is not None + ): + vocab_size = self.config.encoder.vocab_size + else: + raise TypeError( + f"The current model class ({self.__class__.__name__}) has not a recognized " + f"`vocab_size` , as it doesn't support output_scores." + ) + scores = jnp.ones((batch_size, max_length, vocab_size)) * np.array(-1.0e7) + else: + scores = None # per batch-item state bit indicating if sentence has finished. is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) @@ -595,6 +629,7 @@ def _greedy_search( state = GreedyState( cur_len=cur_len, sequences=sequences, + scores=scores, running_token=input_ids, is_sent_finished=is_sent_finished, model_kwargs=model_kwargs, @@ -613,10 +648,10 @@ def greedy_search_body_fn(state): logits = model_outputs.logits[:, -1] # apply min_length, ... - logits = logits_processor(state.sequences, logits, state.cur_len) - - next_token = jnp.argmax(logits, axis=-1) + next_tokens_scores = logits_processor(state.sequences, logits, state.cur_len) + next_token = jnp.argmax(next_tokens_scores, axis=-1) + tokens_scores = state.scores.at[:, state.cur_len, :].set(next_tokens_scores) if output_scores else None next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token[:, None] @@ -626,6 +661,7 @@ def greedy_search_body_fn(state): return GreedyState( cur_len=state.cur_len + 1, sequences=next_sequences, + scores=tokens_scores, running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, @@ -640,7 +676,12 @@ def greedy_search_body_fn(state): else: state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state) - return FlaxGreedySearchOutput(sequences=state.sequences) + if output_scores: + final_scores = state.scores + else: + final_scores = None + + return FlaxGreedySearchOutput(sequences=state.sequences, scores=final_scores) def _sample( self, @@ -745,9 +786,11 @@ def _beam_search( max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + output_scores: Optional[bool] = None, length_penalty: Optional[float] = None, early_stopping: Optional[Union[bool, str]] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, + num_beam_hyps_to_keep: Optional[int] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, num_return_sequences: Optional[int] = None, @@ -795,6 +838,7 @@ def gather_fn(tensor): eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences ) @@ -1003,6 +1047,6 @@ def beam_search_body_fn(state, input_ids_length=1): # Take best beams for each batch (the score is sorted in descending order) sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) - scores = flatten_beam_dim(scores[:, :num_return_sequences]) + sequences_scores = flatten_beam_dim(scores[:, :num_return_sequences]) if output_scores else None - return FlaxBeamSearchOutput(sequences=sequences, scores=scores) + return FlaxBeamSearchOutput(sequences=sequences, sequences_scores=sequences_scores)