diff --git a/docs/source/en/model_doc/clip.md b/docs/source/en/model_doc/clip.md index be98edfb2a81d4..8c1e11c398c180 100644 --- a/docs/source/en/model_doc/clip.md +++ b/docs/source/en/model_doc/clip.md @@ -184,6 +184,11 @@ The resource should ideally demonstrate something new instead of duplicating an [[autodoc]] FlaxCLIPTextModel - __call__ +## FlaxCLIPTextModelWithProjection + +[[autodoc]] FlaxCLIPTextModelWithProjection + - __call__ + ## FlaxCLIPVisionModel [[autodoc]] FlaxCLIPVisionModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9b95aadffccc6f..9024481269a6f4 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3965,6 +3965,7 @@ "FlaxCLIPPreTrainedModel", "FlaxCLIPTextModel", "FlaxCLIPTextPreTrainedModel", + "FlaxCLIPTextModelWithProjection", "FlaxCLIPVisionModel", "FlaxCLIPVisionPreTrainedModel", ] @@ -7388,6 +7389,7 @@ FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, + FlaxCLIPTextModelWithProjection, FlaxCLIPTextPreTrainedModel, FlaxCLIPVisionModel, FlaxCLIPVisionPreTrainedModel, diff --git a/src/transformers/models/clip/__init__.py b/src/transformers/models/clip/__init__.py index 1f079783bed674..0ee0cfb0915f33 100644 --- a/src/transformers/models/clip/__init__.py +++ b/src/transformers/models/clip/__init__.py @@ -94,6 +94,7 @@ "FlaxCLIPPreTrainedModel", "FlaxCLIPTextModel", "FlaxCLIPTextPreTrainedModel", + "FlaxCLIPTextModelWithProjection", "FlaxCLIPVisionModel", "FlaxCLIPVisionPreTrainedModel", ] @@ -167,6 +168,7 @@ FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, + FlaxCLIPTextModelWithProjection, FlaxCLIPTextPreTrainedModel, FlaxCLIPVisionModel, FlaxCLIPVisionPreTrainedModel, diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index 750e5b05485866..5aeaa5d960a773 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -155,6 +155,36 @@ """ +@flax.struct.dataclass +class FlaxCLIPTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of + [`FlaxCLIPTextModel`]. + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: jnp.ndarray = None + last_hidden_state: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + @flax.struct.dataclass class FlaxCLIPOutput(ModelOutput): """ @@ -1007,6 +1037,78 @@ class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel): ) +class FlaxCLIPTextModelWithProjectionModule(nn.Module): + config: CLIPTextConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype) + self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_embeds = self.text_projection(pooled_output) + + if not return_dict: + return (text_embeds, text_outputs[0]) + text_outputs[2:] + + return FlaxCLIPTextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel): + module_class = FlaxCLIPTextModelWithProjectionModule + + +FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection + + >>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ``` +""" + +overwrite_call_docstring( + FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING +) +append_replace_return_docstrings( + FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig +) + + class FlaxCLIPVisionModule(nn.Module): config: CLIPVisionConfig dtype: jnp.dtype = jnp.float32 diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 78be4ef747e96a..7e5b78d3e6fc86 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -562,6 +562,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxCLIPTextModelWithProjection(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxCLIPTextPreTrainedModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/clip/test_modeling_flax_clip.py b/tests/models/clip/test_modeling_flax_clip.py index 565c641aef632b..c1d05081ca5310 100644 --- a/tests/models/clip/test_modeling_flax_clip.py +++ b/tests/models/clip/test_modeling_flax_clip.py @@ -19,7 +19,12 @@ convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, ) - from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel + from transformers.models.clip.modeling_flax_clip import ( + FlaxCLIPModel, + FlaxCLIPTextModel, + FlaxCLIPTextModelWithProjection, + FlaxCLIPVisionModel, + ) if is_torch_available(): import torch @@ -315,7 +320,7 @@ def prepare_config_and_inputs_for_common(self): @require_flax class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): - all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else () + all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else () def setUp(self): self.model_tester = FlaxCLIPTextModelTester(self) diff --git a/utils/check_repo.py b/utils/check_repo.py index 98f2436ae3af45..678294a1b0dcd1 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -205,6 +205,7 @@ "TFGroupViTTextModel", "TFGroupViTVisionModel", "FlaxCLIPTextModel", + "FlaxCLIPTextModelWithProjection", "FlaxCLIPVisionModel", "FlaxWav2Vec2ForCTC", "DetrForSegmentation",