Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dramatic Performance Drop of CLIPVisionModel Related Model After Upgrading transformers From 4.27.4 to 4.28.x #23096

Closed
4 tasks
submartingales opened this issue May 2, 2023 · 17 comments

Comments

@submartingales
Copy link

System Info

@amyeroberts

Related versions are

transformers==4.27.4

and

transformers==4.28.1

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm sure that CLIPVisionModel loaded from_pretrained like from the laion pretrained CLIP ViT will output totally different tensor output with exactly same input image. And using transformers==4.28.1 will lead to a dramatic performance drop for reasons worth digging. Extensive tests have been conducted to verify that this issue seems irrelevant to the torch version (e.g. 2.0 or 1.13)

Probably you can reproduce this by load from laion/CLIP-ViT-H-14-laion2B-s32B-b79K which is quite popular.

Expected behavior

Two version output almost same tensor given same input image.

@submartingales submartingales changed the title **Dramatic Performance Drop** of CLIPVisionModel Related Model After Upgrading transformers From 4.27.4 to 4.28.x Dramatic Performance Drop of CLIPVisionModel Related Model After Upgrading transformers From 4.27.4 to 4.28.x May 2, 2023
@submartingales
Copy link
Author

I found this issue may impact tons of vision workloads and I hope it can be resolved as soon as possible.

@sgugger
Copy link
Collaborator

sgugger commented May 2, 2023

Also cc @younesbelkada

@amyeroberts
Copy link
Collaborator

Hi @submartingales, thanks for reporting!

So that I can pin down the issue, is the input image the same before being passed to the processor? e.g.

import torch
from transformers import CLIPProcessor, CLIPModel

torch.manual_seed(0)

# Dummy image which is always the same for each version
image = torch.randint(0, 256, (3, 300, 300))

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# Model inputs might change based on if there's a change in processing logic
inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)

Or is are the pixel_values exactly the same?

import torch
from transformers import CLIPProcessor, CLIPModel

torch.manual_seed(0)

pixel_values = torch.rand(1, 3, 224, 224)
input_ids = torch.Tensor(
    [[49406,   320,  1125,   539,   320,  2368, 49407],
      [49406,   320,  1125,   539,   320,  1929, 49407]]
).long()

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# The model inputs exactly the same for different library versions
inputs = {"input_ids": input_ids, "pixel_values": pixel_values}
outputs = model(**inputs)

With regards to expected behaviour, could you give some more information about what's changed? Specifically what is being measured in terms of performance e.g. is it the clip loss? And how much it has changed?

@submartingales
Copy link
Author

@amyeroberts I will make two notebooks to clarify, please wait for several minutes.

@submartingales
Copy link
Author

submartingales commented May 2, 2023

