Skip to content

Commit

Permalink
Generation: get special tokens from model config (#30899)
Browse files Browse the repository at this point in the history
* fix

* let's do this way?

* codestyle

* update

* add tests
  • Loading branch information
zucchini-nlp authored May 22, 2024
1 parent 1d568df commit b1065aa
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,23 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
self._cache.reset()
return self._cache

def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id

if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
else:
return

def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
Expand All @@ -1385,11 +1402,16 @@ def _tensor_or_none(token, device=None):
return token
return torch.tensor(token, device=device, dtype=torch.long)

# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
if self.config.is_encoder_decoder:
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
generation_config.decoder_start_token_id, generation_config.bos_token_id
)

bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id

# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0:
Expand Down
30 changes: 30 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GenerationConfig,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
LogitsProcessorList,
Expand Down Expand Up @@ -2478,6 +2479,35 @@ def test_batched_decoder_start_id(self):

self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())

def test_decoder_start_id_from_config(self):
# Refer to: (#30899)
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id

# we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))

# If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
bart_model.generation_config.decoder_start_token_id = None
bart_model.generation_config.bos_token_id = None
outputs_with_user_id = bart_model.generate(
input_ids,
generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id),
)

self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist())

with self.assertRaises(ValueError):
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))

def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
Expand Down

0 comments on commit b1065aa

Please sign in to comment.