Skip to content

Commit

Permalink
Paligemma support for multi-image (#33447)
Browse files Browse the repository at this point in the history
* upadte

* Update src/transformers/models/paligemma/processing_paligemma.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update docs

* better example in tests

* support image tokens

* read token

* Update tests/models/paligemma/test_processing_paligemma.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* nit: naming

* Update docs/source/en/model_doc/paligemma.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* conflicts after rebasing

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent 55b7a04 commit 3e039d3
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 50 deletions.
42 changes: 35 additions & 7 deletions docs/source/en/model_doc/paligemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,20 @@ This model was contributed by [Molbap](https://huggingface.co/Molbap).

## Usage tips

Inference with PaliGemma can be performed as follows:
- PaliGemma is not meant for conversational use, and it works best when fine-tuning to a specific use case. Some downstream tasks on which PaliGemma can be fine-tuned include image captioning, visual question answering (VQA), object detection, referring expression segmentation and document understanding.
- One can use `PaliGemmaProcessor` to prepare images, text and optional labels for the model. When fine-tuning a PaliGemma model, the `suffix` argument can be passed to the processor which creates the `labels` for the model:

```python
prompt = "What is on the flower?"
answer = "a bee"
inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt")
```

## Usage Example

The model can accept a single or multiple images. According to the [paper](https://arxiv.org/abs/2407.07726v1), the checkpoint PaliGemma can transfer to tasks which take multiple images as input. NLVR2 is one such task, which asks one question about two images, and requires looking at both to give the correct answer. Here's an example code for single and multi image inference.

### Single-image Inference

```python
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
Expand All @@ -44,16 +57,31 @@ raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(raw_image, prompt, return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
```

- PaliGemma is not meant for conversational use, and it works best when fine-tuning to a specific use case. Some downstream tasks on which PaliGemma can be fine-tuned include image captioning, visual question answering (VQA), object detection, referring expression segmentation and document understanding.
- One can use `PaliGemmaProcessor` to prepare images, text and optional labels for the model. When fine-tuning a PaliGemma model, the `suffix` argument can be passed to the processor which creates the `labels` for the model:
### Multi-image Inference

```python
prompt = "What is on the flower?"
answer = "a bee"
inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt")
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = PaliGemmaProcessor.from_pretrained(model_id)

prompt = "answer en Which of the two pictures shows a snowman, first or second?"
stop_sign_image = Image.open(
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
)
snow_image = Image.open(
requests.get(
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
).raw
)

inputs = processor(images=[[snow_image, stop_sign_image]], text=prompt, return_tensors="pt")

output = model.generate(**inputs, max_new_tokens=20)
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])

```

## Resources
Expand Down
91 changes: 70 additions & 21 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _is_str_or_image(elem):
return isinstance(elem, (str)) or is_image_or_image_url(elem)


def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
"""
Builds a string from the input prompt and image tokens.
For example, for the call:
Expand All @@ -94,8 +94,33 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
bos_token (`str`): The beginning of sentence token.
image_seq_len (`int`): The length of the image sequence.
image_token (`str`): The image token.
num_images (`int`): Number of images in the prompt.
"""
return f"{image_token * image_seq_len}{bos_token}{prompt}\n"
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"


# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images
def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]

elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images

elif is_valid_image(images):
return [images]

raise ValueError(f"Could not make batched video from {images}")


class PaliGemmaProcessor(ProcessorMixin):
Expand Down Expand Up @@ -230,29 +255,53 @@ def __call__(
)
text = ""

if isinstance(text, List) and isinstance(images, List):
if len(images) < len(text):
raise ValueError(
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
)
if _is_str_or_image(text):
text = [text]
elif isinstance(text, list) and _is_str_or_image(text[0]):
pass
if suffix is not None and _is_str_or_image(suffix):
suffix = [suffix]
if suffix is not None:
suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]

input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
)
for prompt in text
]

if text is not None and images is not None:
if not any(IMAGE_TOKEN in sample for sample in text):
logger.warning(
"You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special "
"image tokens in the text, as many tokens as there are images per each text. It is recommended to "
"add `<image>` tokens in the very beginning of your text and `<bos>` token after that. For this call, we will infer how many images "
"each text has and add special tokens."
)

if isinstance(text, List) and isinstance(images, List):
if len(images) != len(text):
raise ValueError(
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
)

# make a nested list of lists to be able to iterate over the images and text below
if is_valid_image(images):
images = [[images]]
elif isinstance(images, list) and is_valid_image(images[0]):
images = [[image] for image in images]
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")

if suffix is not None and _is_str_or_image(suffix):
suffix = [suffix]
if suffix is not None:
suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]

input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
num_images=len(image_list) if isinstance(image_list, list) else 1,
)
for prompt, image_list in zip(text, images)
]
images = make_batched_images(images)
else:
text = [sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) for sample in text]
input_strings = [f"{sample}\n" for sample in text]

pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

Expand Down
56 changes: 34 additions & 22 deletions tests/models/paligemma/test_modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,6 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@slow
@require_read_token
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
Expand All @@ -349,8 +347,40 @@ def test_small_model_integration_test(self):
EXPECTED_DECODED_TEXT,
)

@slow
@require_read_token
def test_small_model_integration_test_multiimage(self):
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = "answer en There is no snowman in any of the images. Is this true or false?"
stop_sign_image = Image.open(
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
)
snow_image = Image.open(
requests.get(
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
).raw
)

inputs = processor(text=prompt, images=[[snow_image, snow_image]], return_tensors="pt")

output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "answer en There is no snowman in any of the images. Is this true or false?\nFalse"

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

# try another prompt with two different image this time
prompt = "answer en There is exactly one snowman. Is this true or false?"
inputs = processor(text=prompt, images=[[snow_image, stop_sign_image]], return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "answer en There is exactly one snowman. Is this true or false?\nTrue"
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

def test_small_model_integration_test_paligemma_VQA(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
Expand All @@ -370,8 +400,6 @@ def test_small_model_integration_test_paligemma_VQA(self):
EXPECTED_DECODED_TEXT,
)

@slow
@require_read_token
def test_small_model_integration_test_paligemma_empty_prompt(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
Expand All @@ -392,8 +420,6 @@ def test_small_model_integration_test_paligemma_empty_prompt(self):
EXPECTED_DECODED_TEXT,
)

@slow
@require_read_token
def test_small_model_integration_test_paligemma_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
Expand All @@ -420,9 +446,6 @@ def test_small_model_integration_test_paligemma_batched(self):

self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_bf16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
Expand Down Expand Up @@ -452,9 +475,6 @@ def test_small_model_integration_test_paligemma_batched_bf16(self):
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_f16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
Expand Down Expand Up @@ -485,9 +505,6 @@ def test_small_model_integration_test_paligemma_batched_f16(self):
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_torch
@require_read_token
def test_integration_detection_bug(self):
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
# impacted negatively segmentation generations.
Expand All @@ -511,8 +528,6 @@ def test_integration_detection_bug(self):
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_read_token
def test_paligemma_index_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
Expand All @@ -536,9 +551,6 @@ def test_paligemma_index_error_bug(self):
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)

@slow
@require_torch
@require_read_token
def test_paligemma_finetuning_with_suffixes_bf16(self):
# this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes
model_id = "google/paligemma-3b-pt-224"
Expand Down
84 changes: 84 additions & 0 deletions tests/models/paligemma/test_processing_paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest

from transformers import AutoProcessor, GemmaTokenizerFast, PaliGemmaProcessor
from transformers.testing_utils import require_read_token, require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import SiglipImageProcessor


@require_vision
@require_read_token
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PaliGemmaProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SiglipImageProcessor(do_center_crop=False)
tokenizer = GemmaTokenizerFast.from_pretrained("google/gemma-7b")
image_processor.image_seq_length = 32

processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
text_multi_images = "<image><image><bos>Dummy text!"
text_single_image = "<image><bos>Dummy text!"
text_no_image = "Dummy text!"

image = self.prepare_image_inputs()[0]

out_noimage = processor(text=text_no_image, images=image, return_tensors="np")
out_singlimage = processor(text=text_single_image, images=image, return_tensors="np")
for k in out_noimage:
self.assertTrue(out_noimage[k].tolist() == out_singlimage[k].tolist())

out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
out_noimage = processor(text=text_no_image, images=[[image, image]], return_tensors="np")

# We can't be sure what is users intention, whether user want "one text + two images" or user forgot to add the second text
with self.assertRaises(ValueError):
out_noimage = processor(text=text_no_image, images=[image, image], return_tensors="np")

for k in out_noimage:
self.assertTrue(out_noimage[k].tolist() == out_multiimages[k].tolist())

text_batched = ["Dummy text!", "Dummy text!"]
text_batched_with_image = ["<image><bos>Dummy text!", "<image><bos>Dummy text!"]
out_images = processor(text=text_batched_with_image, images=[image, image], return_tensors="np")
out_noimage_nested = processor(text=text_batched, images=[[image], [image]], return_tensors="np")
out_noimage = processor(text=text_batched, images=[image, image], return_tensors="np")
for k in out_noimage:
self.assertTrue(out_noimage[k].tolist() == out_images[k].tolist() == out_noimage_nested[k].tolist())

0 comments on commit 3e039d3

Please sign in to comment.