Skip to content

Commit

Permalink
nabarup add comments in HammingDiversityLogitsProcessor. Issue huggin…
Browse files Browse the repository at this point in the history
  • Loading branch information
Nabarup-Maity committed Aug 18, 2023
1 parent 427adc8 commit eede150
Showing 1 changed file with 46 additions and 7 deletions.
53 changes: 46 additions & 7 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,19 +1086,58 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class HammingDiversityLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for
[`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
[`PreTrainedModel.group_beam_search`]. Hamming Diversity is calculated by keeping a count of the occurence of a token in previous time steps / groups. It penalise using same token in current group which was used in previous groups at the same time step. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
Models](https://arxiv.org/pdf/1610.02424.pdf) for more details.
Args:
diversity_penalty (`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. diversity_penalty parameter during beam search with Hamming Diversity can influence the diversity and creativity of the generated text.
num_beams (`int`):
Number of beams used for group beam search. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more
details.
Number of beams used for group beam search. Increasing num_beams generally increases diversity by exploring more alternative sequences or beams in parallel. `num_beams` should be divisible by `num_beam_groups` for group beam search.
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. Increasing num_beam_groups encourages diversity within each group, promoting varied sequences within the same group and reducing repetition across groups.
Examples:
```python
# With limiting the generated sequence to six tokens
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["Artificial intelligence is transforming industries"], return_tensors="pt")
## 1. With a moderate diversity penalty, you aim to strike a balance between generating varied outputs and maintaining coherence. The outputs might have a mix of diverse and relevant content
>>> beam_ids = model.generate(inputs["input_ids"], num_beams=2, num_beam_groups = 2, diversity_penalty = 0.5, max_new_tokens=6)
>>> print(tokenizer.batch_decode(beam_ids, skip_special_tokens=True)[0])
Artificial intelligence is transforming industries, and it's not just
## 2. Increasing the number of beams allows the model to explore more alternatives in parallel. However, since the beam groups remain limited, you might still observe repetition in the generated sequences.
>>> beam_ids = model.generate(inputs["input_ids"], num_beams=4, num_beam_groups = 2, diversity_penalty = 0.5, max_new_tokens=6)
>>> print(tokenizer.batch_decode(beam_ids, skip_special_tokens=True)[0])
Artificial intelligence is transforming industries, and it's not just
## 3. With more beams, there's a higher likelihood of exploring diverse possibilities.
>>> beam_ids = model.generate(inputs["input_ids"], num_beams=8, num_beam_groups = 2, diversity_penalty = 0.5, max_new_tokens=6)
>>> print(tokenizer.batch_decode(beam_ids, skip_special_tokens=True)[0])
Artificial intelligence is transforming industries and the lives of millions of
## 4. With more beam groups, the model can explore different sequences within and across groups, increasing diversity. However, the relatively low diversity penalty might still lead to some repetition.
>>> beam_ids = model.generate(inputs["input_ids"], num_beams=4, num_beam_groups = 4, diversity_penalty = 0.5, max_new_tokens=6)
>>> print(tokenizer.batch_decode(beam_ids, skip_special_tokens=True)[0])
Artificial intelligence is transforming industries, and it's not just
## 5. A high diversity penalty promotes diversity at the expense of coherence. The generated outputs might become more fragmented and less coherent due to the aggressive avoidance of repetition
>>> beam_ids = model.generate(inputs["input_ids"], num_beams=4, num_beam_groups = 4, diversity_penalty = 10.9, max_new_tokens=6)
>>> print(tokenizer.batch_decode(beam_ids, skip_special_tokens=True)[0])
Artificial intelligence is transforming industries like healthcare, education, and
Overall, these examples showcase how different combinations of num_beams, num_beam_groups, and diversity_penalty can influence the level of diversity and coherence in the generated outputs. Note the output is highly depend on the pretrained model that has been chosen.
"""

def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
Expand Down

0 comments on commit eede150

Please sign in to comment.