Skip to content

Commit

Permalink
Add descriptive docstring to TemperatureLogitsWarper
Browse files Browse the repository at this point in the history
It addresses huggingface#24783
  • Loading branch information
nablabits committed Jul 19, 2023
1 parent ee4250a commit 3dcc5a6
Showing 1 changed file with 64 additions and 2 deletions.
66 changes: 64 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,73 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class TemperatureLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
[`LogitsWarper`] for temperature (exponential scaling output probability distribution) which effectively means that
we can control the randomness of the predicted tokens.
<Tip>
Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
any effect. As noted in `generation_strategies#Assisted-Decoding`, this parameter along with assisted decoding will
help improving the latency.
</Tip>
Args:
temperature (`float`):
The value used to module the logits distribution.
The value used to module the logits distribution. A value of `1.` will make no difference in the output
with respect to calling `generate` without temperature.
Examples:
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> def get_fresh_inference(t, tokens=4, num_beams=1):
... torch.manual_seed(0)
... checkpoint = "gpt2"
... tokenizer = AutoTokenizer.from_pretrained(checkpoint)
... model = AutoModelForCausalLM.from_pretrained(checkpoint)
... model.config.pad_token_id = model.config.eos_token_id
... model.generation_config.pad_token_id = model.config.eos_token_id
... prompt = "the quick brown fox jumps"
... input_ids = tokenizer.encode(prompt, return_tensors="pt")
... outputs = model.generate(
... input_ids=input_ids, max_new_tokens=tokens, temperature=t, do_sample=True, num_beams=num_beams
... )
... print(t, tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> t_range = (0.5, 1.0, 2.0, 7.0, 20.0)
>>> for t in t_range:
... get_fresh_inference(t)
For low max_new_tokens it might run out of ideas quickly.
0.5 the quick brown fox jumps over the fence and
1.0 the quick brown fox jumps around to his knees
2.0 the quick brown fox jumps around to his knees
7.0 the quick brown fox jumps around to his knees
20.0 the quick brown fox jumps around to his knees
>>> for t in t_range:
... get_fresh_inference(t, num_beams=4)
We can mitigate this by passing `num_beams` so it will run a beam-search increasing the variance.
0.5 the quick brown fox jumps out of the bushes
1.0 the quick brown fox jumps out of the bushes
2.0 the quick brown fox jumps up and falls over
7.0 the quick brown fox jumps over to them to
20.0 the quick brown fox jumps back by himself.)
>>> for t in t_range:
... get_fresh_inference(t, tokens=10)
However, if we request more tokens we start to get different outputs without `num_beams`
0.5 the quick brown fox jumps over the fence and runs up the fence. The
1.0 the quick brown fox jumps around to his knees. He stares around while waiting
2.0 the quick brown fox jumps around to his knees. His big big paws catch
7.0 the quick brown fox jumps around to his knees at one moment but after waiting
20.0 the quick brown fox jumps around to his knees at one moment's reminder."
```
"""

def __init__(self, temperature: float):
Expand Down

0 comments on commit 3dcc5a6

Please sign in to comment.