Skip to content

Commit

Permalink
Uniformize model processors (huggingface#31368)
Browse files Browse the repository at this point in the history
* add initial design for uniform processors + align model

* add uniform processors for altclip + chinese_clip

* add uniform processors for blip + blip2

* fix mutable default 👀

* add configuration test

* handle structured kwargs w defaults + add test

* protect torch-specific test

* fix style

* fix

* rebase

* update processor to generic kwargs + test

* fix style

* add sensible kwargs merge

* update test

* fix assertEqual

* move kwargs merging to processing common

* rework kwargs for type hinting

* just get Unpack from extensions

* run-slow[align]

* handle kwargs passed as nested dict

* add from_pretrained test for nested kwargs handling

* [run-slow]align

* update documentation + imports

* update audio inputs

* protect audio types, silly

* try removing imports

* make things simpler

* simplerer

* move out kwargs test to common mixin

* [run-slow]align

* skip tests for old processors

* [run-slow]align, clip

* !$#@!! protect imports, darn it

* [run-slow]align, clip

* [run-slow]align, clip

* update common processor testing

* add altclip

* add chinese_clip

* add pad_size

* [run-slow]align, clip, chinese_clip, altclip

* remove duplicated tests

* fix

* add blip, blip2, bridgetower

Added tests for bridgetower which override common. Also modified common
tests to force center cropping if existing

* fix

* update doc

* improve documentation for default values

* add model_max_length testing

This parameter depends on tokenizers received.

* Raise if kwargs are specified in two places

* fix

* removed copied from

* match defaults

* force padding

* fix tokenizer test

* clean defaults

* move tests to common

* add missing import

* fix

* adapt bridgetower tests to shortest edge

* uniformize donut processor + tests

* add wav2vec2

* extend common testing to audio processors

* add testing + bert version

* propagate common kwargs to different modalities

* BC order of arguments

* check py version

* revert kwargs merging

* add draft overlap test

* update

* fix blip2 and wav2vec due to updates

* fix copies

* ensure overlapping kwargs do not disappear

* replace .pop by .get to handle duplicated kwargs

* fix copies

* fix missing import

* add clearly wav2vec2_bert to uniformized models

* fix copies

* increase number of features

* fix style

* [run-slow] blip, blip2, bridgetower, donut, wav2vec2, wav2vec2_bert

* [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert

* fix concatenation

* [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert

* Update tests/test_processing_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* 🧹

* address comments

* clean up + tests

* [run-slow] instructblip, blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and NielsRogge committed Oct 21, 2024
1 parent b6c3ce1 commit c21341f
Show file tree
Hide file tree
Showing 18 changed files with 772 additions and 276 deletions.
123 changes: 54 additions & 69 deletions src/transformers/models/blip/processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,25 @@
from typing import List, Optional, Union

from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


class BlipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}


class BlipProcessor(ProcessorMixin):
Expand Down Expand Up @@ -51,84 +67,53 @@ def __init__(self, image_processor, tokenizer, **kwargs):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[BlipProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")

# Get only text
if images is None:
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding

# add pixel_values
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
text_encoding = None

# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.
output_kwargs = self._merge_kwargs(
BlipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
else:
text_encoding = None

if text_encoding is not None:
encoding_image_processor.update(text_encoding)

return encoding_image_processor
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
if images is not None:
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])

if text_encoding is not None:
encoding_image_processor.update(text_encoding)
return encoding_image_processor

return text_encoding

def batch_decode(self, *args, **kwargs):
"""
Expand Down
134 changes: 61 additions & 73 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,38 @@

from typing import List, Optional, Union

from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
AddedToken,
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType, logging
from ...utils import logging


logger = logging.get_logger(__name__)


class Blip2ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}


class Blip2Processor(ProcessorMixin):
r"""
Constructs a BLIP-2 processor which wraps a BLIP image processor and an OPT/T5 tokenizer into a single processor.
Expand Down Expand Up @@ -67,83 +83,55 @@ def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[Blip2ProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")

# Get only text
if images is None:
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding

# add pixel_values
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)

output_kwargs = self._merge_kwargs(
Blip2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
else:
return_tensors = None
encoding = BatchFeature(tensor_type=return_tensors)
if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

text_encoding = {}
_text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=None, # hardcode "None" here for prepending image tokens
**kwargs,
)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors

# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
Expand All @@ -164,14 +152,14 @@ def __call__(
)

# cast to desired return tensors type
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
else:
text_encoding = None

if text_encoding is not None:
encoding_image_processor.update(text_encoding)

return encoding_image_processor
encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors))
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.

if images is not None:
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(image_encoding)
return encoding

# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
def batch_decode(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit c21341f

Please sign in to comment.