Skip to content

Commit

Permalink
Merge branch 'main' into hf_contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
m-momeni authored Dec 12, 2024
2 parents e7f6749 + a691ccb commit 2e0f03a
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 110 deletions.
196 changes: 96 additions & 100 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,56 +208,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
# Calculate new tokens to generate
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
if max_new_tokens == 0:
return input_ids, None

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
# Update past key values and masks
self._update_past_and_masks(input_ids)
# Generate candidates
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
return candidate_ids, candidate_logits

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
Expand Down Expand Up @@ -318,6 +277,55 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F

self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold

def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
"""Calculate the minimum and maximum number of new tokens to generate."""
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
return min_new_tokens, max_new_tokens

def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
"""Update past key values and attention masks for subsequent generation rounds."""
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
return has_past_key_values

def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
"""Prepare arguments for the generation call."""
return {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""Generate candidate sequences using the assistant model."""
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits


class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
"""
Expand Down Expand Up @@ -367,6 +375,7 @@ def __init__(

self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_target_ids_len: Optional[int] = None
self.prev_assistant_ids = None
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
Expand Down Expand Up @@ -497,27 +506,50 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
return input_ids, None

input_ids = input_ids.to(self.assistant_model.device)
remove_from_pkv = 0

assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
self.prev_assistant_ids = assistant_input_ids

min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)

self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)

# Update state
self.prev_target_ids_len = input_ids.shape[1]
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_assistant_ids = assistant_output.sequences

if self.prev_target_ids_len >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None

def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]:
"""Converts target input IDs to assistant input IDs, handling discrepancies."""
convert_kwargs = {
"source_tokenizer": self.target_tokenizer,
"destination_tokenizer": self.assistant_tokenizer,
}
remove_from_pkv = 0

# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
# (one for each conversion) which mark where to start looking for the overlap between the
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind

new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
)
prompt_use_length = new_assistant_ids.shape[1]
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]

discrepancy_length, new_tokens_only, discrepancy_only = (
AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
prompt_use, new_assistant_ids
)
assistant_input_ids = self.prev_assistant_ids

Expand All @@ -538,58 +570,29 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
else:
# edge case: in case of no intersection between prompt and new_assistant_ids
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)

else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids_len = input_ids.shape[1]

self.prev_assistant_ids = assistant_input_ids
new_cur_len = assistant_input_ids.shape[-1]
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: assistant_input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
return assistant_input_ids, remove_from_pkv

def _process_assistant_outputs(
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
) -> torch.LongTensor:
"""Processes assistant outputs to obtain target input IDs."""
num_prev_assistant = self.prev_assistant_ids.shape[1]
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
if start_assistant_look_index < 0:
start_assistant_look_index = 0

new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
assistant_output.sequences[:, start_assistant_look_index:],
assistant_sequences[:, start_assistant_look_index:],
source_tokenizer=self.assistant_tokenizer,
destination_tokenizer=self.target_tokenizer,
)
target_prompt_use_length = new_target_ids_from_window.shape[1]

target_prompt_use = input_ids[:, -target_prompt_use_length:]

_, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
target_prompt_use, new_target_ids_from_window
)
_, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)

new_target_ids = input_ids

Expand All @@ -603,14 +606,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
if hasattr(self.generation_config, "max_length"):
new_target_ids = new_target_ids[:, : self.generation_config.max_length]

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

# 4. Prepare variables for output
if input_ids.shape[1] >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None
return new_target_ids


class PromptLookupCandidateGenerator(CandidateGenerator):
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from multiprocessing import Process
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from zipfile import is_zipfile

Expand Down Expand Up @@ -3825,11 +3825,11 @@ def from_pretrained(
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
Process(
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
name="Process-auto_conversion",
name="Thread-auto_conversion",
).start()
else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def ffmpeg_microphone(
The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
could also be used.
ffmpeg_input_device (`str`, *optional*):
The indentifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
The identifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
the default input device will be used. See `https://www.ffmpeg.org/ffmpeg-devices.html#Input-Devices`
for how to specify and list input devices.
ffmpeg_additional_args (`list[str]`, *optional*):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/safetensors_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
# security breaches.
pr = previous_pr(api, model_id, pr_title, token=token)

if pr is None or (not private and pr.author != "SFConvertBot"):
if pr is None or (not private and pr.author != "SFconvertbot"):
spawn_conversion(token, private, model_id)
pr = previous_pr(api, model_id, pr_title, token=token)
else:
Expand Down
21 changes: 19 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import subprocess
import sys
import tempfile
import threading
import time
import unittest
from collections import defaultdict
Expand Down Expand Up @@ -2311,12 +2312,28 @@ class RequestCounter:

def __enter__(self):
self._counter = defaultdict(int)
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self._thread_id = threading.get_ident()
self._extra_info = []

def patched_with_thread_info(func):
def wrap(*args, **kwargs):
self._extra_info.append(threading.get_ident())
return func(*args, **kwargs)

return wrap

self.patcher = patch.object(
urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
)
self.mock = self.patcher.start()
return self

def __exit__(self, *args, **kwargs) -> None:
for call in self.mock.call_args_list:
assert len(self.mock.call_args_list) == len(self._extra_info)

for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
if thread_id != self._thread_id:
continue
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np
from packaging import version
Expand Down Expand Up @@ -1527,7 +1527,7 @@ def get_vocab(self) -> Dict[str, int]:
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
tools: Optional[List[Dict]] = None,
tools: Optional[List[Union[Dict, Callable]]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import json
import re
import types
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
Expand Down Expand Up @@ -97,7 +98,7 @@ def _parse_type_hint(hint: str) -> Dict:
"Couldn't parse this type hint, likely due to a custom class or object: ", hint
)

elif origin is Union:
elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
if len(subtypes) == 1:
Expand Down
Loading

0 comments on commit 2e0f03a

Please sign in to comment.