Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ModelingOutput]add more output for skep model #3146

Merged
merged 10 commits into from
Sep 7, 2022
227 changes: 196 additions & 31 deletions paddlenlp/transformers/skep/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
else:
from paddlenlp.layers.crf import ViterbiDecoder

from ..model_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
SequenceClassifierOutput,
TokenClassifierOutput,
QuestionAnsweringModelOutput,
MultipleChoiceModelOutput,
MaskedLMOutput,
CausalLMOutputWithCrossAttentions,
)
from .. import PretrainedModel, register_base_model

__all__ = [
Expand Down Expand Up @@ -284,7 +293,10 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
The SkepModel forward method, overrides the `__call__()` special method.

Expand Down Expand Up @@ -319,9 +331,23 @@ def forward(self,
For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length],
[batch_size, num_attention_heads, sequence_length, sequence_length].
Defaults to `None`, which means nothing needed to be prevented attention to.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output
will be a tuple of tensors. Defaults to `False`.

Returns:
tuple: Returns tuple (`sequence_output`, `pooled_output`).
An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`.

if the reuslt is tuple: Returns tuple (`sequence_output`, `pooled_output`).

With the fields:

Expand Down Expand Up @@ -356,10 +382,26 @@ def forward(self,
embedding_output = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs
encoder_outputs = self.encoder(
embedding_output,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

if paddle.is_tensor(encoder_outputs):
encoder_outputs = (encoder_outputs, )

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions)

def get_input_embeddings(self) -> nn.Embedding:
"""get skep input word embedding
Expand Down Expand Up @@ -409,7 +451,11 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
The SkepForSequenceClassification forward method, overrides the __call__() special method.

Expand All @@ -422,10 +468,25 @@ def forward(self,
See :class:`SkepModel`.
attention_mask (Tensor, optional):
See :class:`SkepModel`.
labels (Tensor of shape `(batch_size,)`, optional):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1`
a regression loss is computed (Mean-Square loss), If `num_classes > 1`
a classification loss is computed (Cross-Entropy).
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.

Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Shape as `[batch_size, num_classes]` and dtype as float32.
An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and
not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`.

Example:
.. code-block::
Expand All @@ -441,14 +502,46 @@ def forward(self,
logits = model(**inputs)

"""
_, pooled_output = self.skep(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.skep(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits

loss = None
if labels is not None:
if self.num_classes == 1:
loss_fct = paddle.nn.MSELoss()
loss = loss_fct(logits, labels)
elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_classes)),
labels.reshape((-1, )))
else:
loss_fct = paddle.nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits, ) + outputs[2:]
if loss is not None:
return (loss, ) + output
if len(output) == 1:
return output[0]
return output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class SkepForTokenClassification(SkepPretrainedModel):
Expand Down Expand Up @@ -482,7 +575,11 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
The SkepForTokenClassification forward method, overrides the __call__() special method.

Expand All @@ -495,10 +592,22 @@ def forward(self,
See :class:`SkepModel`.
attention_mask (Tensor, optional):
See :class:`SkepModel`.
labels (Tensor of shape `(batch_size, sequence_length)`, optional):
Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.

Returns:
Tensor: Returns tensor `logits`, a tensor of the input token classification logits.
Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`.
An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and
not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`.

Example:
.. code-block::
Expand All @@ -514,14 +623,39 @@ def forward(self,
logits = model(**inputs)

"""
sequence_output, _ = self.skep(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.skep(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits

loss = None
if labels is not None:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_classes)),
labels.reshape((-1, )))

if not return_dict:
output = (logits, ) + outputs[2:]
if loss is not None:
return (loss, ) + output
if len(output) == 1:
return output[0]
return output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class SkepCrfForTokenClassification(SkepPretrainedModel):
Expand Down Expand Up @@ -564,7 +698,10 @@ def forward(self,
position_ids=None,
attention_mask=None,
seq_lens=None,
labels=None):
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
The SkepCrfForTokenClassification forward method, overrides the __call__() special method.

Expand All @@ -584,9 +721,22 @@ def forward(self,
labels (Tensor, optional):
The input label tensor.
Its data type should be int64 and its shape is `[batch_size, sequence_length]`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.

Returns:
Tensor: Returns tensor `loss` if `labels` is not None. Otherwise, returns tensor `prediction`.
An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and
not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`.

if return_dict is False, Returns tensor `loss` if `labels` is not None. Otherwise, returns tensor `prediction`.

- `loss` (Tensor):
The crf loss. Its data type is float32 and its shape is `[batch_size]`.
Expand All @@ -596,13 +746,15 @@ def forward(self,
Its data type is int64 and its shape is `[batch_size, sequence_length]`.

"""
sequence_output, _ = self.skep(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)

bigru_output, _ = self.gru(
sequence_output) #, sequence_length=seq_lens)
outputs = self.skep(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

bigru_output, _ = self.gru(outputs[0]) #, sequence_length=seq_lens)
emission = self.fc(bigru_output)

if seq_lens is None:
Expand All @@ -616,9 +768,22 @@ def forward(self,
seq_lens = paddle.ones(shape=[input_ids_shape[0]],
dtype=paddle.int64) * input_ids_shape[1]

loss, prediction = None, None
if labels is not None:
loss = self.crf_loss(emission, seq_lens, labels)
return loss
else:
_, prediction = self.viterbi_decoder(emission, seq_lens)

# FIXME(wj-Mcat): the output of this old version model is single tensor when return_dict is False
if not return_dict:
# when loss is None, return prediction
if labels is not None:
return loss
return prediction

return TokenClassifierOutput(
loss=loss,
logits=prediction,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading