Skip to content

Commit

Permalink
Add FlaxCLIPTextModelWithProjection (#25254)
Browse files Browse the repository at this point in the history
* Add FlaxClipTextModelWithProjection

This is necessary to support the Flax port of Stable Diffusion XL: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/fb6d705fb518524cabc79c77f13a0e7921bcab3a/text_encoder_2/config.json#L3

Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>

* Use FlaxCLIPTextModelOutput

* make fix-copies again

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Use `return_dict` for consistency with other uses.

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Fix docstring example.

* Add new model to FlaxCLIPTextModelTest

* Add to IGNORE_NON_AUTO_CONFIGURED list

* Fix naming convention.

---------

Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
  • Loading branch information
4 people authored Aug 25, 2023
1 parent 8968ffa commit cb8e3ee
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/clip.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] FlaxCLIPTextModel
- __call__

## FlaxCLIPTextModelWithProjection

[[autodoc]] FlaxCLIPTextModelWithProjection
- __call__

## FlaxCLIPVisionModel

[[autodoc]] FlaxCLIPVisionModel
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3965,6 +3965,7 @@
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
Expand Down Expand Up @@ -7388,6 +7389,7 @@
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
Expand Down Expand Up @@ -167,6 +168,7 @@
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
Expand Down
102 changes: 102 additions & 0 deletions src/transformers/models/clip/modeling_flax_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,36 @@
"""


@flax.struct.dataclass
class FlaxCLIPTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

text_embeds: jnp.ndarray = None
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxCLIPOutput(ModelOutput):
"""
Expand Down Expand Up @@ -1007,6 +1037,78 @@ class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel):
)


class FlaxCLIPTextModelWithProjectionModule(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32

def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)

def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)

if not return_dict:
return (text_embeds, text_outputs[0]) + text_outputs[2:]

return FlaxCLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)


class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel):
module_class = FlaxCLIPTextModelWithProjectionModule


FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection
>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
>>> outputs = model(**inputs)
>>> text_embeds = outputs.text_embeds
```
"""

overwrite_call_docstring(
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING
)
append_replace_return_docstrings(
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig
)


class FlaxCLIPVisionModule(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxCLIPTextModelWithProjection(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxCLIPTextPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]

Expand Down
9 changes: 7 additions & 2 deletions tests/models/clip/test_modeling_flax_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from transformers.models.clip.modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPVisionModel,
)

if is_torch_available():
import torch
Expand Down Expand Up @@ -315,7 +320,7 @@ def prepare_config_and_inputs_for_common(self):

@require_flax
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()

def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
"TFGroupViTTextModel",
"TFGroupViTVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC",
"DetrForSegmentation",
Expand Down

0 comments on commit cb8e3ee

Please sign in to comment.