@amyeroberts Actually, we cannot disclose all resources that are required to run the notebooks for reasons you will definitely know once you have read them. But the performance drop (the last cell's output, the higher the better) are consistent on different platforms and the only variable is the version of transformers so at least for now we believe the model's behavior change is caused by the package upgrading.

The only difference between two notebooks given in the zip file is the transformers 's version, and the checkpoint we have loaded are exactly the same.
two-version-notebook.zip

@submartingales
Copy link
Author

submartingales commented May 2, 2023

@amyeroberts In short, every time we try upgrading the transformers version for new features, no matter what torch version we are using, what platform we are running on, we found our prediction workflow failed. For the specific task solved in the notebooks, another observation I can provide is that once transformers has been upgraded to 4.28.1, the final prediction, say, the output for each input image, when loaded a model with the same weights, the model is possible to generate output with magnitude differences of over a thousand times for each input image and finally result in the performance drop.

@submartingales
Copy link
Author

submartingales commented May 2, 2023

The uploaded two notebooks demonstrate the performance drop considering inference. What we are experiencing at the same time, is that during training using cos loss which is related to the task in the notebook, the transformers==4.27.4 powered model converge easily on about $50k$ images but transformers==4.28.1 based model won't converge on just $1k$ images.

The architecture we have chosen is straight forward and if we load from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') regardless of the Internet connection restriction on certain platform, the problem still exists.

@submartingales
Copy link
Author

submartingales commented May 2, 2023

@amyeroberts With respect to the output difference, at the $8st$ cell of the two given notebook, we can see that the tensor output for the first sample, is different.

The 4.28.1 version gives

-0.776564
1.751475
1.938180
0.474142
-0.191921
...

while 4.27.4 gives

-2.197644
2.167892
-0.369088
-0.928763
-3.423420
...

@submartingales
Copy link
Author

submartingales commented May 3, 2023

@amyeroberts In short, every time we try upgrading the transformers version for new features, no matter what torch version we are using, what platform we are running on, we found our prediction workflow failed. For the specific task solved in the notebooks, another observation I can provide is that once transformers has been upgraded to 4.28.1, the final prediction, say, the output for each input image, when loaded a model with the same weights, the model is possible to generate output with magnitude differences of over a thousand times for each input image and finally result in the performance drop.

@amyeroberts The thousand times I mean above is related to a similar strategy with another weight checkpoint, which is not presented in two-version-notebook.zip.

My coworkers guess that something important related to the overall CLIP workflow has changed between 4.27.4 and 4.28.1, which has caused some incompatibilities issues.

@submartingales
Copy link
Author

@amyeroberts Any progress on this issue? If you can roughly locate the code change related to this issue, I am happy to submit a pull request to fix it.

@amyeroberts
Copy link
Collaborator

Hi @submartingales, thanks for sharing more details and the notebooks.

I suspect this is related to a change in the cropping behaviour identified in a similar issue.

The fastest way to regain the old behaviour whilst waiting for the fix to be merged would be implementing an image processor which overrides the cropping behaviour e.g. something like this:

from typing import Dict, Optional, Union

import numpy as np
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPProcessor
from transformers.image_transforms import get_image_size, to_channel_dimension_format
from transformers.image_utils import ChannelDimension, get_image_size, infer_channel_dimension_format
from transformers.image_processing_utils import get_size_dict


class NewCLIPImageProcessor(CLIPImageProcessor):
    def center_crop(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs
    ) -> np.ndarray:
        size = get_size_dict(size)
        if "height" not in size or "width" not in size:
            raise ValueError(f"The `size` parameter must contain the keys (height, width). Got {size.keys()}")

        image = to_channel_dimension_format(image, ChannelDimension.FIRST)
        if data_format is None:
            data_format = infer_channel_dimension_format(image)

        image_height, image_width = get_image_size(image)
        crop_height, crop_width = size["height"], size["width"]

        crop_top = int((image_height - crop_height + 1) * 0.5)
        crop_left = int((image_width - crop_width + 1) * 0.5)

        image = image[:, crop_top : crop_top + crop_height, crop_left : crop_left + crop_width]
        image = to_channel_dimension_format(image, data_format)
        return image

image_processor = NewCLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor(image_processor=image_processor, tokenizer=tokenizer)

@submartingales
Copy link
Author

@amyeroberts Thank you for your code to fix this but we are sorry to inform that after we have updated the processor by incorporating your code snippet, the problem still exists and the model's output based on transformers==4.28.1 does not change to what it shall be.

@submartingales
Copy link
Author

@amyeroberts These days we have performed further experiments using transformers==4.29.2 and such output change persists in transformers==4.29.2 and the output tensor is allclosed to what is outputed by transformers==4.28.1.

@sgugger
Copy link
Collaborator

sgugger commented May 18, 2023

@submartingales If you do not share a reproducer of the bug, there is really nothing we can do to help.

@submartingales
Copy link
Author

@sgugger Now we make public all resources required to reproduce the bug, in two public notebooks with all related checkpoints loaded in public datasets. Any account can now copy & edit the notebook and reproduce the behavior change with "pin to original environment" checked.

@amyeroberts
Copy link
Collaborator

Hi @submartingales, thanks for sharing the repro.

I've tracked down the change in the model outputs down to a bug fix in 4.28.x: #22458.

In the shared notebooks in the ImageDataset class, the images are converted to torch tensors in the __getitem__ method using ToTensor(). ToTensor() doesn't just convert the PIL image to a tensor, but also scales the pixel values between 0-1.

The image transforms library uses Pillow to resize the images. If the input is an array, then its first converted to a PIL image, and then converted back to an array. To convert an array to a PIL.Image.Image, its pixels must be integer values between [0, 255]. In 4.27.4, if the input had pixel values [0, 1], and we rescale so this conversion happened, the output array wasn't rescaled back down -> the output array had pixel values between [0, 255].

If using ToTensor then the image processor should have do_rescale=False set to prevent the pixel values being divided by 255 twice. This was likely the cause of the degraded performance (as the images in 4.27.4 had their pixel values multiplied by 255 when resizing, nullifying this double divide.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants