Skip to content

Commit

Permalink
Self-speculation (Layer-Skip Llama) (huggingface#34240)
Browse files Browse the repository at this point in the history
* 😅

* early exit (huggingface#34244)

* mvp

* docs and tests

* a few fixes

* no shared cache

* Apply suggestions from code review

Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>

* docs

* make fix-copies

* cohere fix

* [test all]

* [test all] consistent model code copies

* [test all] make fix-copies :D

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>

* Update src/transformers/generation/candidate_generator.py

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* [test all] don't use a stand-alone attribute; fix test

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
  • Loading branch information
5 people authored and BernardZach committed Dec 5, 2024
1 parent ada9c87 commit 0a9894b
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 51 deletions.
70 changes: 47 additions & 23 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer,
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).

#### Universal Assisted Decoding

Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.

To enable assisted decoding, set the `assistant_model` argument with a model.

```python
Expand All @@ -445,7 +435,36 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> set_seed(42) # For reproducibility

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
```

#### Universal Assisted Decoding

Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -465,30 +484,35 @@ If the main and assistant models have different tokenizers, use Universal Assist
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
#### Prompt Lookup

Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

#### Self-Speculative Decoding

An LLM can be trained to also use its language modeling head with earlier hidden states as input, effectively
skipping layers to yield a lower-quality output -- a technique called early exiting.
We use the lower-quality early exit output as an assistant output, and apply self-speculation to fix the output using the remaining layers. The final generation of that self-speculative solution is the same (or has the same distribution) as the original model's generation.
If the model you're using was trained to do early exit, you can pass
`assistant_early_exit` (integer). In this case, the assistant model will be the same model but exiting early, hence the
"self-speculative" name. Because the assistant model is a portion of the target model, caches and weights can be shared, which results in lower memory requirements. As in other assisted generation methods, the final generated result has the same quality as if no assistant had been used.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> set_seed(42) # For reproducibility
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> checkpoint = "facebook/layerskip-llama3.2-1B"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

### DoLa Decoding

**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
Expand Down
29 changes: 16 additions & 13 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,19 +433,22 @@ def update(
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
len(self.key_cache[layer_idx]) == 0
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_import_structure["candidate_generator"] = [
"AssistedCandidateGenerator",
"CandidateGenerator",
"EarlyExitCandidateGenerator",
"PromptLookupCandidateGenerator",
]
_import_structure["logits_process"] = [
Expand Down Expand Up @@ -206,7 +207,12 @@
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
from .candidate_generator import (
AssistedCandidateGenerator,
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
)
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
Expand Down
56 changes: 56 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,62 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
return


class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
"""
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
exit, e.g., `facebook/layerskip-llama3.2-1B`.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
The original model. This model must support early exit (i.e. is trained to compute logits in earlier
layers).
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
model_kwargs (`Dict`):
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
"""

def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
super().__init__(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
# We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
# with early exit
self.assistant_early_exit = self.generation_config.assistant_early_exit
self.generation_config.assistant_early_exit = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
# Temporarily sets the number of hidden layers to the early exit value
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
original_num_hidden_layers = base_model.config.num_hidden_layers
base_model.config.num_hidden_layers = self.assistant_early_exit
candidate_ids, candidate_logits = super().get_candidates(input_ids)
base_model.config.num_hidden_layers = original_num_hidden_layers
return candidate_ids, candidate_logits


def _crop_past_key_values(model, past_key_values, max_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,13 @@ class GenerationConfig(PushToHubMixin):
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
prompt_lookup_num_tokens (`int`, *optional*):
The number of tokens to be output as candidate tokens.
max_matching_ngram_size (`int`, *optional*, default to `None`):
max_matching_ngram_size (`int`, *optional*):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
assistant_early_exit(`int`, *optional*):
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
> Wild card
Expand Down Expand Up @@ -454,10 +457,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
Expand Down Expand Up @@ -534,7 +536,11 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non
generation_mode = GenerationMode.BEAM_SEARCH

# Assisted generation may extend some generation modes
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
if (
assistant_model is not None
or self.prompt_lookup_num_tokens is not None
or self.assistant_early_exit is not None
):
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
Expand Down Expand Up @@ -822,7 +823,16 @@ def _get_candidate_generator(
"""
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))

if generation_config.prompt_lookup_num_tokens is not None:
if generation_config.assistant_early_exit is not None:
candidate_generator = EarlyExitCandidateGenerator(
input_ids=input_ids,
assistant_model=self,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
elif generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
eos_token_id=generation_config._eos_token_tensor,
num_output_tokens=generation_config.prompt_lookup_num_tokens,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def forward(
all_router_logits = () if output_router_logits else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
Loading

0 comments on commit 0a9894b

Please sign in to comment.