Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert 22152 MaskedImageCompletionOutput changes #22187

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions src/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,34 +1281,6 @@ class ImageSuperResolutionOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class MaskedImageCompletionOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Reconstruction loss.
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[torch.FloatTensor] = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageCompletionOutput,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
Expand Down Expand Up @@ -648,7 +643,7 @@ def __init__(self, config: ViTConfig) -> None:
self.post_init()

@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand All @@ -658,7 +653,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedImageCompletionOutput]:
) -> Union[tuple, MaskedLMOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Expand Down Expand Up @@ -728,9 +723,9 @@ def forward(
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return MaskedImageCompletionOutput(
return MaskedLMOutput(
loss=masked_im_loss,
reconstruction=reconstructed_pixel_values,
logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/vit/test_modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)

# test greyscale images
Expand All @@ -145,7 +145,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label

pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
Expand Down