From a5cf998f8fe56017237b8f257725dfec27d1aa88 Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Mon, 21 Feb 2022 18:23:19 +0800 Subject: [PATCH 1/7] add Prohetnet model --- paddlenlp/transformers/prophetnet/__init__.py | 0 paddlenlp/transformers/prophetnet/modeling.py | 1279 +++++++++++++++++ .../transformers/prophetnet/tokenizer.py | 490 +++++++ 3 files changed, 1769 insertions(+) create mode 100644 paddlenlp/transformers/prophetnet/__init__.py create mode 100644 paddlenlp/transformers/prophetnet/modeling.py create mode 100644 paddlenlp/transformers/prophetnet/tokenizer.py diff --git a/paddlenlp/transformers/prophetnet/__init__.py b/paddlenlp/transformers/prophetnet/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/transformers/prophetnet/modeling.py b/paddlenlp/transformers/prophetnet/modeling.py new file mode 100644 index 000000000000..c34b38cd0f5d --- /dev/null +++ b/paddlenlp/transformers/prophetnet/modeling.py @@ -0,0 +1,1279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import Tensor +from paddle.nn import Layer + +import paddlenlp +from .. import PretrainedModel + +__all__ = [ + 'ProphetNetModel', 'ProphetNetPretrainedModel', 'ProphetNetEncoder', 'ProphetNetDecoder', + 'ProphetNetForConditionalGeneration' +] + +ACT2FN = {"gelu": F.gelu} + + +def ngram_attention_bias(sequence_length, ngram, dtype): + """ + This function computes the bias for the predict stream + """ + left_block = paddle.ones((ngram, sequence_length, sequence_length), dtype=dtype) * float("-inf") + right_block = left_block.detach().clone() + # create bias + for stream_idx in range(ngram): + right_block[stream_idx] = right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx] = paddle.triu(left_block[stream_idx], diagonal=-stream_idx + 1) + + left_block[:, :, 0] = 0 + return paddle.concat([left_block, right_block], axis=2) + + +def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): + """ + This function computes individual parts of the relative position buckets. For more detail, see paper. + """ + inv_relative_positions = -relative_positions + rel_positions_bucket = 0 + + if is_bidirectional: + num_buckets = num_buckets // 2 + rel_positions_bucket = (rel_positions_bucket + paddle.cast(paddle.less_than(inv_relative_positions, + paddle.zeros_like( + inv_relative_positions)), + dtype=paddle.int32) * num_buckets) + inv_relative_positions = paddle.abs(inv_relative_positions) + else: + inv_relative_positions = paddle.cast(paddle.less_than(paddle.zeros_like(inv_relative_positions), + inv_relative_positions), + dtype=paddle.int32) * inv_relative_positions + + max_exact = num_buckets // 2 + is_small = paddle.less_than(inv_relative_positions, paddle.to_tensor(max_exact).cast(dtype=paddle.int32)) + val_if_large = max_exact + paddle.log(paddle.cast(inv_relative_positions, dtype=paddle.float32) + / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + val_if_large_num_buckets = paddle.ones_like(val_if_large) * (num_buckets - 1) + val_if_large_lt = paddle.cast(paddle.less_than(val_if_large, val_if_large_num_buckets), dtype=paddle.int32) + val_if_large = paddle.cast(val_if_large_lt * val_if_large, dtype=paddle.int32) + \ + (1 - val_if_large_lt) * val_if_large_num_buckets + rel_positions_bucket = rel_positions_bucket + paddle.where(is_small, paddle.cast(inv_relative_positions, + dtype=paddle.int32), + val_if_large) + return rel_positions_bucket + + +def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids): + """ + This function computes both main and predict relative position buckets. For more detail, see paper. + """ + # main stream + main_stream_relative_positions = paddle.tile(paddle.unsqueeze(position_ids, axis=1), + repeat_times=[1, position_ids.shape[-1], 1]) + main_stream_relative_positions = main_stream_relative_positions - paddle.unsqueeze(position_ids, axis=-1) + + # predicting stream + predicting_stream_relative_positions = paddle.unsqueeze(paddle.concat([position_ids - 1, position_ids], + axis=-1), axis=1) + predicting_stream_relative_positions = paddle.tile(predicting_stream_relative_positions, + repeat_times=[1, position_ids.shape[-1], 1]) + predicting_stream_relative_positions = predicting_stream_relative_positions - paddle.unsqueeze(position_ids, + axis=-1) + + # get both position buckets + main_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False + ) + predict_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False + ) + return main_relative_position_buckets, predict_relative_position_buckets + + +class ProphetNetPretrainedModel(PretrainedModel): + """ + An abstract class for pretrained Prophetnet models. It provides Prophetnet related + `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, + `pretrained_resource_files_map`, `base_model_prefix` for downloading and + loading pretrained models. + """ + model_config_file = "" + pretrained_init_configuration = { + "prophetnet-large-uncased": { + "activation_dropout": 0.1, + "activation_function": "gelu", + "attention_dropout": 0.1, + "bos_token_id": 102, + "decoder_ffn_dim": 4096, + "decoder_layerdrop": 0.0, + "decoder_max_position_embeddings": 514, + "decoder_start_token_id": 102, + "disable_ngram_loss": False, + "dropout": 0.1, + "encoder_ffn_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_max_position_embeddings": 513, + "eos_token_id": 102, + "eps": 0.1, + "hidden_size": 1024, + "init_std": 0.02, + "max_position_embeddings": 512, + "ngram": 2, + "num_buckets": 32, + "num_decoder_attention_heads": 16, + "num_decoder_layers": 12, + "num_encoder_attention_heads": 16, + "num_encoder_layers": 12, + "pad_token_id": 0, + "relative_max_distance": 128, + "length_penalty": 2.0, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "max_length": 142, + "vocab_size": 30522 + }, + } + resource_files_names = {"model_state": "model_state.pdparams"} + pretrained_resource_files_map = {} + base_model_prefix = "prophetnet" + + def init_weights(self, layer): + if isinstance(layer, nn.Linear): + layer.weight.set_value(paddle.tensor.normal(mean=0.0, + std=self.init_std if hasattr(self, "init_std") else + self.prophetnet.config["init_std"], + shape=layer.weight.shape)) + if layer.bias is not None: + layer.bias.set_value(paddle.tensor.zeros(layer.bias.shape)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.prophetnet.decoder_start_token_id + pad_token_id = self.prophetnet.config["pad_token_id"] + + assert (decoder_start_token_id is not None), \ + "self.model.config.decoder_start_token_id has to be defined. " \ + "In ProphetNet it is usually set to the pad_token_id. See ProphetNet docs for more information" + + # shift inputs to the right + shifted_input_ids = paddle.zeros_like(input_ids) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids_mask = paddle.cast(shifted_input_ids == -100, dtype=paddle.int32) + shifted_input_ids = shifted_input_ids_mask * pad_token_id + (1 - shifted_input_ids_mask) * shifted_input_ids + + assert paddle.sum(paddle.cast(shifted_input_ids >= 0, dtype=paddle.int32)).item() == shifted_input_ids.shape[ + -1], "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class ProphetNetPositionalEmbeddings(nn.Embedding): + """ + ProphetNetPositional Embeddings. + """ + + def __init__(self, max_position_embeddings, hidden_size, pad_token_id): + self.max_length = max_position_embeddings + super(ProphetNetPositionalEmbeddings, self).__init__(max_position_embeddings, hidden_size, pad_token_id) + + def forward(self, inputs_shape, attention_mask=None, past_key_values=None, position_ids=None): + assert (position_ids is None) or (self._padding_idx is None), \ + "If position_ids is pre-computed then padding_idx should not be set." + + if position_ids is None: + if past_key_values is not None: + # position_ids is the same for every token when decoding a single step + # Without the int() cast, it doesn't work in some cases when exporting to ONNX + prev_num_input_ids = past_key_values[0][0].shape[2] + num_input_ids = inputs_shape[1] + prev_num_input_ids + position_ids = paddle.ones((1, 1), dtype='int64') * ( + int(self._padding_idx + num_input_ids) + ) + else: + if attention_mask is None: + attention_mask = paddle.ones(inputs_shape, dtype='int64') + + # retrieve position_ids from input_ids / attention_mask + position_ids = paddle.cast( + paddle.cast(paddle.cumsum(attention_mask, axis=1), dtype=attention_mask.dtype) * + attention_mask, dtype=paddle.int64) + self._padding_idx + + # make sure position_ids are not bigger then max_length + position_ids = paddle.clip(position_ids, min=0, max=self.max_length - 1) + + return super().forward(position_ids), position_ids + + def _forward(self, position_ids): + return super().forward(position_ids) + + +class ProphetNetAttention(Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + """ + + def __init__(self, + hidden_size, + attention_dropout, + dropout, + num_attn_heads: int): + super().__init__() + hidden_size = hidden_size + + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_attn_heads = num_attn_heads + self.head_dim = hidden_size // num_attn_heads + + assert (self.head_dim * num_attn_heads == hidden_size), \ + "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and `config.num_decoder_attention_heads`" + + self.key_proj = nn.Linear(hidden_size, hidden_size) + self.value_proj = nn.Linear(hidden_size, hidden_size) + self.query_proj = nn.Linear(hidden_size, hidden_size) + + self.out_proj = nn.Linear(hidden_size, hidden_size) + + def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): + return paddle.transpose(paddle.reshape(tensor, [bsz, seq_len, self.num_attn_heads, self.head_dim]), + (0, 2, 1, 3)) + + def forward(self, + hidden_states, + key_value_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor]] = None) -> Tuple[Tensor, Optional[Tensor]]: + + batch_size, tgt_len, hidden_size = hidden_states.shape + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + assert hidden_states.shape == [batch_size, tgt_len, hidden_size, ], \ + f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.shape}" + + # previous time steps are cached - no need to recompute key and value if they are static + query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5) + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) + + if is_cross_attention: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + # project states into the correct shape + proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + query_states = paddle.reshape(self._shape(query_states, tgt_len, batch_size), proj_shape) + key_states = paddle.reshape(key_states, proj_shape) + value_states = paddle.reshape(value_states, proj_shape) + + src_len = key_states.shape[1] + attn_weights = paddle.bmm(query_states, key_states.transpose((0, 2, 1))) + assert attn_weights.shape == [batch_size * self.num_attn_heads, tgt_len, src_len, ], \ + f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}" + + # This is part of a workaround to get around fork/join parallelism not supporting Optional types. + if attention_mask is not None and len(attention_mask.shape) == 0: + attention_mask = None + assert attention_mask is None or attention_mask.shape == [self.num_attn_heads * batch_size, 1, src_len, ], \ + f"`attention_mask` should be `None` or of shape attention_mask.shape == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}" + + if attention_mask is not None: # don't attend to padding symbols + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, axis=-1) + + attn_probs = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = paddle.bmm(attn_probs, value_states) + assert attn_output.shape == [batch_size * self.num_attn_heads, tgt_len, self.head_dim, ], \ + f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.shape}" + + attn_output = (paddle.reshape(paddle.transpose(paddle.reshape(attn_output, + (batch_size, self.num_attn_heads, tgt_len, + self.head_dim)), + (0, 2, 1, 3)), + (batch_size, tgt_len, hidden_size))) + + attn_output = self.out_proj(attn_output) + + attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) + return attn_output, past_key_value + + +class ProphetNetFeedForward(Layer): + """ + This is the residual two feed-forward layer block based on the original Transformer implementation. + """ + + def __init__(self, + hidden_size, + activation_function, + activation_dropout, + dropout, + ffn_dim: int): + super(ProphetNetFeedForward, self).__init__() + self.activation_fn = ACT2FN[activation_function] + self.intermediate = nn.Linear(hidden_size, ffn_dim) + self.output = nn.Linear(ffn_dim, hidden_size) + self.activation_dropout = activation_dropout + self.dropout = dropout + + def forward(self, hidden_states): + hidden_states = self.intermediate(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.output(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ProphetNetNgramSelfAttention(Layer): + def __init__(self, + hidden_size, + num_buckets, + relative_max_distance, + num_decoder_attention_heads, + dropout, + attention_dropout, + ngram): + super(ProphetNetNgramSelfAttention, self).__init__() + + self.hidden_size = hidden_size + + self.num_buckets = num_buckets + self.relative_max_distance = relative_max_distance + self.num_attn_heads = num_decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.head_dim = hidden_size // self.num_attn_heads + self.ngram = ngram + + assert (self.head_dim * self.num_attn_heads == hidden_size), \ + "config.hidden_size must be divisible by num_attn_heads" + # key, value, query projection + self.key_proj = nn.Linear(hidden_size, hidden_size) + self.value_proj = nn.Linear(hidden_size, hidden_size) + self.query_proj = nn.Linear(hidden_size, hidden_size) + + # out projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + # rel position embeddings + self.relative_pos_embeddings = nn.Linear(hidden_size, self.num_buckets * self.num_attn_heads) + + def _shape(self, tensor, seq_len, batch_size): + return paddle.transpose(paddle.reshape(tensor, + (batch_size, seq_len, self.num_attn_heads, self.head_dim)), + (0, 2, 1, 3)) + + def forward(self, + hidden_states, + past_key_value: Optional[Tuple[Tensor]] = None, + attention_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None): + batch_size, ngram_sequence_length, hidden_size = hidden_states.shape + + assert hidden_states.shape == [batch_size, ngram_sequence_length, hidden_size, ], \ + f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}" + + # project + query_states = self.query_proj(hidden_states) + key_states = self.key_proj(hidden_states) + value_states = self.value_proj(hidden_states) + + # normalize + query_states = query_states / (self.head_dim ** 0.5) + + # reshape + query_states = self._shape(query_states, ngram_sequence_length, batch_size) + key_states = self._shape(key_states, -1, batch_size) + value_states = self._shape(value_states, -1, batch_size) + + proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + + query_states = paddle.reshape(query_states, proj_shape) + key_states = paddle.reshape(key_states, proj_shape) + value_states = paddle.reshape(value_states, proj_shape) + + # chunk into main stream and predict stream + hidden_states_list = paddle.chunk(hidden_states, 1 + self.ngram, axis=1) + + query_states_list = paddle.chunk(query_states, 1 + self.ngram, axis=1) + key_states_list = paddle.chunk(key_states, 1 + self.ngram, axis=1) + value_states_list = paddle.chunk(value_states, 1 + self.ngram, axis=1) + + main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] + main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] + main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] + main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] + + # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) + if past_key_value is not None: + prev_main_key_states = past_key_value[0].reshape([batch_size * self.num_attn_heads, -1, self.head_dim]) + main_key_states = paddle.concat((prev_main_key_states, main_key_states), axis=1) + prev_main_value_states = past_key_value[1].reshape([batch_size * self.num_attn_heads, -1, self.head_dim]) + main_value_states = paddle.concat((prev_main_value_states, main_value_states), axis=1) + + # Update cache + past_key_value = (paddle.reshape(main_key_states, (batch_size, self.num_attn_heads, -1, self.head_dim)), + paddle.reshape(main_value_states, (batch_size, self.num_attn_heads, -1, self.head_dim)),) + + # get seq_length of main stream only + sequence_length = ngram_sequence_length // (1 + self.ngram) + + # MAIN-STREAM + # main attn weights + main_attn_weights = paddle.bmm(main_query_states, paddle.transpose(main_key_states, (0, 2, 1))) + + # retrieve relative position embeddings for each layer -> see paper for more details + main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(main_hidden_states, + main_attn_weights, + position_ids, + main_relative_position_buckets) + + main_attn_weights = main_attn_weights + main_relative_pos_embeddings + + if attention_mask is not None: + main_attn_weights = main_attn_weights + attention_mask + + main_attn_probs = F.softmax(main_attn_weights, axis=-1, dtype=main_attn_weights.dtype) + + main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) + # project to attn_output + main_attn_output = paddle.bmm(main_attn_probs, main_value_states) + + # reshape so that num_heads dim is merged into last `head_dim` axis + main_attn_output = (paddle.reshape(paddle.transpose( + paddle.reshape(main_attn_output, (batch_size, self.num_attn_heads, sequence_length, self.head_dim)) + , (0, 2, 1, 3)), (batch_size, 1, sequence_length, hidden_size))) + main_attn_output = self.out_proj(main_attn_output) + + # PREDICT-STREAM + # [ngram, B*head, T, c] + predict_query_states = paddle.reshape(paddle.concat(predict_query_states_list, axis=0), + (self.ngram, -1, sequence_length, self.head_dim)) + # [ngram, B*head, 2*T, c] + predict_key_states = paddle.concat([paddle.unsqueeze(paddle.concat([main_key_states, key], axis=1), axis=0) + for key in predict_key_states_list], axis=0) + + # [ngram, T, B, C] + predict_hidden_states = paddle.reshape(paddle.concat(hidden_states_predict_list, axis=0), + (self.ngram, sequence_length, batch_size, hidden_size)) + + # [ngram, B*head, 2*T, c] + predict_value_states = paddle.concat([paddle.unsqueeze(paddle.concat([main_value_states, v_p], axis=1), axis=0) + for v_p in predict_value_states_list], axis=0) + + # [ngram, B*head, T, 2*T] + predict_attn_weights = paddlenlp.ops.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states)) + + # [ngram, B*head, T, S] + # retrieve relative position embeddings for each layer -> see paper for more details + predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(predict_hidden_states, + predict_attn_weights, position_ids, + predict_relative_position_buckets) + + # [ngram, B*head, T, 2*T] + predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings + + if extended_predict_attention_mask is not None: + predict_attn_weights = predict_attn_weights + paddle.cast(extended_predict_attention_mask, + predict_attn_weights.dtype) + + predict_attn_probs = F.softmax(predict_attn_weights, axis=-1, dtype=predict_attn_weights.dtype) + + predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) + # project to attention output + # [ngram, B*head, T, c] + predict_attn_output = paddlenlp.ops.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) + + # reshape so that num_heads dim is merged into last `head_dim` axis + # [ngram, B, T, C] + predict_attn_output = (paddle.reshape(paddle.transpose(paddle.reshape(predict_attn_output, (self.ngram, + batch_size, + self.num_attn_heads, + sequence_length, + self.head_dim)), + (1, 0, 3, 2, 4)), + (batch_size, self.ngram, sequence_length, hidden_size))) + predict_attn_output = self.out_proj(predict_attn_output) + + # concat to single attn output + # [B, 1+ngram*T, C] + attn_output = paddle.reshape(paddle.concat([main_attn_output, predict_attn_output], axis=1), + (batch_size, -1, hidden_size)) + # reshape into better form for `config.output_attentions` + main_attn_probs = paddle.reshape(main_attn_probs, (batch_size, self.num_attn_heads, sequence_length, -1)) + predict_attn_probs = paddle.transpose(paddle.reshape(predict_attn_probs, (self.ngram, + batch_size, + self.num_attn_heads, + sequence_length, -1)), + (1, 0, 2, 3, 4)) + + attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) + + return attn_output, main_attn_probs, predict_attn_probs, past_key_value + + def get_main_relative_pos_embeddings(self, hidden_states, attn_weights, position_ids, + main_relative_position_buckets): + # input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] + + if main_relative_position_buckets is None: + batch_size, sequence_length = hidden_states.shape[:2] + relative_positions = (paddle.tile(paddle.unsqueeze(paddle.unsqueeze(paddle.arange(1, + attn_weights. + shape[-1] + 1), + axis=0), + axis=0), + repeat_times=[batch_size, sequence_length, 1])) + relative_positions = relative_positions - paddle.tile(paddle.unsqueeze(position_ids, axis=0), + repeat_times=[batch_size, sequence_length, + 1]) # [B, T, s] + main_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] + rel_pos_embeddings = paddle.transpose(paddle.reshape(rel_pos_embeddings, + (rel_pos_embeddings.shape[:2] + + [self.num_buckets, self.num_attn_heads])), + (0, 3, 1, 2)) # [B,T,Buckets,head] + rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + [-1]) # [B*head,T,Buckets] + + main_relative_position_buckets = paddle.cast(paddle.reshape(paddle.tile(main_relative_position_buckets, + repeat_times=[1, self.num_attn_heads, + 1]), + (-1, main_relative_position_buckets.shape[-1])), + dtype=paddle.int64) # [B*head*T, T] + rel_pos_embeddings = paddle.reshape(rel_pos_embeddings, + (-1, rel_pos_embeddings.shape[-1])) # [B*head*T,Buckets] + + main_relative_position_buckets_index = paddle.tile(main_relative_position_buckets.unsqueeze(2), + repeat_times=[1, 1, 2]) + main_relative_position_buckets_index[:, :, 0] = \ + paddle.tile(paddle.arange(0, main_relative_position_buckets_index.shape[0]).unsqueeze(1), + repeat_times=[1, main_relative_position_buckets_index.shape[1]]) + + main_relative_pos_embeddings = paddle.reshape(paddle.gather_nd(rel_pos_embeddings, + index=main_relative_position_buckets_index), + (attn_weights.shape[:2] + [-1])) + return main_relative_pos_embeddings + + def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, position_ids, + predict_relative_position_buckets): + # input hidden_states [ngram, T,B,C], + # input attn_weights [ngram, B*head,T,S], + # input position_ids [B,T] or [1,1], + # input predict_relative_position_buckets [B,T, 2*T] or None + sequence_length, batch_size = hidden_states.shape[1:3] + + if predict_relative_position_buckets is None: + key_sequence_length = attn_weights.shape[-1] + assert position_ids[0][0] == key_sequence_length - 1, \ + "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)" + relative_positions = (paddle.tile(paddle.unsqueeze(paddle.unsqueeze(paddle.arange(0, key_sequence_length), + axis=0), + axis=0), + repeat_times=[batch_size, sequence_length, 1])) + + relative_positions = relative_positions - paddle.tile(paddle.unsqueeze(position_ids, axis=0), + repeat_times=[batch_size, sequence_length, 1]) + predict_relative_position_buckets = compute_relative_buckets(self.num_buckets, + self.relative_max_distance, + relative_positions, False) + + hidden_states = paddle.transpose(hidden_states, (0, 2, 1, 3)) # [ngram, B, T, C] + rel_pos_embeddings = paddle.reshape(self.relative_pos_embeddings(hidden_states), + hidden_states.shape[:-1] + + [self.num_buckets, + self.num_attn_heads]) # [ngram, B, T, bucket, head] + rel_pos_embeddings = paddle.reshape(paddle.transpose(rel_pos_embeddings, + (0, 1, 4, 2, 3)), + (self.ngram * batch_size * self.num_attn_heads, + sequence_length, -1)) # [ngram*B*head, T, bucket] + + predict_relative_position_buckets = paddle.tile(paddle.unsqueeze(predict_relative_position_buckets, + axis=0), + repeat_times=[self.ngram, 1, self.num_attn_heads, + 1]) # [ngram, B, head*T, S] + + rel_pos_embeddings = paddle.reshape(rel_pos_embeddings, (-1, rel_pos_embeddings.shape[-1])) + predict_relative_position_buckets = paddle.cast(paddle.reshape(predict_relative_position_buckets, + (-1, + predict_relative_position_buckets.shape[-1])), + dtype=paddle.int64) # [ngram*B*head*T, S] + + predict_relative_position_buckets_index = paddle.tile(predict_relative_position_buckets.unsqueeze(2), + repeat_times=[1, 1, 2]) + predict_relative_position_buckets_index[:, :, 0] = \ + paddle.tile(paddle.arange(0, predict_relative_position_buckets_index.shape[0]).unsqueeze(1), + repeat_times=[1, predict_relative_position_buckets_index.shape[1]]) + + predict_relative_pos_embeddings = paddle.reshape(paddle.gather_nd(rel_pos_embeddings, + index=predict_relative_position_buckets_index), + (self.ngram, batch_size * self.num_attn_heads, + sequence_length, -1)) # [ngram, B*head, T, S] + + return predict_relative_pos_embeddings + + +class ProphetNetEncoderLayer(Layer): + """ + Encoder block for Prophetnet + """ + + def __init__(self, + hidden_size, + encoder_ffn_dim, + activation_function, + activation_dropout, + attention_dropout, + dropout, + num_encoder_attention_heads): + super(ProphetNetEncoderLayer, self).__init__() + # 1st residual block + self.self_attn = ProphetNetAttention(hidden_size, attention_dropout, dropout, num_encoder_attention_heads) + self.self_attn_layer_norm = nn.LayerNorm(hidden_size) + + # 2nd residual block + self.feed_forward = ProphetNetFeedForward(hidden_size, activation_function, activation_dropout, dropout, + encoder_ffn_dim) + self.feed_forward_layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, + hidden_states, + attention_mask): + # 1st residual block + attention_output, _ = self.self_attn(hidden_states=hidden_states, + attention_mask=attention_mask) + hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) + + # 2nd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + return hidden_states + + +class ProphetNetDecoderLayer(Layer): + """ + Decoder block for Prophetnet + """ + + def __init__(self, + hidden_size, + num_buckets, + relative_max_distance, + num_decoder_attention_heads, + activation_function, + activation_dropout, + dropout, + attention_dropout, + ngram, + decoder_ffn_dim, + add_cross_attention): + super(ProphetNetDecoderLayer, self).__init__() + # 1st residual block + self.self_attn = ProphetNetNgramSelfAttention(hidden_size, num_buckets, relative_max_distance, + num_decoder_attention_heads, dropout, attention_dropout, ngram) + self.self_attn_layer_norm = nn.LayerNorm(hidden_size) + + # 2nd residual block + if add_cross_attention: + self.cross_attn = ProphetNetAttention(hidden_size, attention_dropout, dropout, num_decoder_attention_heads) + self.cross_attn_layer_norm = nn.LayerNorm(hidden_size) + + # 3rd residual block + self.feed_forward = ProphetNetFeedForward(hidden_size, activation_function, activation_dropout, dropout, + decoder_ffn_dim) + self.feed_forward_layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attn_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + past_key_value=None, + use_cache: bool = True): + # 1st residual block + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + ) + hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + if encoder_hidden_states is not None: + # 2nd residual block + attention_output, cross_attn_present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attn_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # 3rd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class ProphetNetEncoder(ProphetNetPretrainedModel): + r""" + word_embeddings (:obj:`torch.nn.Embeddings` of shape :obj:`(config.vocab_size, config.hidden_size)`, `optional`): + The word embedding parameters. This can be used to initialize :class:`~transformers.ProphetNetEncoder` with + pre-defined word embeddings instead of randomly initialized word embeddings. + """ + + def __init__(self, + word_embeddings, + vocab_size, + hidden_size, + pad_token_id, + max_position_embeddings, + encoder_ffn_dim, + activation_function, + activation_dropout, + attention_dropout, + dropout, + num_encoder_attention_heads, + num_encoder_layers, + init_std): + super(ProphetNetEncoder, self).__init__() + self.init_std = init_std + if word_embeddings is not None: + self.word_embeddings = word_embeddings + else: + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) + + self.position_embeddings = ProphetNetPositionalEmbeddings(max_position_embeddings, hidden_size, pad_token_id) + self.embeddings_layer_norm = nn.LayerNorm(hidden_size) + + self.layers = nn.LayerList([ProphetNetEncoderLayer(hidden_size, + encoder_ffn_dim, + activation_function, + activation_dropout, + attention_dropout, + dropout, + num_encoder_attention_heads) for _ in + range(num_encoder_layers)]) + + self.apply(self.init_weights) + + def forward(self, + input_ids=None, + attention_mask=None): + if input_ids is None: + raise ValueError("Input_ids cannot be None.") + inputs_embeds = self.word_embeddings(input_ids) + + # prepare attention mask + if attention_mask is not None: + extended_attention_mask = (paddle.tile(1.0 - attention_mask.unsqueeze(1), + repeat_times=[self.config["num_encoder_attention_heads"], + 1, 1])) * -10000.0 + extended_attention_mask = paddle.cast(extended_attention_mask, dtype=inputs_embeds.dtype) + extended_attention_mask.stop_gradient = True + else: + extended_attention_mask = None + + position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2]) + + hidden_states = inputs_embeds + position_embeddings + hidden_states = self.embeddings_layer_norm(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.config["dropout"], training=self.training) + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer(hidden_states, attention_mask=extended_attention_mask) + return hidden_states + + +class ProphetNetDecoder(ProphetNetPretrainedModel): + def __init__(self, + word_embeddings, + vocab_size, + hidden_size, + pad_token_id, + max_position_embeddings, + relative_max_distance, + ngram, + num_buckets, + num_decoder_attention_heads, + decoder_ffn_dim, + activation_function, + activation_dropout, + dropout, + attention_dropout, + add_cross_attention, + num_decoder_layers, + init_std): + super(ProphetNetDecoder, self).__init__() + self.init_std = init_std + self.ngram = ngram + self.num_buckets = num_buckets + self.relative_max_distance = relative_max_distance + self.dropout = dropout + self.max_target_positions = max_position_embeddings + self.add_cross_attention = add_cross_attention + if word_embeddings is not None: + self.word_embeddings = word_embeddings + else: + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) + + self.position_embeddings = ProphetNetPositionalEmbeddings(max_position_embeddings, hidden_size, pad_token_id) + + self.ngram_embeddings = nn.Embedding(self.ngram, hidden_size) + self.layers = nn.LayerList([ProphetNetDecoderLayer(hidden_size, + num_buckets, + relative_max_distance, + num_decoder_attention_heads, + activation_function, + activation_dropout, + dropout, + attention_dropout, + ngram, + decoder_ffn_dim, + add_cross_attention) for _ in + range(num_decoder_layers)]) + self.embeddings_layer_norm = nn.LayerNorm(hidden_size) + + self.apply(self.init_weights) + + def forward(self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=True): + if input_ids is None: + raise ValueError("Decoder input_ids cannot be None.") + inputs_embeds = self.word_embeddings(input_ids) + batch_size, sequence_length = inputs_embeds.shape[:2] + + main_stream_pos_embed, position_ids = self.position_embeddings((batch_size, sequence_length), + past_key_values=past_key_values) + + if past_key_values is not None: + main_relative_position_buckets, predict_relative_position_buckets = None, None + else: + main_relative_position_buckets, predict_relative_position_buckets = self.compute_buffered_relative_buckets( + position_ids) + predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1) + + # add position embeddings + hidden_states = inputs_embeds + main_stream_pos_embed + + ngram_embeddings = self.ngram_embeddings.weight + + # prepare attention mask + if past_key_values is not None: + assert hidden_states.shape[1] == 1, \ + "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" + + ngram_hidden_states = [paddle.tile((ngram_embeddings[ngram - 1] + predicting_stream_pos_embed), + repeat_times=[batch_size, 1, 1]) + for ngram in range(self.ngram)] + extended_attention_mask = None + extended_predict_attention_mask = None + else: + ngram_hidden_states = [(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) + for ngram in range(self.ngram)] + extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) + extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) + extended_attention_mask.stop_gradient = True + extended_predict_attention_mask.stop_gradient = True + + # prepare encoder attention mask + if encoder_attention_mask is not None: + extended_encoder_attention_mask = (1.0 - paddle.tile(encoder_attention_mask[:, None, :], + repeat_times=[ + self.config["num_decoder_attention_heads"], + 1, 1])) * -10000.0 + extended_encoder_attention_mask = paddle.cast(extended_encoder_attention_mask, dtype=inputs_embeds.dtype) + else: + extended_encoder_attention_mask = None + + hidden_states = paddle.concat([hidden_states] + ngram_hidden_states, axis=1) + + if self.embeddings_layer_norm: + hidden_states = self.embeddings_layer_norm(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + present_key_values = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer(hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + + if use_cache: + present_key_values += (layer_outputs[1],) + + last_hidden_state = hidden_states[:, :sequence_length] # 1-gram + last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.ngram > 0 else None # 2-gram + return tuple(v for v in [last_hidden_state, + last_hidden_state_ngram, + present_key_values] if v is not None) + + def compute_buffered_relative_buckets(self, position_ids): + batch_size, sequence_length = position_ids.shape + + if not hasattr(self, '_main_relative_buckets') or self._main_relative_buckets is None: + position_ids = paddle.tile(paddle.arange(1, self.max_target_positions + 1), repeat_times=[1, 1]) + self._main_relative_buckets, self._predict_relative_buckets = compute_all_stream_relative_buckets( + self.num_buckets, self.relative_max_distance, position_ids) + + # buffer relative buckets + main_relative_buckets = paddle.tile(self._main_relative_buckets[:, :sequence_length, :sequence_length], + repeat_times=[batch_size, 1, 1]) + predict_relative_buckets = paddle.tile(paddle.concat( + [self._predict_relative_buckets[:, :sequence_length, :sequence_length], + self._predict_relative_buckets[:, :sequence_length, + self.max_target_positions: self.max_target_positions + sequence_length]], axis=2), + repeat_times=[batch_size, 1, 1]) + + return main_relative_buckets, predict_relative_buckets + + def prepare_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + if not hasattr(self, '_causal_mask') or self._causal_mask is None: + causal_mask = paddle.full((self.max_target_positions, self.max_target_positions), -float("inf"), + dtype=hidden_states.dtype) + self._causal_mask = paddle.triu(causal_mask, 1) + extended_causal_mask = paddle.expand(self._causal_mask[:seq_length, :seq_length].unsqueeze(0), + shape=[batch_size, seq_length, seq_length]) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask.unsqueeze(1)) * -10000.0 + extended_attention_mask = extended_causal_mask + extended_attention_mask + else: + extended_attention_mask = extended_causal_mask + return paddle.cast(paddle.tile(extended_attention_mask, + repeat_times=[self.config["num_decoder_attention_heads"], 1, 1]), + dtype=hidden_states.dtype) + + def prepare_predict_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + if not hasattr(self, '_predict_causal_mask') or self._predict_causal_mask is None: + self._predict_causal_mask = ngram_attention_bias(self.max_target_positions, self.ngram, hidden_states.dtype) + predict_causal_mask = paddle.concat([self._predict_causal_mask[:, :seq_length, :seq_length], + self._predict_causal_mask[:, :seq_length, + self.max_target_positions: self.max_target_positions + seq_length]], + axis=-1) + extended_predict_causal_mask = paddle.expand(predict_causal_mask[:, None, :, :], + shape=predict_causal_mask.shape[:1] + [ + batch_size] + predict_causal_mask.shape[1:]) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * -10000.0 + extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) + # predicted stream attention_mask should always be 0 + extended_attention_mask = paddle.concat([extended_attention_mask, + paddle.zeros_like(extended_attention_mask)], + axis=-1) + extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask + else: + extended_predict_attention_mask = extended_predict_causal_mask + return paddle.cast(extended_predict_attention_mask.tile([1, self.config["num_decoder_attention_heads"], 1, 1]), + dtype=hidden_states.dtype) + + +class ProphetNetModel(ProphetNetPretrainedModel): + def __init__(self, + vocab_size, + bos_token_id=102, + pad_token_id=0, + eos_token_id=102, + hidden_size=1024, + decoder_start_token_id=102, + max_position_embeddings=512, + activation_function="gelu", + activation_dropout=0.1, + dropout=0.1, + relative_max_distance=128, + ngram=2, + num_buckets=32, + encoder_ffn_dim=4096, + num_encoder_attention_heads=16, + num_encoder_layers=12, + decoder_ffn_dim=4096, + num_decoder_attention_heads=16, + num_decoder_layers=12, + attention_dropout=0.1, + init_std=0.02, + eps=0.1, + add_cross_attention=True, + disable_ngram_loss=False, + **kwargs): + super(ProphetNetModel, self).__init__() + self.init_std = init_std + self.eps = eps + self.pad_token_id = pad_token_id + self.disable_ngram_loss = disable_ngram_loss + self.decoder_start_token_id = decoder_start_token_id + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) + + self.encoder = ProphetNetEncoder(self.word_embeddings, + vocab_size, + hidden_size, + pad_token_id, + max_position_embeddings, + encoder_ffn_dim, + activation_function, + activation_dropout, + attention_dropout, + dropout, + num_encoder_attention_heads, + num_encoder_layers, + init_std) + + self.decoder = ProphetNetDecoder(self.word_embeddings, + vocab_size, + hidden_size, + pad_token_id, + max_position_embeddings, + relative_max_distance, + ngram, + num_buckets, + num_decoder_attention_heads, + decoder_ffn_dim, + activation_function, + activation_dropout, + dropout, + attention_dropout, + add_cross_attention, + num_decoder_layers, + init_std) + + self.apply(self.init_weights) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward(self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_output: Optional[Tuple] = None, + use_cache=True, + past_key_values=None): + if attention_mask is None: + assert input_ids is not None, "input_ids should be " \ + "specified when generating attention_mask" + attention_mask = paddle.cast( + input_ids != self.pad_token_id, + dtype=paddle.get_default_dtype()) + + if decoder_attention_mask is None: + assert decoder_input_ids is not None, "decoder_input_ids should be " \ + "specified when generating decoder_attention_mask" + decoder_attention_mask = paddle.cast( + decoder_input_ids != self.pad_token_id, + dtype=paddle.get_default_dtype()) + if encoder_output is None: + encoder_output = self.encoder(input_ids=input_ids, + attention_mask=attention_mask) + decoder_outputs = self.decoder(input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + past_key_values=past_key_values) + return decoder_outputs + (encoder_output,) + + +class Linear_wo_bias(Layer): + def __init__(self, + in_features, + out_features, + weight_attr=None, + name=None): + super(Linear_wo_bias, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.name = name + + def forward(self, input): + out = F.linear( + x=input, weight=self.weight, name=self.name) + return out + + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'in_features={}, out_features={}, dtype={}{}'.format( + self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) + + +class ProphetNetForConditionalGeneration(ProphetNetPretrainedModel): + def __init__(self, prophetnet): + super(ProphetNetForConditionalGeneration, self).__init__() + self.prophetnet = prophetnet + self.padding_idx = prophetnet.word_embeddings._padding_idx + + self.lm_head = Linear_wo_bias(self.prophetnet.config["hidden_size"], self.prophetnet.config["vocab_size"]) + + # Initialize weights and apply final processing + self.apply(self.init_weights) + + def forward(self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_output=None, + labels=None, + use_cache=True, + past_key_values=None): + if labels is not None and decoder_input_ids is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + outputs = self.prophetnet(input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_output=encoder_output, + use_cache=use_cache, + past_key_values=past_key_values) + + batch_size, sequence_length = decoder_input_ids.shape + + predicting_streams = paddle.reshape(outputs[1], + (batch_size, self.prophetnet.config["ngram"], sequence_length, -1)) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + if use_cache: + past_key_values = outputs[2] + return logits, past_key_values, predict_logits + else: + return logits, predict_logits + + def prepare_inputs_for_generation(self, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, + cache=None, + use_cache=None, + encoder_output=None): + assert encoder_output is not None, "`encoder_output` have to be passed for generation." + if cache is not None: + decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) + + # first step, decoder_cached_states are empty + return {"input_ids": None, # encoder_outputs is defined. input_ids not needed + "decoder_input_ids": decoder_input_ids, + "encoder_output": encoder_output, + "decoder_attention_mask": decoder_attention_mask, + "attention_mask": attention_mask, + "use_cache": use_cache, + "past_key_values": cache} + + def prepare_decoder_input_ids_from_labels(self, labels): + return self._shift_right(labels) + + def get_encoder(self): + return self.prophetnet.encoder + + def get_decoder(self): + return self.prophetnet.decoder + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError as e: + try: + return getattr(getattr(self, self.base_model_prefix), name) + except AttributeError: + try: + return getattr(self, self.base_model_prefix).config[name] + except KeyError: + raise e diff --git a/paddlenlp/transformers/prophetnet/tokenizer.py b/paddlenlp/transformers/prophetnet/tokenizer.py new file mode 100644 index 000000000000..e3117265c2d1 --- /dev/null +++ b/paddlenlp/transformers/prophetnet/tokenizer.py @@ -0,0 +1,490 @@ +import collections +import logging +import os +from collections import OrderedDict +from typing import List + +from .. import PretrainedTokenizer, BasicTokenizer, WordpieceTokenizer + + +class Trie: + """ + Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass + Loose reference https://en.wikipedia.org/wiki/Trie + """ + + def __init__(self): + self.data = {} + + def add(self, word: str): + """ + Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. + The special key `""` is used to represent termination. + + This function is idempotent, adding twice the same word will leave the trie unchanged + + Example: + + ```python + >>> trie = Trie() + >>> trie.add("Hello 友達") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} + >>> trie.add("Hello") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} + ``` + """ + if not word: + # Prevent empty string + return + ref = self.data + for char in word: + ref[char] = char in ref and ref[char] or {} + ref = ref[char] + ref[""] = 1 + + def split(self, text: str) -> List[str]: + """ + Will look for the words added to the trie within `text`. Output is the original string splitted along the + boundaries of the words found. + + This trie will match the longest possible word first ! + + Example: + + ```python + >>> trie = Trie() + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS] This is a extra_id_100"] + >>> trie.add("[CLS]") + >>> trie.add("extra_id_1") + >>> trie.add("extra_id_100") + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS]", " This is a ", "extra_id_100"] + ``` + """ + # indexes are counted left of the chars index. + # "hello", index 0, is left of h, index 1 is between h and e. + # index 5 is right of the "o". + + # States are going to capture every possible start (indexes as above) + # as keys, and have as values, a pointer to the position in the trie + # where we're at. This is a partial match for now. + # This enables to keep track of multiple matches while we're iterating + # the string + # If the trie contains, "blowing", and "lower" and we encounter the + # string "blower", we need to split into ["b", "lower"]. + # This is where we need to keep track of multiple possible starts. + states = OrderedDict() + + # This will contain every indices where we need + # to cut. + # We force to cut at offset 0 and len(text) (added later) + offsets = [0] + + # This is used by the lookahead which needs to skip over + # some text where the full match exceeded the place in the initial + # for loop + skip = None + # Main loop, Giving this algorithm O(n) complexity + for current, current_char in enumerate(text): + if skip and current < skip: + # Prevents the lookahead for matching twice + # like extra_id_100 and id_100 + continue + + # This will track every state + # that stop matching, we need to stop tracking them. + # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then + # fail on "b", we need to remove 0 from the valid states. + to_remove = set() + # Whenever we found a match, we need to drop everything + # this is a greedy algorithm, it will match on the first found token + reset = False + + # In this case, we already have partial matches (But unfinished) + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + + # Lookahead to match longest first + # Important in case of extra_id_1 vs extra_id_100 + # Here we are also actively looking for other earlier partial + # matches + # "[CLS]", "L", we need to match CLS even if L is special + for lookstart, looktrie_pointer in states.items(): + if lookstart > start: + # This partial match is later, we can stop looking + break + elif lookstart < start: + # This partial match is earlier, the trie pointer + # was already updated, so index is + 1 + lookahead_index = current + 1 + end = current + 1 + else: + # Here lookstart == start and + # looktrie_pointer == trie_pointer + # It wasn't updated yet so indices are current ones + lookahead_index = current + end = current + next_char = text[lookahead_index] if lookahead_index < len(text) else None + while next_char in looktrie_pointer: + looktrie_pointer = looktrie_pointer[next_char] + lookahead_index += 1 + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + + if lookahead_index == len(text): + # End of string + break + next_char = text[lookahead_index] + # End lookahead + + # Storing and resetting + offsets.append(start) + offsets.append(end) + reset = True + break + elif current_char in trie_pointer: + # The current character being looked at has a match within the trie + # update the pointer (it will be stored back into states later). + trie_pointer = trie_pointer[current_char] + + # Storing back the new pointer into the states. + # Partial matches got longer by one. + states[start] = trie_pointer + else: + # The new character has not match in the trie, we need + # to stop keeping track of this partial match. + # We can't do it directly within the loop because of how + # python iteration works + to_remove.add(start) + + # Either clearing the full start (we found a real match) + # Or clearing only the partial matches that didn't work. + if reset: + states = {} + else: + for start in to_remove: + del states[start] + + # If this character is a starting character within the trie + # start keeping track of this partial match. + if current_char in self.data: + states[current] = self.data[current_char] + + # We have a cut at the end with states. + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + end = len(text) + offsets.append(start) + offsets.append(end) + # Longest cut is always the one with lower start so the first + # item so we need to break. + break + + return self.cut_text(text, offsets) + + def cut_text(self, text, offsets): + # We have all the offsets now, we just need to do the actual splitting. + # We need to eventually add the first part of the string and the eventual + # last part. + offsets.append(len(text)) + tokens = [] + start = 0 + for end in offsets: + if start > end: + logging.error( + "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway." + ) + continue + elif start == end: + # This might happen if there's a match at index 0 + # we're also preventing zero-width cuts in case of two + # consecutive matches + continue + tokens.append(text[start:end]) + start = end + + return tokens + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def create_trie(unique_no_split_tokens): + trie = Trie() + for token in unique_no_split_tokens: + trie.add(token) + return trie + + +class ProphetNetTokenizer(PretrainedTokenizer): + r""" + Construct a ProphetNetTokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + x_sep_token (`str`, *optional*, defaults to `"[X_SEP]"`): + Special second separator token, which can be generated by + [`ProphetNetForConditionalGeneration`]. It is used to separate bullet-point like + sentences in summarization, *e.g.*. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + resource_files_names = {"vocab_file": "prophetnet.tokenizer"} + pretrained_resource_files_map = {} + + def __init__(self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + unk_token="[UNK]", + sep_token="[SEP]", + bos_token="[SEP]", + eos_token="[SEP]", + cls_token="[CLS]", + x_sep_token="[X_SEP]", + pad_token="[PAD]", + mask_token="[MASK]"): + self.unique_no_split_tokens = [x_sep_token, unk_token, sep_token, bos_token, eos_token, cls_token, pad_token, + mask_token] + self.tokens_trie = create_trie(self.unique_no_split_tokens) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=unk_token) + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab) + + def tokenize(self, text): + return self._tokenize(text) + + def _tokenize(self, text): + """ + Converts a string to a list of tokens. + + Args: + text (str): The text to be tokenized. + + Returns: + List[str]: A list of string representing converted tokens. + """ + no_split_token = set(self.unique_no_split_tokens) + tokens = self.tokens_trie.split(text) + for i, token in enumerate(tokens): + if token in no_split_token: + left = tokens[i - 1] if i > 0 else None + right = tokens[i + 1] if i < len(tokens) - 1 else None + # We strip left and right by default + if right: + tokens[i + 1] = right.lstrip() + if left: + tokens[i - 1] = left.rstrip() + # ["This is something", "", "else"] + tokenized_text = [] + for token in tokens: + # Need to skip eventual empty (fully stripped) tokens + if not token: + continue + if token in no_split_token: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize_function(token)) + # ["This", " is", " something", "", "else"] + return tokenized_text + + def _tokenize_function(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_ids(self, tokens): + """ + Converts a sequence of tokens into ids using the `vocab` attribute (an + instance of `Vocab`). Override it if needed. + + Args: + tokens (list[int]): List of token ids. + + Returns: + list: Converted id list. + """ + if not isinstance(tokens, (list, tuple)): + return self._convert_token_to_id(tokens) + else: + return [self._convert_token_to_id(token) for token in tokens] + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """ + Converts a single index or a sequence of indices to a token or + a sequence of tokens, using the vocabulary and added tokens. + + Args: + ids (int or List[int]): + The token id (or token ids) to be converted to token(s). + skip_special_tokens (bool, optional): + Whether or not to remove special tokens in the decoding. + Defaults to `False` and we do not remove special tokens. + + Returns: + str or List[str]: The decoded token(s). + """ + if not isinstance(ids, (list, tuple)): + return self._convert_id_to_token(ids) + tokens = [self._convert_id_to_token(_id) for _id in ids] + if skip_special_tokens: + return [ + token for token in tokens + if token not in self.all_special_tokens + ] + return tokens + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def convert_ids_to_string(self, ids): + return self.convert_tokens_to_string(self.convert_ids_to_tokens(ids)) + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ProphetNet + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given + sequence(s). + """ + sep = [self.sep_token_id] + if token_ids_1 is None: + return len(token_ids_0 + sep) * [0] + return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.sep_token_id] + sep = [self.sep_token_id] + return token_ids_0 + sep + token_ids_1 + sep + + def save_vocabulary(self, save_directory): + index = 0 + vocab_file = os.path.join(save_directory, self.resource_files_names["vocab_file"]) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logging.warning(f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!") + index = token_index + writer.write(token + "\n") + index += 1 From 6f984b95975f735382ba3ae021a12af46bb2108b Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Mon, 21 Feb 2022 18:47:27 +0800 Subject: [PATCH 2/7] update prohetnet --- paddlenlp/transformers/prophetnet/__init__.py | 2 ++ paddlenlp/transformers/prophetnet/modeling.py | 3 +-- paddlenlp/transformers/prophetnet/tokenizer.py | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/prophetnet/__init__.py b/paddlenlp/transformers/prophetnet/__init__.py index e69de29bb2d1..cd8b34346bd0 100644 --- a/paddlenlp/transformers/prophetnet/__init__.py +++ b/paddlenlp/transformers/prophetnet/__init__.py @@ -0,0 +1,2 @@ +from .modeling import * +from .tokenizer import * \ No newline at end of file diff --git a/paddlenlp/transformers/prophetnet/modeling.py b/paddlenlp/transformers/prophetnet/modeling.py index c34b38cd0f5d..b5343d273c2d 100644 --- a/paddlenlp/transformers/prophetnet/modeling.py +++ b/paddlenlp/transformers/prophetnet/modeling.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple - import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import Tensor from paddle.nn import Layer +from typing import Optional, Tuple import paddlenlp from .. import PretrainedModel diff --git a/paddlenlp/transformers/prophetnet/tokenizer.py b/paddlenlp/transformers/prophetnet/tokenizer.py index e3117265c2d1..90951276e541 100644 --- a/paddlenlp/transformers/prophetnet/tokenizer.py +++ b/paddlenlp/transformers/prophetnet/tokenizer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import collections import logging import os From 555ad31550dad269e4466655384714900b8007dd Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Mon, 21 Feb 2022 19:00:28 +0800 Subject: [PATCH 3/7] update format --- paddlenlp/transformers/prophetnet/__init__.py | 2 +- paddlenlp/transformers/prophetnet/modeling.py | 1005 +++++++++-------- .../transformers/prophetnet/tokenizer.py | 47 +- 3 files changed, 587 insertions(+), 467 deletions(-) diff --git a/paddlenlp/transformers/prophetnet/__init__.py b/paddlenlp/transformers/prophetnet/__init__.py index cd8b34346bd0..aa915c710509 100644 --- a/paddlenlp/transformers/prophetnet/__init__.py +++ b/paddlenlp/transformers/prophetnet/__init__.py @@ -1,2 +1,2 @@ from .modeling import * -from .tokenizer import * \ No newline at end of file +from .tokenizer import * diff --git a/paddlenlp/transformers/prophetnet/modeling.py b/paddlenlp/transformers/prophetnet/modeling.py index b5343d273c2d..61520cd91e67 100644 --- a/paddlenlp/transformers/prophetnet/modeling.py +++ b/paddlenlp/transformers/prophetnet/modeling.py @@ -24,8 +24,8 @@ from .. import PretrainedModel __all__ = [ - 'ProphetNetModel', 'ProphetNetPretrainedModel', 'ProphetNetEncoder', 'ProphetNetDecoder', - 'ProphetNetForConditionalGeneration' + 'ProphetNetModel', 'ProphetNetPretrainedModel', 'ProphetNetEncoder', + 'ProphetNetDecoder', 'ProphetNetForConditionalGeneration' ] ACT2FN = {"gelu": F.gelu} @@ -35,18 +35,24 @@ def ngram_attention_bias(sequence_length, ngram, dtype): """ This function computes the bias for the predict stream """ - left_block = paddle.ones((ngram, sequence_length, sequence_length), dtype=dtype) * float("-inf") + left_block = paddle.ones( + (ngram, sequence_length, sequence_length), dtype=dtype) * float("-inf") right_block = left_block.detach().clone() # create bias for stream_idx in range(ngram): - right_block[stream_idx] = right_block[stream_idx].fill_diagonal_(0, wrap=False) - left_block[stream_idx] = paddle.triu(left_block[stream_idx], diagonal=-stream_idx + 1) + right_block[stream_idx] = right_block[stream_idx].fill_diagonal_( + 0, wrap=False) + left_block[stream_idx] = paddle.triu(left_block[stream_idx], + diagonal=-stream_idx + 1) left_block[:, :, 0] = 0 return paddle.concat([left_block, right_block], axis=2) -def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): +def compute_relative_buckets(num_buckets, + max_distance, + relative_positions, + is_bidirectional=False): """ This function computes individual parts of the relative position buckets. For more detail, see paper. """ @@ -55,54 +61,70 @@ def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_b if is_bidirectional: num_buckets = num_buckets // 2 - rel_positions_bucket = (rel_positions_bucket + paddle.cast(paddle.less_than(inv_relative_positions, - paddle.zeros_like( - inv_relative_positions)), - dtype=paddle.int32) * num_buckets) + rel_positions_bucket = (rel_positions_bucket + paddle.cast( + paddle.less_than(inv_relative_positions, + paddle.zeros_like(inv_relative_positions)), + dtype=paddle.int32) * num_buckets) inv_relative_positions = paddle.abs(inv_relative_positions) else: - inv_relative_positions = paddle.cast(paddle.less_than(paddle.zeros_like(inv_relative_positions), - inv_relative_positions), - dtype=paddle.int32) * inv_relative_positions + inv_relative_positions = paddle.cast( + paddle.less_than(paddle.zeros_like(inv_relative_positions), + inv_relative_positions), + dtype=paddle.int32) * inv_relative_positions max_exact = num_buckets // 2 - is_small = paddle.less_than(inv_relative_positions, paddle.to_tensor(max_exact).cast(dtype=paddle.int32)) - val_if_large = max_exact + paddle.log(paddle.cast(inv_relative_positions, dtype=paddle.float32) - / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - val_if_large_num_buckets = paddle.ones_like(val_if_large) * (num_buckets - 1) - val_if_large_lt = paddle.cast(paddle.less_than(val_if_large, val_if_large_num_buckets), dtype=paddle.int32) + is_small = paddle.less_than( + inv_relative_positions, + paddle.to_tensor(max_exact).cast(dtype=paddle.int32)) + val_if_large = max_exact + paddle.log( + paddle.cast(inv_relative_positions, dtype=paddle.float32) / max_exact + ) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + val_if_large_num_buckets = paddle.ones_like(val_if_large) * (num_buckets - + 1) + val_if_large_lt = paddle.cast(paddle.less_than(val_if_large, + val_if_large_num_buckets), + dtype=paddle.int32) val_if_large = paddle.cast(val_if_large_lt * val_if_large, dtype=paddle.int32) + \ (1 - val_if_large_lt) * val_if_large_num_buckets - rel_positions_bucket = rel_positions_bucket + paddle.where(is_small, paddle.cast(inv_relative_positions, - dtype=paddle.int32), - val_if_large) + rel_positions_bucket = rel_positions_bucket + paddle.where( + is_small, paddle.cast(inv_relative_positions, dtype=paddle.int32), + val_if_large) return rel_positions_bucket -def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids): +def compute_all_stream_relative_buckets(num_buckets, max_distance, + position_ids): """ This function computes both main and predict relative position buckets. For more detail, see paper. """ # main stream - main_stream_relative_positions = paddle.tile(paddle.unsqueeze(position_ids, axis=1), - repeat_times=[1, position_ids.shape[-1], 1]) - main_stream_relative_positions = main_stream_relative_positions - paddle.unsqueeze(position_ids, axis=-1) + main_stream_relative_positions = paddle.tile( + paddle.unsqueeze(position_ids, axis=1), + repeat_times=[1, position_ids.shape[-1], 1]) + main_stream_relative_positions = main_stream_relative_positions - paddle.unsqueeze( + position_ids, axis=-1) # predicting stream - predicting_stream_relative_positions = paddle.unsqueeze(paddle.concat([position_ids - 1, position_ids], - axis=-1), axis=1) - predicting_stream_relative_positions = paddle.tile(predicting_stream_relative_positions, - repeat_times=[1, position_ids.shape[-1], 1]) - predicting_stream_relative_positions = predicting_stream_relative_positions - paddle.unsqueeze(position_ids, - axis=-1) + predicting_stream_relative_positions = paddle.unsqueeze(paddle.concat( + [position_ids - 1, position_ids], axis=-1), + axis=1) + predicting_stream_relative_positions = paddle.tile( + predicting_stream_relative_positions, + repeat_times=[1, position_ids.shape[-1], 1]) + predicting_stream_relative_positions = predicting_stream_relative_positions - paddle.unsqueeze( + position_ids, axis=-1) # get both position buckets main_relative_position_buckets = compute_relative_buckets( - num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False - ) + num_buckets, + max_distance, + main_stream_relative_positions, + is_bidirectional=False) predict_relative_position_buckets = compute_relative_buckets( - num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False - ) + num_buckets, + max_distance, + predicting_stream_relative_positions, + is_bidirectional=False) return main_relative_position_buckets, predict_relative_position_buckets @@ -155,10 +177,12 @@ class ProphetNetPretrainedModel(PretrainedModel): def init_weights(self, layer): if isinstance(layer, nn.Linear): - layer.weight.set_value(paddle.tensor.normal(mean=0.0, - std=self.init_std if hasattr(self, "init_std") else - self.prophetnet.config["init_std"], - shape=layer.weight.shape)) + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.init_std if hasattr(self, "init_std") else + self.prophetnet.config["init_std"], + shape=layer.weight.shape)) if layer.bias is not None: layer.bias.set_value(paddle.tensor.zeros(layer.bias.shape)) @@ -177,11 +201,15 @@ def _shift_right(self, input_ids): assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids_mask = paddle.cast(shifted_input_ids == -100, dtype=paddle.int32) - shifted_input_ids = shifted_input_ids_mask * pad_token_id + (1 - shifted_input_ids_mask) * shifted_input_ids + shifted_input_ids_mask = paddle.cast(shifted_input_ids == -100, + dtype=paddle.int32) + shifted_input_ids = shifted_input_ids_mask * pad_token_id + ( + 1 - shifted_input_ids_mask) * shifted_input_ids - assert paddle.sum(paddle.cast(shifted_input_ids >= 0, dtype=paddle.int32)).item() == shifted_input_ids.shape[ - -1], "Verify that `shifted_input_ids` has only positive values" + assert paddle.sum( + paddle.cast(shifted_input_ids >= 0, dtype=paddle.int32)).item( + ) == shifted_input_ids.shape[ + -1], "Verify that `shifted_input_ids` has only positive values" return shifted_input_ids @@ -190,12 +218,16 @@ class ProphetNetPositionalEmbeddings(nn.Embedding): """ ProphetNetPositional Embeddings. """ - def __init__(self, max_position_embeddings, hidden_size, pad_token_id): self.max_length = max_position_embeddings - super(ProphetNetPositionalEmbeddings, self).__init__(max_position_embeddings, hidden_size, pad_token_id) + super(ProphetNetPositionalEmbeddings, + self).__init__(max_position_embeddings, hidden_size, pad_token_id) - def forward(self, inputs_shape, attention_mask=None, past_key_values=None, position_ids=None): + def forward(self, + inputs_shape, + attention_mask=None, + past_key_values=None, + position_ids=None): assert (position_ids is None) or (self._padding_idx is None), \ "If position_ids is pre-computed then padding_idx should not be set." @@ -205,20 +237,23 @@ def forward(self, inputs_shape, attention_mask=None, past_key_values=None, posit # Without the int() cast, it doesn't work in some cases when exporting to ONNX prev_num_input_ids = past_key_values[0][0].shape[2] num_input_ids = inputs_shape[1] + prev_num_input_ids - position_ids = paddle.ones((1, 1), dtype='int64') * ( - int(self._padding_idx + num_input_ids) - ) + position_ids = paddle.ones( + (1, 1), + dtype='int64') * (int(self._padding_idx + num_input_ids)) else: if attention_mask is None: attention_mask = paddle.ones(inputs_shape, dtype='int64') # retrieve position_ids from input_ids / attention_mask position_ids = paddle.cast( - paddle.cast(paddle.cumsum(attention_mask, axis=1), dtype=attention_mask.dtype) * - attention_mask, dtype=paddle.int64) + self._padding_idx + paddle.cast(paddle.cumsum(attention_mask, axis=1), + dtype=attention_mask.dtype) * attention_mask, + dtype=paddle.int64) + self._padding_idx # make sure position_ids are not bigger then max_length - position_ids = paddle.clip(position_ids, min=0, max=self.max_length - 1) + position_ids = paddle.clip(position_ids, + min=0, + max=self.max_length - 1) return super().forward(position_ids), position_ids @@ -230,11 +265,7 @@ class ProphetNetAttention(Layer): """ Multi-headed attention from 'Attention Is All You Need' paper. """ - - def __init__(self, - hidden_size, - attention_dropout, - dropout, + def __init__(self, hidden_size, attention_dropout, dropout, num_attn_heads: int): super().__init__() hidden_size = hidden_size @@ -254,14 +285,18 @@ def __init__(self, self.out_proj = nn.Linear(hidden_size, hidden_size) def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): - return paddle.transpose(paddle.reshape(tensor, [bsz, seq_len, self.num_attn_heads, self.head_dim]), - (0, 2, 1, 3)) - - def forward(self, - hidden_states, - key_value_states: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - past_key_value: Optional[Tuple[Tensor]] = None) -> Tuple[Tensor, Optional[Tensor]]: + return paddle.transpose( + paddle.reshape(tensor, + [bsz, seq_len, self.num_attn_heads, self.head_dim]), + (0, 2, 1, 3)) + + def forward( + self, + hidden_states, + key_value_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor]] = None + ) -> Tuple[Tensor, Optional[Tensor]]: batch_size, tgt_len, hidden_size = hidden_states.shape @@ -272,7 +307,7 @@ def forward(self, f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.shape}" # previous time steps are cached - no need to recompute key and value if they are static - query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5) + query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions @@ -280,12 +315,16 @@ def forward(self, value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) + key_states = self._shape(self.key_proj(key_value_states), -1, + batch_size) + value_states = self._shape(self.value_proj(key_value_states), -1, + batch_size) else: # self_attention - key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) + key_states = self._shape(self.key_proj(hidden_states), -1, + batch_size) + value_states = self._shape(self.value_proj(hidden_states), -1, + batch_size) if is_cross_attention: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -296,7 +335,8 @@ def forward(self, # project states into the correct shape proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) - query_states = paddle.reshape(self._shape(query_states, tgt_len, batch_size), proj_shape) + query_states = paddle.reshape( + self._shape(query_states, tgt_len, batch_size), proj_shape) key_states = paddle.reshape(key_states, proj_shape) value_states = paddle.reshape(value_states, proj_shape) @@ -316,21 +356,26 @@ def forward(self, attn_weights = F.softmax(attn_weights, axis=-1) - attn_probs = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_probs = F.dropout(attn_weights, + p=self.attention_dropout, + training=self.training) attn_output = paddle.bmm(attn_probs, value_states) assert attn_output.shape == [batch_size * self.num_attn_heads, tgt_len, self.head_dim, ], \ f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.shape}" - attn_output = (paddle.reshape(paddle.transpose(paddle.reshape(attn_output, - (batch_size, self.num_attn_heads, tgt_len, - self.head_dim)), - (0, 2, 1, 3)), - (batch_size, tgt_len, hidden_size))) + attn_output = (paddle.reshape( + paddle.transpose( + paddle.reshape( + attn_output, + (batch_size, self.num_attn_heads, tgt_len, self.head_dim)), + (0, 2, 1, 3)), (batch_size, tgt_len, hidden_size))) attn_output = self.out_proj(attn_output) - attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) + attn_output = F.dropout(attn_output, + p=self.dropout, + training=self.training) return attn_output, past_key_value @@ -338,13 +383,8 @@ class ProphetNetFeedForward(Layer): """ This is the residual two feed-forward layer block based on the original Transformer implementation. """ - - def __init__(self, - hidden_size, - activation_function, - activation_dropout, - dropout, - ffn_dim: int): + def __init__(self, hidden_size, activation_function, activation_dropout, + dropout, ffn_dim: int): super(ProphetNetFeedForward, self).__init__() self.activation_fn = ACT2FN[activation_function] self.intermediate = nn.Linear(hidden_size, ffn_dim) @@ -356,20 +396,19 @@ def forward(self, hidden_states): hidden_states = self.intermediate(hidden_states) hidden_states = self.activation_fn(hidden_states) - hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = F.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) hidden_states = self.output(hidden_states) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, + p=self.dropout, + training=self.training) return hidden_states class ProphetNetNgramSelfAttention(Layer): - def __init__(self, - hidden_size, - num_buckets, - relative_max_distance, - num_decoder_attention_heads, - dropout, - attention_dropout, + def __init__(self, hidden_size, num_buckets, relative_max_distance, + num_decoder_attention_heads, dropout, attention_dropout, ngram): super(ProphetNetNgramSelfAttention, self).__init__() @@ -394,12 +433,15 @@ def __init__(self, self.out_proj = nn.Linear(hidden_size, hidden_size) # rel position embeddings - self.relative_pos_embeddings = nn.Linear(hidden_size, self.num_buckets * self.num_attn_heads) + self.relative_pos_embeddings = nn.Linear( + hidden_size, self.num_buckets * self.num_attn_heads) def _shape(self, tensor, seq_len, batch_size): - return paddle.transpose(paddle.reshape(tensor, - (batch_size, seq_len, self.num_attn_heads, self.head_dim)), - (0, 2, 1, 3)) + return paddle.transpose( + paddle.reshape( + tensor, + (batch_size, seq_len, self.num_attn_heads, self.head_dim)), + (0, 2, 1, 3)) def forward(self, hidden_states, @@ -420,10 +462,11 @@ def forward(self, value_states = self.value_proj(hidden_states) # normalize - query_states = query_states / (self.head_dim ** 0.5) + query_states = query_states / (self.head_dim**0.5) # reshape - query_states = self._shape(query_states, ngram_sequence_length, batch_size) + query_states = self._shape(query_states, ngram_sequence_length, + batch_size) key_states = self._shape(key_states, -1, batch_size) value_states = self._shape(value_states, -1, batch_size) @@ -440,164 +483,211 @@ def forward(self, key_states_list = paddle.chunk(key_states, 1 + self.ngram, axis=1) value_states_list = paddle.chunk(value_states, 1 + self.ngram, axis=1) - main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] - main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] - main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] - main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] + main_hidden_states, hidden_states_predict_list = hidden_states_list[ + 0], hidden_states_list[1:] + main_query_states, predict_query_states_list = query_states_list[ + 0], query_states_list[1:] + main_key_states, predict_key_states_list = key_states_list[ + 0], key_states_list[1:] + main_value_states, predict_value_states_list = value_states_list[ + 0], value_states_list[1:] # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) if past_key_value is not None: - prev_main_key_states = past_key_value[0].reshape([batch_size * self.num_attn_heads, -1, self.head_dim]) - main_key_states = paddle.concat((prev_main_key_states, main_key_states), axis=1) - prev_main_value_states = past_key_value[1].reshape([batch_size * self.num_attn_heads, -1, self.head_dim]) - main_value_states = paddle.concat((prev_main_value_states, main_value_states), axis=1) + prev_main_key_states = past_key_value[0].reshape( + [batch_size * self.num_attn_heads, -1, self.head_dim]) + main_key_states = paddle.concat( + (prev_main_key_states, main_key_states), axis=1) + prev_main_value_states = past_key_value[1].reshape( + [batch_size * self.num_attn_heads, -1, self.head_dim]) + main_value_states = paddle.concat( + (prev_main_value_states, main_value_states), axis=1) # Update cache - past_key_value = (paddle.reshape(main_key_states, (batch_size, self.num_attn_heads, -1, self.head_dim)), - paddle.reshape(main_value_states, (batch_size, self.num_attn_heads, -1, self.head_dim)),) + past_key_value = ( + paddle.reshape( + main_key_states, + (batch_size, self.num_attn_heads, -1, self.head_dim)), + paddle.reshape( + main_value_states, + (batch_size, self.num_attn_heads, -1, self.head_dim)), + ) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) # MAIN-STREAM # main attn weights - main_attn_weights = paddle.bmm(main_query_states, paddle.transpose(main_key_states, (0, 2, 1))) + main_attn_weights = paddle.bmm( + main_query_states, paddle.transpose(main_key_states, (0, 2, 1))) # retrieve relative position embeddings for each layer -> see paper for more details - main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(main_hidden_states, - main_attn_weights, - position_ids, - main_relative_position_buckets) + main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( + main_hidden_states, main_attn_weights, position_ids, + main_relative_position_buckets) main_attn_weights = main_attn_weights + main_relative_pos_embeddings if attention_mask is not None: main_attn_weights = main_attn_weights + attention_mask - main_attn_probs = F.softmax(main_attn_weights, axis=-1, dtype=main_attn_weights.dtype) + main_attn_probs = F.softmax(main_attn_weights, + axis=-1, + dtype=main_attn_weights.dtype) - main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) + main_attn_probs = F.dropout(main_attn_probs, + p=self.attention_dropout, + training=self.training) # project to attn_output main_attn_output = paddle.bmm(main_attn_probs, main_value_states) # reshape so that num_heads dim is merged into last `head_dim` axis - main_attn_output = (paddle.reshape(paddle.transpose( - paddle.reshape(main_attn_output, (batch_size, self.num_attn_heads, sequence_length, self.head_dim)) - , (0, 2, 1, 3)), (batch_size, 1, sequence_length, hidden_size))) + main_attn_output = (paddle.reshape( + paddle.transpose( + paddle.reshape(main_attn_output, + (batch_size, self.num_attn_heads, + sequence_length, self.head_dim)), (0, 2, 1, 3)), + (batch_size, 1, sequence_length, hidden_size))) main_attn_output = self.out_proj(main_attn_output) # PREDICT-STREAM # [ngram, B*head, T, c] - predict_query_states = paddle.reshape(paddle.concat(predict_query_states_list, axis=0), - (self.ngram, -1, sequence_length, self.head_dim)) + predict_query_states = paddle.reshape( + paddle.concat(predict_query_states_list, axis=0), + (self.ngram, -1, sequence_length, self.head_dim)) # [ngram, B*head, 2*T, c] - predict_key_states = paddle.concat([paddle.unsqueeze(paddle.concat([main_key_states, key], axis=1), axis=0) - for key in predict_key_states_list], axis=0) + predict_key_states = paddle.concat([ + paddle.unsqueeze(paddle.concat([main_key_states, key], axis=1), + axis=0) for key in predict_key_states_list + ], + axis=0) # [ngram, T, B, C] - predict_hidden_states = paddle.reshape(paddle.concat(hidden_states_predict_list, axis=0), - (self.ngram, sequence_length, batch_size, hidden_size)) + predict_hidden_states = paddle.reshape( + paddle.concat(hidden_states_predict_list, axis=0), + (self.ngram, sequence_length, batch_size, hidden_size)) # [ngram, B*head, 2*T, c] - predict_value_states = paddle.concat([paddle.unsqueeze(paddle.concat([main_value_states, v_p], axis=1), axis=0) - for v_p in predict_value_states_list], axis=0) + predict_value_states = paddle.concat([ + paddle.unsqueeze(paddle.concat([main_value_states, v_p], axis=1), + axis=0) for v_p in predict_value_states_list + ], + axis=0) # [ngram, B*head, T, 2*T] - predict_attn_weights = paddlenlp.ops.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states)) + predict_attn_weights = paddlenlp.ops.einsum( + "nbtc,nbsc->nbts", (predict_query_states, predict_key_states)) # [ngram, B*head, T, S] # retrieve relative position embeddings for each layer -> see paper for more details - predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(predict_hidden_states, - predict_attn_weights, position_ids, - predict_relative_position_buckets) + predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( + predict_hidden_states, predict_attn_weights, position_ids, + predict_relative_position_buckets) # [ngram, B*head, T, 2*T] predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings if extended_predict_attention_mask is not None: - predict_attn_weights = predict_attn_weights + paddle.cast(extended_predict_attention_mask, - predict_attn_weights.dtype) + predict_attn_weights = predict_attn_weights + paddle.cast( + extended_predict_attention_mask, predict_attn_weights.dtype) - predict_attn_probs = F.softmax(predict_attn_weights, axis=-1, dtype=predict_attn_weights.dtype) + predict_attn_probs = F.softmax(predict_attn_weights, + axis=-1, + dtype=predict_attn_weights.dtype) - predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) + predict_attn_probs = F.dropout(predict_attn_probs, + p=self.attention_dropout, + training=self.training) # project to attention output # [ngram, B*head, T, c] - predict_attn_output = paddlenlp.ops.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) + predict_attn_output = paddlenlp.ops.einsum( + "nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) # reshape so that num_heads dim is merged into last `head_dim` axis # [ngram, B, T, C] - predict_attn_output = (paddle.reshape(paddle.transpose(paddle.reshape(predict_attn_output, (self.ngram, - batch_size, - self.num_attn_heads, - sequence_length, - self.head_dim)), - (1, 0, 3, 2, 4)), - (batch_size, self.ngram, sequence_length, hidden_size))) + predict_attn_output = (paddle.reshape( + paddle.transpose( + paddle.reshape(predict_attn_output, + (self.ngram, batch_size, self.num_attn_heads, + sequence_length, self.head_dim)), + (1, 0, 3, 2, 4)), + (batch_size, self.ngram, sequence_length, hidden_size))) predict_attn_output = self.out_proj(predict_attn_output) # concat to single attn output # [B, 1+ngram*T, C] - attn_output = paddle.reshape(paddle.concat([main_attn_output, predict_attn_output], axis=1), - (batch_size, -1, hidden_size)) + attn_output = paddle.reshape( + paddle.concat([main_attn_output, predict_attn_output], axis=1), + (batch_size, -1, hidden_size)) # reshape into better form for `config.output_attentions` - main_attn_probs = paddle.reshape(main_attn_probs, (batch_size, self.num_attn_heads, sequence_length, -1)) - predict_attn_probs = paddle.transpose(paddle.reshape(predict_attn_probs, (self.ngram, - batch_size, - self.num_attn_heads, - sequence_length, -1)), - (1, 0, 2, 3, 4)) - - attn_output = F.dropout(attn_output, p=self.dropout, training=self.training) + main_attn_probs = paddle.reshape( + main_attn_probs, + (batch_size, self.num_attn_heads, sequence_length, -1)) + predict_attn_probs = paddle.transpose( + paddle.reshape(predict_attn_probs, + (self.ngram, batch_size, self.num_attn_heads, + sequence_length, -1)), (1, 0, 2, 3, 4)) + + attn_output = F.dropout(attn_output, + p=self.dropout, + training=self.training) return attn_output, main_attn_probs, predict_attn_probs, past_key_value - def get_main_relative_pos_embeddings(self, hidden_states, attn_weights, position_ids, + def get_main_relative_pos_embeddings(self, hidden_states, attn_weights, + position_ids, main_relative_position_buckets): # input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] if main_relative_position_buckets is None: batch_size, sequence_length = hidden_states.shape[:2] - relative_positions = (paddle.tile(paddle.unsqueeze(paddle.unsqueeze(paddle.arange(1, - attn_weights. - shape[-1] + 1), - axis=0), - axis=0), - repeat_times=[batch_size, sequence_length, 1])) - relative_positions = relative_positions - paddle.tile(paddle.unsqueeze(position_ids, axis=0), - repeat_times=[batch_size, sequence_length, - 1]) # [B, T, s] + relative_positions = (paddle.tile( + paddle.unsqueeze(paddle.unsqueeze(paddle.arange( + 1, attn_weights.shape[-1] + 1), + axis=0), + axis=0), + repeat_times=[batch_size, sequence_length, 1])) + relative_positions = relative_positions - paddle.tile( + paddle.unsqueeze(position_ids, axis=0), + repeat_times=[batch_size, sequence_length, 1]) # [B, T, s] main_relative_position_buckets = compute_relative_buckets( - self.num_buckets, self.relative_max_distance, relative_positions, False - ) - - rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] - rel_pos_embeddings = paddle.transpose(paddle.reshape(rel_pos_embeddings, - (rel_pos_embeddings.shape[:2] + - [self.num_buckets, self.num_attn_heads])), - (0, 3, 1, 2)) # [B,T,Buckets,head] - rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + [-1]) # [B*head,T,Buckets] - - main_relative_position_buckets = paddle.cast(paddle.reshape(paddle.tile(main_relative_position_buckets, - repeat_times=[1, self.num_attn_heads, - 1]), - (-1, main_relative_position_buckets.shape[-1])), - dtype=paddle.int64) # [B*head*T, T] - rel_pos_embeddings = paddle.reshape(rel_pos_embeddings, - (-1, rel_pos_embeddings.shape[-1])) # [B*head*T,Buckets] - - main_relative_position_buckets_index = paddle.tile(main_relative_position_buckets.unsqueeze(2), - repeat_times=[1, 1, 2]) + self.num_buckets, self.relative_max_distance, + relative_positions, False) + + rel_pos_embeddings = self.relative_pos_embeddings( + hidden_states) # [B,T,Buckets*head] + rel_pos_embeddings = paddle.transpose( + paddle.reshape(rel_pos_embeddings, + (rel_pos_embeddings.shape[:2] + + [self.num_buckets, self.num_attn_heads])), + (0, 3, 1, 2)) # [B,T,Buckets,head] + rel_pos_embeddings = rel_pos_embeddings.reshape( + attn_weights.shape[:2] + [-1]) # [B*head,T,Buckets] + + main_relative_position_buckets = paddle.cast( + paddle.reshape( + paddle.tile(main_relative_position_buckets, + repeat_times=[1, self.num_attn_heads, 1]), + (-1, main_relative_position_buckets.shape[-1])), + dtype=paddle.int64) # [B*head*T, T] + rel_pos_embeddings = paddle.reshape( + rel_pos_embeddings, + (-1, rel_pos_embeddings.shape[-1])) # [B*head*T,Buckets] + + main_relative_position_buckets_index = paddle.tile( + main_relative_position_buckets.unsqueeze(2), repeat_times=[1, 1, 2]) main_relative_position_buckets_index[:, :, 0] = \ paddle.tile(paddle.arange(0, main_relative_position_buckets_index.shape[0]).unsqueeze(1), repeat_times=[1, main_relative_position_buckets_index.shape[1]]) - main_relative_pos_embeddings = paddle.reshape(paddle.gather_nd(rel_pos_embeddings, - index=main_relative_position_buckets_index), - (attn_weights.shape[:2] + [-1])) + main_relative_pos_embeddings = paddle.reshape( + paddle.gather_nd(rel_pos_embeddings, + index=main_relative_position_buckets_index), + (attn_weights.shape[:2] + [-1])) return main_relative_pos_embeddings - def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, position_ids, + def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, + position_ids, predict_relative_position_buckets): # input hidden_states [ngram, T,B,C], # input attn_weights [ngram, B*head,T,S], @@ -609,48 +699,55 @@ def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, posit key_sequence_length = attn_weights.shape[-1] assert position_ids[0][0] == key_sequence_length - 1, \ "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)" - relative_positions = (paddle.tile(paddle.unsqueeze(paddle.unsqueeze(paddle.arange(0, key_sequence_length), - axis=0), - axis=0), - repeat_times=[batch_size, sequence_length, 1])) - - relative_positions = relative_positions - paddle.tile(paddle.unsqueeze(position_ids, axis=0), - repeat_times=[batch_size, sequence_length, 1]) - predict_relative_position_buckets = compute_relative_buckets(self.num_buckets, - self.relative_max_distance, - relative_positions, False) - - hidden_states = paddle.transpose(hidden_states, (0, 2, 1, 3)) # [ngram, B, T, C] - rel_pos_embeddings = paddle.reshape(self.relative_pos_embeddings(hidden_states), - hidden_states.shape[:-1] + - [self.num_buckets, - self.num_attn_heads]) # [ngram, B, T, bucket, head] - rel_pos_embeddings = paddle.reshape(paddle.transpose(rel_pos_embeddings, - (0, 1, 4, 2, 3)), - (self.ngram * batch_size * self.num_attn_heads, - sequence_length, -1)) # [ngram*B*head, T, bucket] - - predict_relative_position_buckets = paddle.tile(paddle.unsqueeze(predict_relative_position_buckets, - axis=0), - repeat_times=[self.ngram, 1, self.num_attn_heads, - 1]) # [ngram, B, head*T, S] - - rel_pos_embeddings = paddle.reshape(rel_pos_embeddings, (-1, rel_pos_embeddings.shape[-1])) - predict_relative_position_buckets = paddle.cast(paddle.reshape(predict_relative_position_buckets, - (-1, - predict_relative_position_buckets.shape[-1])), - dtype=paddle.int64) # [ngram*B*head*T, S] - - predict_relative_position_buckets_index = paddle.tile(predict_relative_position_buckets.unsqueeze(2), - repeat_times=[1, 1, 2]) + relative_positions = (paddle.tile( + paddle.unsqueeze(paddle.unsqueeze(paddle.arange( + 0, key_sequence_length), + axis=0), + axis=0), + repeat_times=[batch_size, sequence_length, 1])) + + relative_positions = relative_positions - paddle.tile( + paddle.unsqueeze(position_ids, axis=0), + repeat_times=[batch_size, sequence_length, 1]) + predict_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, + relative_positions, False) + + hidden_states = paddle.transpose(hidden_states, + (0, 2, 1, 3)) # [ngram, B, T, C] + rel_pos_embeddings = paddle.reshape( + self.relative_pos_embeddings(hidden_states), + hidden_states.shape[:-1] + [self.num_buckets, self.num_attn_heads + ]) # [ngram, B, T, bucket, head] + rel_pos_embeddings = paddle.reshape( + paddle.transpose(rel_pos_embeddings, (0, 1, 4, 2, 3)), + (self.ngram * batch_size * self.num_attn_heads, sequence_length, + -1)) # [ngram*B*head, T, bucket] + + predict_relative_position_buckets = paddle.tile( + paddle.unsqueeze(predict_relative_position_buckets, axis=0), + repeat_times=[self.ngram, 1, self.num_attn_heads, + 1]) # [ngram, B, head*T, S] + + rel_pos_embeddings = paddle.reshape(rel_pos_embeddings, + (-1, rel_pos_embeddings.shape[-1])) + predict_relative_position_buckets = paddle.cast( + paddle.reshape(predict_relative_position_buckets, + (-1, predict_relative_position_buckets.shape[-1])), + dtype=paddle.int64) # [ngram*B*head*T, S] + + predict_relative_position_buckets_index = paddle.tile( + predict_relative_position_buckets.unsqueeze(2), + repeat_times=[1, 1, 2]) predict_relative_position_buckets_index[:, :, 0] = \ paddle.tile(paddle.arange(0, predict_relative_position_buckets_index.shape[0]).unsqueeze(1), repeat_times=[1, predict_relative_position_buckets_index.shape[1]]) - predict_relative_pos_embeddings = paddle.reshape(paddle.gather_nd(rel_pos_embeddings, - index=predict_relative_position_buckets_index), - (self.ngram, batch_size * self.num_attn_heads, - sequence_length, -1)) # [ngram, B*head, T, S] + predict_relative_pos_embeddings = paddle.reshape( + paddle.gather_nd(rel_pos_embeddings, + index=predict_relative_position_buckets_index), + (self.ngram, batch_size * self.num_attn_heads, sequence_length, + -1)) # [ngram, B*head, T, S] return predict_relative_pos_embeddings @@ -659,36 +756,34 @@ class ProphetNetEncoderLayer(Layer): """ Encoder block for Prophetnet """ - - def __init__(self, - hidden_size, - encoder_ffn_dim, - activation_function, - activation_dropout, - attention_dropout, - dropout, + def __init__(self, hidden_size, encoder_ffn_dim, activation_function, + activation_dropout, attention_dropout, dropout, num_encoder_attention_heads): super(ProphetNetEncoderLayer, self).__init__() # 1st residual block - self.self_attn = ProphetNetAttention(hidden_size, attention_dropout, dropout, num_encoder_attention_heads) + self.self_attn = ProphetNetAttention(hidden_size, attention_dropout, + dropout, + num_encoder_attention_heads) self.self_attn_layer_norm = nn.LayerNorm(hidden_size) # 2nd residual block - self.feed_forward = ProphetNetFeedForward(hidden_size, activation_function, activation_dropout, dropout, + self.feed_forward = ProphetNetFeedForward(hidden_size, + activation_function, + activation_dropout, dropout, encoder_ffn_dim) self.feed_forward_layer_norm = nn.LayerNorm(hidden_size) - def forward(self, - hidden_states, - attention_mask): + def forward(self, hidden_states, attention_mask): # 1st residual block attention_output, _ = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) + hidden_states = self.self_attn_layer_norm(attention_output + + hidden_states) # 2nd residual block feed_forward_output = self.feed_forward(hidden_states) - hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + + hidden_states) return hidden_states @@ -696,32 +791,28 @@ class ProphetNetDecoderLayer(Layer): """ Decoder block for Prophetnet """ - - def __init__(self, - hidden_size, - num_buckets, - relative_max_distance, - num_decoder_attention_heads, - activation_function, - activation_dropout, - dropout, - attention_dropout, - ngram, - decoder_ffn_dim, - add_cross_attention): + def __init__(self, hidden_size, num_buckets, relative_max_distance, + num_decoder_attention_heads, activation_function, + activation_dropout, dropout, attention_dropout, ngram, + decoder_ffn_dim, add_cross_attention): super(ProphetNetDecoderLayer, self).__init__() # 1st residual block - self.self_attn = ProphetNetNgramSelfAttention(hidden_size, num_buckets, relative_max_distance, - num_decoder_attention_heads, dropout, attention_dropout, ngram) + self.self_attn = ProphetNetNgramSelfAttention( + hidden_size, num_buckets, relative_max_distance, + num_decoder_attention_heads, dropout, attention_dropout, ngram) self.self_attn_layer_norm = nn.LayerNorm(hidden_size) # 2nd residual block if add_cross_attention: - self.cross_attn = ProphetNetAttention(hidden_size, attention_dropout, dropout, num_decoder_attention_heads) + self.cross_attn = ProphetNetAttention(hidden_size, + attention_dropout, dropout, + num_decoder_attention_heads) self.cross_attn_layer_norm = nn.LayerNorm(hidden_size) # 3rd residual block - self.feed_forward = ProphetNetFeedForward(hidden_size, activation_function, activation_dropout, dropout, + self.feed_forward = ProphetNetFeedForward(hidden_size, + activation_function, + activation_dropout, dropout, decoder_ffn_dim) self.feed_forward_layer_norm = nn.LayerNorm(hidden_size) @@ -738,7 +829,8 @@ def forward(self, use_cache: bool = True): # 1st residual block # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, @@ -748,10 +840,12 @@ def forward(self, predict_relative_position_buckets=predict_relative_position_buckets, position_ids=position_ids, ) - hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) + hidden_states = self.self_attn_layer_norm(hidden_states + + ngram_attention_output) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None if encoder_hidden_states is not None: # 2nd residual block attention_output, cross_attn_present_key_value = self.cross_attn( @@ -760,19 +854,21 @@ def forward(self, attention_mask=encoder_attn_mask, past_key_value=cross_attn_past_key_value, ) - hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) + hidden_states = self.cross_attn_layer_norm(attention_output + + hidden_states) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # 3rd residual block feed_forward_output = self.feed_forward(hidden_states) - hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + + hidden_states) - outputs = (hidden_states,) + outputs = (hidden_states, ) if use_cache: - outputs += (present_key_value,) + outputs += (present_key_value, ) return outputs @@ -783,88 +879,71 @@ class ProphetNetEncoder(ProphetNetPretrainedModel): The word embedding parameters. This can be used to initialize :class:`~transformers.ProphetNetEncoder` with pre-defined word embeddings instead of randomly initialized word embeddings. """ - - def __init__(self, - word_embeddings, - vocab_size, - hidden_size, - pad_token_id, - max_position_embeddings, - encoder_ffn_dim, - activation_function, - activation_dropout, - attention_dropout, - dropout, - num_encoder_attention_heads, - num_encoder_layers, - init_std): + def __init__(self, word_embeddings, vocab_size, hidden_size, pad_token_id, + max_position_embeddings, encoder_ffn_dim, activation_function, + activation_dropout, attention_dropout, dropout, + num_encoder_attention_heads, num_encoder_layers, init_std): super(ProphetNetEncoder, self).__init__() self.init_std = init_std if word_embeddings is not None: self.word_embeddings = word_embeddings else: - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) + self.word_embeddings = nn.Embedding(vocab_size, + hidden_size, + padding_idx=pad_token_id) - self.position_embeddings = ProphetNetPositionalEmbeddings(max_position_embeddings, hidden_size, pad_token_id) + self.position_embeddings = ProphetNetPositionalEmbeddings( + max_position_embeddings, hidden_size, pad_token_id) self.embeddings_layer_norm = nn.LayerNorm(hidden_size) - self.layers = nn.LayerList([ProphetNetEncoderLayer(hidden_size, - encoder_ffn_dim, - activation_function, - activation_dropout, - attention_dropout, - dropout, - num_encoder_attention_heads) for _ in - range(num_encoder_layers)]) + self.layers = nn.LayerList([ + ProphetNetEncoderLayer(hidden_size, encoder_ffn_dim, + activation_function, activation_dropout, + attention_dropout, dropout, + num_encoder_attention_heads) + for _ in range(num_encoder_layers) + ]) self.apply(self.init_weights) - def forward(self, - input_ids=None, - attention_mask=None): + def forward(self, input_ids=None, attention_mask=None): if input_ids is None: raise ValueError("Input_ids cannot be None.") inputs_embeds = self.word_embeddings(input_ids) # prepare attention mask if attention_mask is not None: - extended_attention_mask = (paddle.tile(1.0 - attention_mask.unsqueeze(1), - repeat_times=[self.config["num_encoder_attention_heads"], - 1, 1])) * -10000.0 - extended_attention_mask = paddle.cast(extended_attention_mask, dtype=inputs_embeds.dtype) + extended_attention_mask = (paddle.tile( + 1.0 - attention_mask.unsqueeze(1), + repeat_times=[self.config["num_encoder_attention_heads"], 1, 1 + ])) * -10000.0 + extended_attention_mask = paddle.cast(extended_attention_mask, + dtype=inputs_embeds.dtype) extended_attention_mask.stop_gradient = True else: extended_attention_mask = None - position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2]) + position_embeddings, position_ids = self.position_embeddings( + inputs_embeds.shape[:2]) hidden_states = inputs_embeds + position_embeddings hidden_states = self.embeddings_layer_norm(hidden_states) - hidden_states = F.dropout(hidden_states, p=self.config["dropout"], training=self.training) + hidden_states = F.dropout(hidden_states, + p=self.config["dropout"], + training=self.training) for idx, encoder_layer in enumerate(self.layers): - hidden_states = encoder_layer(hidden_states, attention_mask=extended_attention_mask) + hidden_states = encoder_layer( + hidden_states, attention_mask=extended_attention_mask) return hidden_states class ProphetNetDecoder(ProphetNetPretrainedModel): - def __init__(self, - word_embeddings, - vocab_size, - hidden_size, - pad_token_id, - max_position_embeddings, - relative_max_distance, - ngram, - num_buckets, - num_decoder_attention_heads, - decoder_ffn_dim, - activation_function, - activation_dropout, - dropout, - attention_dropout, - add_cross_attention, - num_decoder_layers, + def __init__(self, word_embeddings, vocab_size, hidden_size, pad_token_id, + max_position_embeddings, relative_max_distance, ngram, + num_buckets, num_decoder_attention_heads, decoder_ffn_dim, + activation_function, activation_dropout, dropout, + attention_dropout, add_cross_attention, num_decoder_layers, init_std): super(ProphetNetDecoder, self).__init__() self.init_std = init_std @@ -877,23 +956,23 @@ def __init__(self, if word_embeddings is not None: self.word_embeddings = word_embeddings else: - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) + self.word_embeddings = nn.Embedding(vocab_size, + hidden_size, + padding_idx=pad_token_id) - self.position_embeddings = ProphetNetPositionalEmbeddings(max_position_embeddings, hidden_size, pad_token_id) + self.position_embeddings = ProphetNetPositionalEmbeddings( + max_position_embeddings, hidden_size, pad_token_id) self.ngram_embeddings = nn.Embedding(self.ngram, hidden_size) - self.layers = nn.LayerList([ProphetNetDecoderLayer(hidden_size, - num_buckets, - relative_max_distance, - num_decoder_attention_heads, - activation_function, - activation_dropout, - dropout, - attention_dropout, - ngram, - decoder_ffn_dim, - add_cross_attention) for _ in - range(num_decoder_layers)]) + self.layers = nn.LayerList([ + ProphetNetDecoderLayer(hidden_size, num_buckets, + relative_max_distance, + num_decoder_attention_heads, + activation_function, activation_dropout, + dropout, attention_dropout, ngram, + decoder_ffn_dim, add_cross_attention) + for _ in range(num_decoder_layers) + ]) self.embeddings_layer_norm = nn.LayerNorm(hidden_size) self.apply(self.init_weights) @@ -910,15 +989,16 @@ def forward(self, inputs_embeds = self.word_embeddings(input_ids) batch_size, sequence_length = inputs_embeds.shape[:2] - main_stream_pos_embed, position_ids = self.position_embeddings((batch_size, sequence_length), - past_key_values=past_key_values) + main_stream_pos_embed, position_ids = self.position_embeddings( + (batch_size, sequence_length), past_key_values=past_key_values) if past_key_values is not None: main_relative_position_buckets, predict_relative_position_buckets = None, None else: main_relative_position_buckets, predict_relative_position_buckets = self.compute_buffered_relative_buckets( position_ids) - predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1) + predicting_stream_pos_embed = self.position_embeddings._forward( + position_ids + 1) # add position embeddings hidden_states = inputs_embeds + main_stream_pos_embed @@ -930,80 +1010,105 @@ def forward(self, assert hidden_states.shape[1] == 1, \ "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" - ngram_hidden_states = [paddle.tile((ngram_embeddings[ngram - 1] + predicting_stream_pos_embed), - repeat_times=[batch_size, 1, 1]) - for ngram in range(self.ngram)] + ngram_hidden_states = [ + paddle.tile( + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed), + repeat_times=[batch_size, 1, 1]) + for ngram in range(self.ngram) + ] extended_attention_mask = None extended_predict_attention_mask = None else: - ngram_hidden_states = [(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) - for ngram in range(self.ngram)] - extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) - extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) + for ngram in range(self.ngram) + ] + extended_attention_mask = self.prepare_attention_mask( + hidden_states, attention_mask) + extended_predict_attention_mask = self.prepare_predict_attention_mask( + hidden_states, attention_mask) extended_attention_mask.stop_gradient = True extended_predict_attention_mask.stop_gradient = True # prepare encoder attention mask if encoder_attention_mask is not None: - extended_encoder_attention_mask = (1.0 - paddle.tile(encoder_attention_mask[:, None, :], - repeat_times=[ - self.config["num_decoder_attention_heads"], - 1, 1])) * -10000.0 - extended_encoder_attention_mask = paddle.cast(extended_encoder_attention_mask, dtype=inputs_embeds.dtype) + extended_encoder_attention_mask = (1.0 - paddle.tile( + encoder_attention_mask[:, None, :], + repeat_times=[self.config["num_decoder_attention_heads"], 1, 1 + ])) * -10000.0 + extended_encoder_attention_mask = paddle.cast( + extended_encoder_attention_mask, dtype=inputs_embeds.dtype) else: extended_encoder_attention_mask = None - hidden_states = paddle.concat([hidden_states] + ngram_hidden_states, axis=1) + hidden_states = paddle.concat([hidden_states] + ngram_hidden_states, + axis=1) if self.embeddings_layer_norm: hidden_states = self.embeddings_layer_norm(hidden_states) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, + p=self.dropout, + training=self.training) present_key_values = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = past_key_values[ + idx] if past_key_values is not None else None - layer_outputs = decoder_layer(hidden_states, - attention_mask=extended_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attn_mask=extended_encoder_attention_mask, - extended_predict_attention_mask=extended_predict_attention_mask, - main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - past_key_value=past_key_value, - use_cache=use_cache) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets= + predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache) hidden_states = layer_outputs[0] if use_cache: - present_key_values += (layer_outputs[1],) + present_key_values += (layer_outputs[1], ) last_hidden_state = hidden_states[:, :sequence_length] # 1-gram - last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.ngram > 0 else None # 2-gram - return tuple(v for v in [last_hidden_state, - last_hidden_state_ngram, - present_key_values] if v is not None) + last_hidden_state_ngram = hidden_states[:, + sequence_length:] if self.ngram > 0 else None # 2-gram + return tuple( + v for v in + [last_hidden_state, last_hidden_state_ngram, present_key_values] + if v is not None) def compute_buffered_relative_buckets(self, position_ids): batch_size, sequence_length = position_ids.shape - if not hasattr(self, '_main_relative_buckets') or self._main_relative_buckets is None: - position_ids = paddle.tile(paddle.arange(1, self.max_target_positions + 1), repeat_times=[1, 1]) + if not hasattr(self, '_main_relative_buckets' + ) or self._main_relative_buckets is None: + position_ids = paddle.tile(paddle.arange( + 1, self.max_target_positions + 1), + repeat_times=[1, 1]) self._main_relative_buckets, self._predict_relative_buckets = compute_all_stream_relative_buckets( self.num_buckets, self.relative_max_distance, position_ids) # buffer relative buckets - main_relative_buckets = paddle.tile(self._main_relative_buckets[:, :sequence_length, :sequence_length], - repeat_times=[batch_size, 1, 1]) - predict_relative_buckets = paddle.tile(paddle.concat( - [self._predict_relative_buckets[:, :sequence_length, :sequence_length], - self._predict_relative_buckets[:, :sequence_length, - self.max_target_positions: self.max_target_positions + sequence_length]], axis=2), + main_relative_buckets = paddle.tile( + self._main_relative_buckets[:, :sequence_length, :sequence_length], repeat_times=[batch_size, 1, 1]) + predict_relative_buckets = paddle.tile(paddle.concat([ + self. + _predict_relative_buckets[:, :sequence_length, :sequence_length], + self._predict_relative_buckets[:, :sequence_length, + self.max_target_positions:self. + max_target_positions + + sequence_length] + ], + axis=2), + repeat_times=[batch_size, 1, 1]) return main_relative_buckets, predict_relative_buckets @@ -1012,48 +1117,65 @@ def prepare_attention_mask(self, hidden_states, attention_mask): # get causal mask if not hasattr(self, '_causal_mask') or self._causal_mask is None: - causal_mask = paddle.full((self.max_target_positions, self.max_target_positions), -float("inf"), - dtype=hidden_states.dtype) + causal_mask = paddle.full( + (self.max_target_positions, self.max_target_positions), + -float("inf"), + dtype=hidden_states.dtype) self._causal_mask = paddle.triu(causal_mask, 1) - extended_causal_mask = paddle.expand(self._causal_mask[:seq_length, :seq_length].unsqueeze(0), - shape=[batch_size, seq_length, seq_length]) + extended_causal_mask = paddle.expand( + self._causal_mask[:seq_length, :seq_length].unsqueeze(0), + shape=[batch_size, seq_length, seq_length]) # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask.unsqueeze(1)) * -10000.0 + extended_attention_mask = (1.0 - + attention_mask.unsqueeze(1)) * -10000.0 extended_attention_mask = extended_causal_mask + extended_attention_mask else: extended_attention_mask = extended_causal_mask - return paddle.cast(paddle.tile(extended_attention_mask, - repeat_times=[self.config["num_decoder_attention_heads"], 1, 1]), + return paddle.cast(paddle.tile( + extended_attention_mask, + repeat_times=[self.config["num_decoder_attention_heads"], 1, 1]), dtype=hidden_states.dtype) def prepare_predict_attention_mask(self, hidden_states, attention_mask): batch_size, seq_length = hidden_states.shape[:2] # get causal mask - if not hasattr(self, '_predict_causal_mask') or self._predict_causal_mask is None: - self._predict_causal_mask = ngram_attention_bias(self.max_target_positions, self.ngram, hidden_states.dtype) - predict_causal_mask = paddle.concat([self._predict_causal_mask[:, :seq_length, :seq_length], - self._predict_causal_mask[:, :seq_length, - self.max_target_positions: self.max_target_positions + seq_length]], + if not hasattr( + self, + '_predict_causal_mask') or self._predict_causal_mask is None: + self._predict_causal_mask = ngram_attention_bias( + self.max_target_positions, self.ngram, hidden_states.dtype) + predict_causal_mask = paddle.concat([ + self._predict_causal_mask[:, :seq_length, :seq_length], + self._predict_causal_mask[:, :seq_length, + self.max_target_positions:self. + max_target_positions + seq_length] + ], axis=-1) - extended_predict_causal_mask = paddle.expand(predict_causal_mask[:, None, :, :], - shape=predict_causal_mask.shape[:1] + [ - batch_size] + predict_causal_mask.shape[1:]) + extended_predict_causal_mask = paddle.expand( + predict_causal_mask[:, None, :, :], + shape=predict_causal_mask.shape[:1] + [batch_size] + + predict_causal_mask.shape[1:]) # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * -10000.0 - extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) + extended_attention_mask = ( + 1.0 - attention_mask[None, :, None, :]) * -10000.0 + extended_attention_mask = extended_attention_mask.expand( + (self.ngram, batch_size, seq_length, seq_length)) # predicted stream attention_mask should always be 0 - extended_attention_mask = paddle.concat([extended_attention_mask, - paddle.zeros_like(extended_attention_mask)], + extended_attention_mask = paddle.concat([ + extended_attention_mask, + paddle.zeros_like(extended_attention_mask) + ], axis=-1) extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask else: extended_predict_attention_mask = extended_predict_causal_mask - return paddle.cast(extended_predict_attention_mask.tile([1, self.config["num_decoder_attention_heads"], 1, 1]), + return paddle.cast(extended_predict_attention_mask.tile( + [1, self.config["num_decoder_attention_heads"], 1, 1]), dtype=hidden_states.dtype) @@ -1090,39 +1212,22 @@ def __init__(self, self.pad_token_id = pad_token_id self.disable_ngram_loss = disable_ngram_loss self.decoder_start_token_id = decoder_start_token_id - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) - - self.encoder = ProphetNetEncoder(self.word_embeddings, - vocab_size, - hidden_size, - pad_token_id, - max_position_embeddings, - encoder_ffn_dim, - activation_function, - activation_dropout, - attention_dropout, - dropout, - num_encoder_attention_heads, - num_encoder_layers, - init_std) - - self.decoder = ProphetNetDecoder(self.word_embeddings, - vocab_size, - hidden_size, - pad_token_id, - max_position_embeddings, - relative_max_distance, - ngram, - num_buckets, - num_decoder_attention_heads, - decoder_ffn_dim, - activation_function, - activation_dropout, - dropout, - attention_dropout, - add_cross_attention, - num_decoder_layers, - init_std) + self.word_embeddings = nn.Embedding(vocab_size, + hidden_size, + padding_idx=pad_token_id) + + self.encoder = ProphetNetEncoder( + self.word_embeddings, vocab_size, hidden_size, pad_token_id, + max_position_embeddings, encoder_ffn_dim, activation_function, + activation_dropout, attention_dropout, dropout, + num_encoder_attention_heads, num_encoder_layers, init_std) + + self.decoder = ProphetNetDecoder( + self.word_embeddings, vocab_size, hidden_size, pad_token_id, + max_position_embeddings, relative_max_distance, ngram, num_buckets, + num_decoder_attention_heads, decoder_ffn_dim, activation_function, + activation_dropout, dropout, attention_dropout, add_cross_attention, + num_decoder_layers, init_std) self.apply(self.init_weights) @@ -1143,9 +1248,8 @@ def forward(self, if attention_mask is None: assert input_ids is not None, "input_ids should be " \ "specified when generating attention_mask" - attention_mask = paddle.cast( - input_ids != self.pad_token_id, - dtype=paddle.get_default_dtype()) + attention_mask = paddle.cast(input_ids != self.pad_token_id, + dtype=paddle.get_default_dtype()) if decoder_attention_mask is None: assert decoder_input_ids is not None, "decoder_input_ids should be " \ @@ -1162,28 +1266,22 @@ def forward(self, encoder_attention_mask=attention_mask, use_cache=use_cache, past_key_values=past_key_values) - return decoder_outputs + (encoder_output,) + return decoder_outputs + (encoder_output, ) class Linear_wo_bias(Layer): - def __init__(self, - in_features, - out_features, - weight_attr=None, - name=None): + def __init__(self, in_features, out_features, weight_attr=None, name=None): super(Linear_wo_bias, self).__init__() self._dtype = self._helper.get_default_dtype() self._weight_attr = weight_attr - self.weight = self.create_parameter( - shape=[in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) + self.weight = self.create_parameter(shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) self.name = name def forward(self, input): - out = F.linear( - x=input, weight=self.weight, name=self.name) + out = F.linear(x=input, weight=self.weight, name=self.name) return out def extra_repr(self): @@ -1198,7 +1296,8 @@ def __init__(self, prophetnet): self.prophetnet = prophetnet self.padding_idx = prophetnet.word_embeddings._padding_idx - self.lm_head = Linear_wo_bias(self.prophetnet.config["hidden_size"], self.prophetnet.config["vocab_size"]) + self.lm_head = Linear_wo_bias(self.prophetnet.config["hidden_size"], + self.prophetnet.config["vocab_size"]) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -1225,8 +1324,9 @@ def forward(self, batch_size, sequence_length = decoder_input_ids.shape - predicting_streams = paddle.reshape(outputs[1], - (batch_size, self.prophetnet.config["ngram"], sequence_length, -1)) + predicting_streams = paddle.reshape( + outputs[1], + (batch_size, self.prophetnet.config["ngram"], sequence_length, -1)) predict_logits = self.lm_head(predicting_streams) logits = predict_logits[:, 0] @@ -1248,13 +1348,16 @@ def prepare_inputs_for_generation(self, decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) # first step, decoder_cached_states are empty - return {"input_ids": None, # encoder_outputs is defined. input_ids not needed - "decoder_input_ids": decoder_input_ids, - "encoder_output": encoder_output, - "decoder_attention_mask": decoder_attention_mask, - "attention_mask": attention_mask, - "use_cache": use_cache, - "past_key_values": cache} + return { + "input_ids": + None, # encoder_outputs is defined. input_ids not needed + "decoder_input_ids": decoder_input_ids, + "encoder_output": encoder_output, + "decoder_attention_mask": decoder_attention_mask, + "attention_mask": attention_mask, + "use_cache": use_cache, + "past_key_values": cache + } def prepare_decoder_input_ids_from_labels(self, labels): return self._shift_right(labels) diff --git a/paddlenlp/transformers/prophetnet/tokenizer.py b/paddlenlp/transformers/prophetnet/tokenizer.py index 90951276e541..ee9d1c04e031 100644 --- a/paddlenlp/transformers/prophetnet/tokenizer.py +++ b/paddlenlp/transformers/prophetnet/tokenizer.py @@ -26,7 +26,6 @@ class Trie: Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass Loose reference https://en.wikipedia.org/wiki/Trie """ - def __init__(self): self.data = {} @@ -143,7 +142,9 @@ def split(self, text: str) -> List[str]: # It wasn't updated yet so indices are current ones lookahead_index = current end = current - next_char = text[lookahead_index] if lookahead_index < len(text) else None + next_char = text[ + lookahead_index] if lookahead_index < len( + text) else None while next_char in looktrie_pointer: looktrie_pointer = looktrie_pointer[next_char] lookahead_index += 1 @@ -297,15 +298,20 @@ def __init__(self, x_sep_token="[X_SEP]", pad_token="[PAD]", mask_token="[MASK]"): - self.unique_no_split_tokens = [x_sep_token, unk_token, sep_token, bos_token, eos_token, cls_token, pad_token, - mask_token] + self.unique_no_split_tokens = [ + x_sep_token, unk_token, sep_token, bos_token, eos_token, cls_token, + pad_token, mask_token + ] self.tokens_trie = create_trie(self.unique_no_split_tokens) self.vocab = load_vocab(vocab_file) - self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=unk_token) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, + unk_token=unk_token) @property def vocab_size(self): @@ -417,7 +423,10 @@ def convert_tokens_to_string(self, tokens): def convert_ids_to_string(self, ids): return self.convert_tokens_to_string(self.convert_ids_to_tokens(ids)) - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + def get_special_tokens_mask(self, + token_ids_0, + token_ids_1=None, + already_has_special_tokens=False): """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -435,14 +444,17 @@ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_spe """ if already_has_special_tokens: return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) if token_ids_1 is None: return ([0] * len(token_ids_0)) + [1] return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ProphetNet sequence pair mask has the following format: @@ -469,7 +481,9 @@ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): return len(token_ids_0 + sep) * [0] return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + def build_inputs_with_special_tokens(self, + token_ids_0, + token_ids_1=None) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: @@ -493,12 +507,15 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> Lis def save_vocabulary(self, save_directory): index = 0 - vocab_file = os.path.join(save_directory, self.resource_files_names["vocab_file"]) + vocab_file = os.path.join(save_directory, + self.resource_files_names["vocab_file"]) with open(vocab_file, "w", encoding="utf-8") as writer: - for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + for token, token_index in sorted(self.vocab.items(), + key=lambda kv: kv[1]): if index != token_index: - logging.warning(f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." - " Please check that the vocabulary is not corrupted!") + logging.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!") index = token_index writer.write(token + "\n") index += 1 From 74e7318e1bcb81f9e103c38294325ff70deb4d01 Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Mon, 21 Feb 2022 20:18:52 +0800 Subject: [PATCH 4/7] pre commit --- paddlenlp/transformers/prophetnet/modeling.py | 455 ++++++++++-------- .../transformers/prophetnet/tokenizer.py | 19 +- 2 files changed, 252 insertions(+), 222 deletions(-) diff --git a/paddlenlp/transformers/prophetnet/modeling.py b/paddlenlp/transformers/prophetnet/modeling.py index 61520cd91e67..365788771e4c 100644 --- a/paddlenlp/transformers/prophetnet/modeling.py +++ b/paddlenlp/transformers/prophetnet/modeling.py @@ -42,8 +42,8 @@ def ngram_attention_bias(sequence_length, ngram, dtype): for stream_idx in range(ngram): right_block[stream_idx] = right_block[stream_idx].fill_diagonal_( 0, wrap=False) - left_block[stream_idx] = paddle.triu(left_block[stream_idx], - diagonal=-stream_idx + 1) + left_block[stream_idx] = paddle.triu( + left_block[stream_idx], diagonal=-stream_idx + 1) left_block[:, :, 0] = 0 return paddle.concat([left_block, right_block], axis=2) @@ -68,8 +68,9 @@ def compute_relative_buckets(num_buckets, inv_relative_positions = paddle.abs(inv_relative_positions) else: inv_relative_positions = paddle.cast( - paddle.less_than(paddle.zeros_like(inv_relative_positions), - inv_relative_positions), + paddle.less_than( + paddle.zeros_like(inv_relative_positions), + inv_relative_positions), dtype=paddle.int32) * inv_relative_positions max_exact = num_buckets // 2 @@ -77,17 +78,21 @@ def compute_relative_buckets(num_buckets, inv_relative_positions, paddle.to_tensor(max_exact).cast(dtype=paddle.int32)) val_if_large = max_exact + paddle.log( - paddle.cast(inv_relative_positions, dtype=paddle.float32) / max_exact - ) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - val_if_large_num_buckets = paddle.ones_like(val_if_large) * (num_buckets - - 1) - val_if_large_lt = paddle.cast(paddle.less_than(val_if_large, - val_if_large_num_buckets), - dtype=paddle.int32) + paddle.cast( + inv_relative_positions, dtype=paddle.float32) / + max_exact) / math.log(max_distance / max_exact) * (num_buckets - + max_exact) + val_if_large_num_buckets = paddle.ones_like(val_if_large) * ( + num_buckets - 1) + val_if_large_lt = paddle.cast( + paddle.less_than(val_if_large, val_if_large_num_buckets), + dtype=paddle.int32) val_if_large = paddle.cast(val_if_large_lt * val_if_large, dtype=paddle.int32) + \ (1 - val_if_large_lt) * val_if_large_num_buckets rel_positions_bucket = rel_positions_bucket + paddle.where( - is_small, paddle.cast(inv_relative_positions, dtype=paddle.int32), + is_small, + paddle.cast( + inv_relative_positions, dtype=paddle.int32), val_if_large) return rel_positions_bucket @@ -99,15 +104,16 @@ def compute_all_stream_relative_buckets(num_buckets, max_distance, """ # main stream main_stream_relative_positions = paddle.tile( - paddle.unsqueeze(position_ids, axis=1), + paddle.unsqueeze( + position_ids, axis=1), repeat_times=[1, position_ids.shape[-1], 1]) main_stream_relative_positions = main_stream_relative_positions - paddle.unsqueeze( position_ids, axis=-1) # predicting stream - predicting_stream_relative_positions = paddle.unsqueeze(paddle.concat( - [position_ids - 1, position_ids], axis=-1), - axis=1) + predicting_stream_relative_positions = paddle.unsqueeze( + paddle.concat( + [position_ids - 1, position_ids], axis=-1), axis=1) predicting_stream_relative_positions = paddle.tile( predicting_stream_relative_positions, repeat_times=[1, position_ids.shape[-1], 1]) @@ -201,15 +207,16 @@ def _shift_right(self, input_ids): assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids_mask = paddle.cast(shifted_input_ids == -100, - dtype=paddle.int32) + shifted_input_ids_mask = paddle.cast( + shifted_input_ids == -100, dtype=paddle.int32) shifted_input_ids = shifted_input_ids_mask * pad_token_id + ( 1 - shifted_input_ids_mask) * shifted_input_ids assert paddle.sum( - paddle.cast(shifted_input_ids >= 0, dtype=paddle.int32)).item( - ) == shifted_input_ids.shape[ - -1], "Verify that `shifted_input_ids` has only positive values" + paddle.cast( + shifted_input_ids >= 0, dtype=paddle.int32) + ).item() == shifted_input_ids.shape[ + -1], "Verify that `shifted_input_ids` has only positive values" return shifted_input_ids @@ -218,6 +225,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding): """ ProphetNetPositional Embeddings. """ + def __init__(self, max_position_embeddings, hidden_size, pad_token_id): self.max_length = max_position_embeddings super(ProphetNetPositionalEmbeddings, @@ -246,14 +254,15 @@ def forward(self, # retrieve position_ids from input_ids / attention_mask position_ids = paddle.cast( - paddle.cast(paddle.cumsum(attention_mask, axis=1), - dtype=attention_mask.dtype) * attention_mask, + paddle.cast( + paddle.cumsum( + attention_mask, axis=1), + dtype=attention_mask.dtype) * attention_mask, dtype=paddle.int64) + self._padding_idx # make sure position_ids are not bigger then max_length - position_ids = paddle.clip(position_ids, - min=0, - max=self.max_length - 1) + position_ids = paddle.clip( + position_ids, min=0, max=self.max_length - 1) return super().forward(position_ids), position_ids @@ -265,7 +274,11 @@ class ProphetNetAttention(Layer): """ Multi-headed attention from 'Attention Is All You Need' paper. """ - def __init__(self, hidden_size, attention_dropout, dropout, + + def __init__(self, + hidden_size, + attention_dropout, + dropout, num_attn_heads: int): super().__init__() hidden_size = hidden_size @@ -290,13 +303,12 @@ def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): [bsz, seq_len, self.num_attn_heads, self.head_dim]), (0, 2, 1, 3)) - def forward( - self, - hidden_states, - key_value_states: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - past_key_value: Optional[Tuple[Tensor]] = None - ) -> Tuple[Tensor, Optional[Tensor]]: + def forward(self, + hidden_states, + key_value_states: Optional[Tensor]=None, + attention_mask: Optional[Tensor]=None, + past_key_value: Optional[Tuple[Tensor]]=None) -> Tuple[ + Tensor, Optional[Tensor]]: batch_size, tgt_len, hidden_size = hidden_states.shape @@ -315,16 +327,16 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.key_proj(key_value_states), -1, - batch_size) - value_states = self._shape(self.value_proj(key_value_states), -1, - batch_size) + key_states = self._shape( + self.key_proj(key_value_states), -1, batch_size) + value_states = self._shape( + self.value_proj(key_value_states), -1, batch_size) else: # self_attention - key_states = self._shape(self.key_proj(hidden_states), -1, - batch_size) - value_states = self._shape(self.value_proj(hidden_states), -1, - batch_size) + key_states = self._shape( + self.key_proj(hidden_states), -1, batch_size) + value_states = self._shape( + self.value_proj(hidden_states), -1, batch_size) if is_cross_attention: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -356,9 +368,8 @@ def forward( attn_weights = F.softmax(attn_weights, axis=-1) - attn_probs = F.dropout(attn_weights, - p=self.attention_dropout, - training=self.training) + attn_probs = F.dropout( + attn_weights, p=self.attention_dropout, training=self.training) attn_output = paddle.bmm(attn_probs, value_states) assert attn_output.shape == [batch_size * self.num_attn_heads, tgt_len, self.head_dim, ], \ @@ -373,9 +384,8 @@ def forward( attn_output = self.out_proj(attn_output) - attn_output = F.dropout(attn_output, - p=self.dropout, - training=self.training) + attn_output = F.dropout( + attn_output, p=self.dropout, training=self.training) return attn_output, past_key_value @@ -383,8 +393,13 @@ class ProphetNetFeedForward(Layer): """ This is the residual two feed-forward layer block based on the original Transformer implementation. """ - def __init__(self, hidden_size, activation_function, activation_dropout, - dropout, ffn_dim: int): + + def __init__(self, + hidden_size, + activation_function, + activation_dropout, + dropout, + ffn_dim: int): super(ProphetNetFeedForward, self).__init__() self.activation_fn = ACT2FN[activation_function] self.intermediate = nn.Linear(hidden_size, ffn_dim) @@ -396,13 +411,11 @@ def forward(self, hidden_states): hidden_states = self.intermediate(hidden_states) hidden_states = self.activation_fn(hidden_states) - hidden_states = F.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) + hidden_states = F.dropout( + hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.output(hidden_states) - hidden_states = F.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = F.dropout( + hidden_states, p=self.dropout, training=self.training) return hidden_states @@ -433,8 +446,8 @@ def __init__(self, hidden_size, num_buckets, relative_max_distance, self.out_proj = nn.Linear(hidden_size, hidden_size) # rel position embeddings - self.relative_pos_embeddings = nn.Linear( - hidden_size, self.num_buckets * self.num_attn_heads) + self.relative_pos_embeddings = nn.Linear(hidden_size, self.num_buckets * + self.num_attn_heads) def _shape(self, tensor, seq_len, batch_size): return paddle.transpose( @@ -445,7 +458,7 @@ def _shape(self, tensor, seq_len, batch_size): def forward(self, hidden_states, - past_key_value: Optional[Tuple[Tensor]] = None, + past_key_value: Optional[Tuple[Tensor]]=None, attention_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, @@ -510,16 +523,16 @@ def forward(self, (batch_size, self.num_attn_heads, -1, self.head_dim)), paddle.reshape( main_value_states, - (batch_size, self.num_attn_heads, -1, self.head_dim)), - ) + (batch_size, self.num_attn_heads, -1, self.head_dim)), ) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) # MAIN-STREAM # main attn weights - main_attn_weights = paddle.bmm( - main_query_states, paddle.transpose(main_key_states, (0, 2, 1))) + main_attn_weights = paddle.bmm(main_query_states, + paddle.transpose(main_key_states, + (0, 2, 1))) # retrieve relative position embeddings for each layer -> see paper for more details main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( @@ -531,13 +544,11 @@ def forward(self, if attention_mask is not None: main_attn_weights = main_attn_weights + attention_mask - main_attn_probs = F.softmax(main_attn_weights, - axis=-1, - dtype=main_attn_weights.dtype) + main_attn_probs = F.softmax( + main_attn_weights, axis=-1, dtype=main_attn_weights.dtype) - main_attn_probs = F.dropout(main_attn_probs, - p=self.attention_dropout, - training=self.training) + main_attn_probs = F.dropout( + main_attn_probs, p=self.attention_dropout, training=self.training) # project to attn_output main_attn_output = paddle.bmm(main_attn_probs, main_value_states) @@ -553,26 +564,34 @@ def forward(self, # PREDICT-STREAM # [ngram, B*head, T, c] predict_query_states = paddle.reshape( - paddle.concat(predict_query_states_list, axis=0), + paddle.concat( + predict_query_states_list, axis=0), (self.ngram, -1, sequence_length, self.head_dim)) # [ngram, B*head, 2*T, c] - predict_key_states = paddle.concat([ - paddle.unsqueeze(paddle.concat([main_key_states, key], axis=1), - axis=0) for key in predict_key_states_list - ], - axis=0) + predict_key_states = paddle.concat( + [ + paddle.unsqueeze( + paddle.concat( + [main_key_states, key], axis=1), axis=0) + for key in predict_key_states_list + ], + axis=0) # [ngram, T, B, C] predict_hidden_states = paddle.reshape( - paddle.concat(hidden_states_predict_list, axis=0), + paddle.concat( + hidden_states_predict_list, axis=0), (self.ngram, sequence_length, batch_size, hidden_size)) # [ngram, B*head, 2*T, c] - predict_value_states = paddle.concat([ - paddle.unsqueeze(paddle.concat([main_value_states, v_p], axis=1), - axis=0) for v_p in predict_value_states_list - ], - axis=0) + predict_value_states = paddle.concat( + [ + paddle.unsqueeze( + paddle.concat( + [main_value_states, v_p], axis=1), axis=0) + for v_p in predict_value_states_list + ], + axis=0) # [ngram, B*head, T, 2*T] predict_attn_weights = paddlenlp.ops.einsum( @@ -591,13 +610,13 @@ def forward(self, predict_attn_weights = predict_attn_weights + paddle.cast( extended_predict_attention_mask, predict_attn_weights.dtype) - predict_attn_probs = F.softmax(predict_attn_weights, - axis=-1, - dtype=predict_attn_weights.dtype) + predict_attn_probs = F.softmax( + predict_attn_weights, axis=-1, dtype=predict_attn_weights.dtype) - predict_attn_probs = F.dropout(predict_attn_probs, - p=self.attention_dropout, - training=self.training) + predict_attn_probs = F.dropout( + predict_attn_probs, + p=self.attention_dropout, + training=self.training) # project to attention output # [ngram, B*head, T, c] predict_attn_output = paddlenlp.ops.einsum( @@ -617,7 +636,8 @@ def forward(self, # concat to single attn output # [B, 1+ngram*T, C] attn_output = paddle.reshape( - paddle.concat([main_attn_output, predict_attn_output], axis=1), + paddle.concat( + [main_attn_output, predict_attn_output], axis=1), (batch_size, -1, hidden_size)) # reshape into better form for `config.output_attentions` main_attn_probs = paddle.reshape( @@ -628,9 +648,8 @@ def forward(self, (self.ngram, batch_size, self.num_attn_heads, sequence_length, -1)), (1, 0, 2, 3, 4)) - attn_output = F.dropout(attn_output, - p=self.dropout, - training=self.training) + attn_output = F.dropout( + attn_output, p=self.dropout, training=self.training) return attn_output, main_attn_probs, predict_attn_probs, past_key_value @@ -642,13 +661,14 @@ def get_main_relative_pos_embeddings(self, hidden_states, attn_weights, if main_relative_position_buckets is None: batch_size, sequence_length = hidden_states.shape[:2] relative_positions = (paddle.tile( - paddle.unsqueeze(paddle.unsqueeze(paddle.arange( - 1, attn_weights.shape[-1] + 1), - axis=0), - axis=0), + paddle.unsqueeze( + paddle.unsqueeze( + paddle.arange(1, attn_weights.shape[-1] + 1), axis=0), + axis=0), repeat_times=[batch_size, sequence_length, 1])) relative_positions = relative_positions - paddle.tile( - paddle.unsqueeze(position_ids, axis=0), + paddle.unsqueeze( + position_ids, axis=0), repeat_times=[batch_size, sequence_length, 1]) # [B, T, s] main_relative_position_buckets = compute_relative_buckets( self.num_buckets, self.relative_max_distance, @@ -666,8 +686,9 @@ def get_main_relative_pos_embeddings(self, hidden_states, attn_weights, main_relative_position_buckets = paddle.cast( paddle.reshape( - paddle.tile(main_relative_position_buckets, - repeat_times=[1, self.num_attn_heads, 1]), + paddle.tile( + main_relative_position_buckets, + repeat_times=[1, self.num_attn_heads, 1]), (-1, main_relative_position_buckets.shape[-1])), dtype=paddle.int64) # [B*head*T, T] rel_pos_embeddings = paddle.reshape( @@ -675,14 +696,15 @@ def get_main_relative_pos_embeddings(self, hidden_states, attn_weights, (-1, rel_pos_embeddings.shape[-1])) # [B*head*T,Buckets] main_relative_position_buckets_index = paddle.tile( - main_relative_position_buckets.unsqueeze(2), repeat_times=[1, 1, 2]) + main_relative_position_buckets.unsqueeze(2), + repeat_times=[1, 1, 2]) main_relative_position_buckets_index[:, :, 0] = \ paddle.tile(paddle.arange(0, main_relative_position_buckets_index.shape[0]).unsqueeze(1), repeat_times=[1, main_relative_position_buckets_index.shape[1]]) main_relative_pos_embeddings = paddle.reshape( - paddle.gather_nd(rel_pos_embeddings, - index=main_relative_position_buckets_index), + paddle.gather_nd( + rel_pos_embeddings, index=main_relative_position_buckets_index), (attn_weights.shape[:2] + [-1])) return main_relative_pos_embeddings @@ -700,14 +722,15 @@ def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, assert position_ids[0][0] == key_sequence_length - 1, \ "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)" relative_positions = (paddle.tile( - paddle.unsqueeze(paddle.unsqueeze(paddle.arange( - 0, key_sequence_length), - axis=0), - axis=0), + paddle.unsqueeze( + paddle.unsqueeze( + paddle.arange(0, key_sequence_length), axis=0), + axis=0), repeat_times=[batch_size, sequence_length, 1])) relative_positions = relative_positions - paddle.tile( - paddle.unsqueeze(position_ids, axis=0), + paddle.unsqueeze( + position_ids, axis=0), repeat_times=[batch_size, sequence_length, 1]) predict_relative_position_buckets = compute_relative_buckets( self.num_buckets, self.relative_max_distance, @@ -725,7 +748,8 @@ def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, -1)) # [ngram*B*head, T, bucket] predict_relative_position_buckets = paddle.tile( - paddle.unsqueeze(predict_relative_position_buckets, axis=0), + paddle.unsqueeze( + predict_relative_position_buckets, axis=0), repeat_times=[self.ngram, 1, self.num_attn_heads, 1]) # [ngram, B, head*T, S] @@ -744,8 +768,9 @@ def get_predict_relative_pos_embeddings(self, hidden_states, attn_weights, repeat_times=[1, predict_relative_position_buckets_index.shape[1]]) predict_relative_pos_embeddings = paddle.reshape( - paddle.gather_nd(rel_pos_embeddings, - index=predict_relative_position_buckets_index), + paddle.gather_nd( + rel_pos_embeddings, + index=predict_relative_position_buckets_index), (self.ngram, batch_size * self.num_attn_heads, sequence_length, -1)) # [ngram, B*head, T, S] @@ -756,6 +781,7 @@ class ProphetNetEncoderLayer(Layer): """ Encoder block for Prophetnet """ + def __init__(self, hidden_size, encoder_ffn_dim, activation_function, activation_dropout, attention_dropout, dropout, num_encoder_attention_heads): @@ -767,16 +793,15 @@ def __init__(self, hidden_size, encoder_ffn_dim, activation_function, self.self_attn_layer_norm = nn.LayerNorm(hidden_size) # 2nd residual block - self.feed_forward = ProphetNetFeedForward(hidden_size, - activation_function, - activation_dropout, dropout, - encoder_ffn_dim) + self.feed_forward = ProphetNetFeedForward( + hidden_size, activation_function, activation_dropout, dropout, + encoder_ffn_dim) self.feed_forward_layer_norm = nn.LayerNorm(hidden_size) def forward(self, hidden_states, attention_mask): # 1st residual block - attention_output, _ = self.self_attn(hidden_states=hidden_states, - attention_mask=attention_mask) + attention_output, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask) hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) @@ -791,6 +816,7 @@ class ProphetNetDecoderLayer(Layer): """ Decoder block for Prophetnet """ + def __init__(self, hidden_size, num_buckets, relative_max_distance, num_decoder_attention_heads, activation_function, activation_dropout, dropout, attention_dropout, ngram, @@ -810,10 +836,9 @@ def __init__(self, hidden_size, num_buckets, relative_max_distance, self.cross_attn_layer_norm = nn.LayerNorm(hidden_size) # 3rd residual block - self.feed_forward = ProphetNetFeedForward(hidden_size, - activation_function, - activation_dropout, dropout, - decoder_ffn_dim) + self.feed_forward = ProphetNetFeedForward( + hidden_size, activation_function, activation_dropout, dropout, + decoder_ffn_dim) self.feed_forward_layer_norm = nn.LayerNorm(hidden_size) def forward(self, @@ -826,7 +851,7 @@ def forward(self, predict_relative_position_buckets=None, position_ids=None, past_key_value=None, - use_cache: bool = True): + use_cache: bool=True): # 1st residual block # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[: @@ -838,8 +863,7 @@ def forward(self, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - ) + position_ids=position_ids, ) hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) @@ -852,8 +876,7 @@ def forward(self, hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, - past_key_value=cross_attn_past_key_value, - ) + past_key_value=cross_attn_past_key_value, ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) @@ -879,6 +902,7 @@ class ProphetNetEncoder(ProphetNetPretrainedModel): The word embedding parameters. This can be used to initialize :class:`~transformers.ProphetNetEncoder` with pre-defined word embeddings instead of randomly initialized word embeddings. """ + def __init__(self, word_embeddings, vocab_size, hidden_size, pad_token_id, max_position_embeddings, encoder_ffn_dim, activation_function, activation_dropout, attention_dropout, dropout, @@ -888,9 +912,8 @@ def __init__(self, word_embeddings, vocab_size, hidden_size, pad_token_id, if word_embeddings is not None: self.word_embeddings = word_embeddings else: - self.word_embeddings = nn.Embedding(vocab_size, - hidden_size, - padding_idx=pad_token_id) + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id) self.position_embeddings = ProphetNetPositionalEmbeddings( max_position_embeddings, hidden_size, pad_token_id) @@ -915,10 +938,11 @@ def forward(self, input_ids=None, attention_mask=None): if attention_mask is not None: extended_attention_mask = (paddle.tile( 1.0 - attention_mask.unsqueeze(1), - repeat_times=[self.config["num_encoder_attention_heads"], 1, 1 - ])) * -10000.0 - extended_attention_mask = paddle.cast(extended_attention_mask, - dtype=inputs_embeds.dtype) + repeat_times=[ + self.config["num_encoder_attention_heads"], 1, 1 + ])) * -10000.0 + extended_attention_mask = paddle.cast( + extended_attention_mask, dtype=inputs_embeds.dtype) extended_attention_mask.stop_gradient = True else: extended_attention_mask = None @@ -928,9 +952,8 @@ def forward(self, input_ids=None, attention_mask=None): hidden_states = inputs_embeds + position_embeddings hidden_states = self.embeddings_layer_norm(hidden_states) - hidden_states = F.dropout(hidden_states, - p=self.config["dropout"], - training=self.training) + hidden_states = F.dropout( + hidden_states, p=self.config["dropout"], training=self.training) for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( @@ -956,21 +979,19 @@ def __init__(self, word_embeddings, vocab_size, hidden_size, pad_token_id, if word_embeddings is not None: self.word_embeddings = word_embeddings else: - self.word_embeddings = nn.Embedding(vocab_size, - hidden_size, - padding_idx=pad_token_id) + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id) self.position_embeddings = ProphetNetPositionalEmbeddings( max_position_embeddings, hidden_size, pad_token_id) self.ngram_embeddings = nn.Embedding(self.ngram, hidden_size) self.layers = nn.LayerList([ - ProphetNetDecoderLayer(hidden_size, num_buckets, - relative_max_distance, - num_decoder_attention_heads, - activation_function, activation_dropout, - dropout, attention_dropout, ngram, - decoder_ffn_dim, add_cross_attention) + ProphetNetDecoderLayer( + hidden_size, num_buckets, relative_max_distance, + num_decoder_attention_heads, activation_function, + activation_dropout, dropout, attention_dropout, ngram, + decoder_ffn_dim, add_cross_attention) for _ in range(num_decoder_layers) ]) self.embeddings_layer_norm = nn.LayerNorm(hidden_size) @@ -1034,22 +1055,22 @@ def forward(self, if encoder_attention_mask is not None: extended_encoder_attention_mask = (1.0 - paddle.tile( encoder_attention_mask[:, None, :], - repeat_times=[self.config["num_decoder_attention_heads"], 1, 1 - ])) * -10000.0 + repeat_times=[ + self.config["num_decoder_attention_heads"], 1, 1 + ])) * -10000.0 extended_encoder_attention_mask = paddle.cast( extended_encoder_attention_mask, dtype=inputs_embeds.dtype) else: extended_encoder_attention_mask = None - hidden_states = paddle.concat([hidden_states] + ngram_hidden_states, - axis=1) + hidden_states = paddle.concat( + [hidden_states] + ngram_hidden_states, axis=1) if self.embeddings_layer_norm: hidden_states = self.embeddings_layer_norm(hidden_states) - hidden_states = F.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = F.dropout( + hidden_states, p=self.dropout, training=self.training) present_key_values = () if use_cache else None @@ -1065,8 +1086,7 @@ def forward(self, encoder_attn_mask=extended_encoder_attention_mask, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets= - predict_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) @@ -1080,7 +1100,8 @@ def forward(self, last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.ngram > 0 else None # 2-gram return tuple( - v for v in + v + for v in [last_hidden_state, last_hidden_state_ngram, present_key_values] if v is not None) @@ -1089,9 +1110,9 @@ def compute_buffered_relative_buckets(self, position_ids): if not hasattr(self, '_main_relative_buckets' ) or self._main_relative_buckets is None: - position_ids = paddle.tile(paddle.arange( - 1, self.max_target_positions + 1), - repeat_times=[1, 1]) + position_ids = paddle.tile( + paddle.arange(1, self.max_target_positions + 1), + repeat_times=[1, 1]) self._main_relative_buckets, self._predict_relative_buckets = compute_all_stream_relative_buckets( self.num_buckets, self.relative_max_distance, position_ids) @@ -1099,16 +1120,18 @@ def compute_buffered_relative_buckets(self, position_ids): main_relative_buckets = paddle.tile( self._main_relative_buckets[:, :sequence_length, :sequence_length], repeat_times=[batch_size, 1, 1]) - predict_relative_buckets = paddle.tile(paddle.concat([ - self. - _predict_relative_buckets[:, :sequence_length, :sequence_length], - self._predict_relative_buckets[:, :sequence_length, - self.max_target_positions:self. - max_target_positions + - sequence_length] - ], - axis=2), - repeat_times=[batch_size, 1, 1]) + predict_relative_buckets = paddle.tile( + paddle.concat( + [ + self._predict_relative_buckets[:, :sequence_length, : + sequence_length], + self._predict_relative_buckets[:, :sequence_length, + self.max_target_positions: + self.max_target_positions + + sequence_length] + ], + axis=2), + repeat_times=[batch_size, 1, 1]) return main_relative_buckets, predict_relative_buckets @@ -1128,15 +1151,18 @@ def prepare_attention_mask(self, hidden_states, attention_mask): # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - - attention_mask.unsqueeze(1)) * -10000.0 + extended_attention_mask = ( + 1.0 - attention_mask.unsqueeze(1)) * -10000.0 extended_attention_mask = extended_causal_mask + extended_attention_mask else: extended_attention_mask = extended_causal_mask - return paddle.cast(paddle.tile( - extended_attention_mask, - repeat_times=[self.config["num_decoder_attention_heads"], 1, 1]), - dtype=hidden_states.dtype) + return paddle.cast( + paddle.tile( + extended_attention_mask, + repeat_times=[ + self.config["num_decoder_attention_heads"], 1, 1 + ]), + dtype=hidden_states.dtype) def prepare_predict_attention_mask(self, hidden_states, attention_mask): batch_size, seq_length = hidden_states.shape[:2] @@ -1147,13 +1173,13 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask): '_predict_causal_mask') or self._predict_causal_mask is None: self._predict_causal_mask = ngram_attention_bias( self.max_target_positions, self.ngram, hidden_states.dtype) - predict_causal_mask = paddle.concat([ - self._predict_causal_mask[:, :seq_length, :seq_length], - self._predict_causal_mask[:, :seq_length, - self.max_target_positions:self. - max_target_positions + seq_length] - ], - axis=-1) + predict_causal_mask = paddle.concat( + [ + self._predict_causal_mask[:, :seq_length, :seq_length], self. + _predict_causal_mask[:, :seq_length, self.max_target_positions: + self.max_target_positions + seq_length] + ], + axis=-1) extended_predict_causal_mask = paddle.expand( predict_causal_mask[:, None, :, :], shape=predict_causal_mask.shape[:1] + [batch_size] + @@ -1166,17 +1192,19 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask): extended_attention_mask = extended_attention_mask.expand( (self.ngram, batch_size, seq_length, seq_length)) # predicted stream attention_mask should always be 0 - extended_attention_mask = paddle.concat([ - extended_attention_mask, - paddle.zeros_like(extended_attention_mask) - ], - axis=-1) + extended_attention_mask = paddle.concat( + [ + extended_attention_mask, + paddle.zeros_like(extended_attention_mask) + ], + axis=-1) extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask else: extended_predict_attention_mask = extended_predict_causal_mask - return paddle.cast(extended_predict_attention_mask.tile( - [1, self.config["num_decoder_attention_heads"], 1, 1]), - dtype=hidden_states.dtype) + return paddle.cast( + extended_predict_attention_mask.tile( + [1, self.config["num_decoder_attention_heads"], 1, 1]), + dtype=hidden_states.dtype) class ProphetNetModel(ProphetNetPretrainedModel): @@ -1212,9 +1240,8 @@ def __init__(self, self.pad_token_id = pad_token_id self.disable_ngram_loss = disable_ngram_loss self.decoder_start_token_id = decoder_start_token_id - self.word_embeddings = nn.Embedding(vocab_size, - hidden_size, - padding_idx=pad_token_id) + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id) self.encoder = ProphetNetEncoder( self.word_embeddings, vocab_size, hidden_size, pad_token_id, @@ -1242,14 +1269,15 @@ def forward(self, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, - encoder_output: Optional[Tuple] = None, + encoder_output: Optional[Tuple]=None, use_cache=True, past_key_values=None): if attention_mask is None: assert input_ids is not None, "input_ids should be " \ "specified when generating attention_mask" - attention_mask = paddle.cast(input_ids != self.pad_token_id, - dtype=paddle.get_default_dtype()) + attention_mask = paddle.cast( + input_ids != self.pad_token_id, + dtype=paddle.get_default_dtype()) if decoder_attention_mask is None: assert decoder_input_ids is not None, "decoder_input_ids should be " \ @@ -1258,14 +1286,15 @@ def forward(self, decoder_input_ids != self.pad_token_id, dtype=paddle.get_default_dtype()) if encoder_output is None: - encoder_output = self.encoder(input_ids=input_ids, - attention_mask=attention_mask) - decoder_outputs = self.decoder(input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_output, - encoder_attention_mask=attention_mask, - use_cache=use_cache, - past_key_values=past_key_values) + encoder_output = self.encoder( + input_ids=input_ids, attention_mask=attention_mask) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + past_key_values=past_key_values) return decoder_outputs + (encoder_output, ) @@ -1274,10 +1303,11 @@ def __init__(self, in_features, out_features, weight_attr=None, name=None): super(Linear_wo_bias, self).__init__() self._dtype = self._helper.get_default_dtype() self._weight_attr = weight_attr - self.weight = self.create_parameter(shape=[in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) self.name = name def forward(self, input): @@ -1314,13 +1344,14 @@ def forward(self, if labels is not None and decoder_input_ids is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) - outputs = self.prophetnet(input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_output=encoder_output, - use_cache=use_cache, - past_key_values=past_key_values) + outputs = self.prophetnet( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_output=encoder_output, + use_cache=use_cache, + past_key_values=past_key_values) batch_size, sequence_length = decoder_input_ids.shape diff --git a/paddlenlp/transformers/prophetnet/tokenizer.py b/paddlenlp/transformers/prophetnet/tokenizer.py index ee9d1c04e031..13e037a9d71e 100644 --- a/paddlenlp/transformers/prophetnet/tokenizer.py +++ b/paddlenlp/transformers/prophetnet/tokenizer.py @@ -26,6 +26,7 @@ class Trie: Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass Loose reference https://en.wikipedia.org/wiki/Trie """ + def __init__(self): self.data = {} @@ -159,7 +160,7 @@ def split(self, text: str) -> List[str]: next_char = text[lookahead_index] # End lookahead - # Storing and resetting + # Storing and resetting offsets.append(start) offsets.append(end) reset = True @@ -304,14 +305,13 @@ def __init__(self, ] self.tokens_trie = create_trie(self.unique_no_split_tokens) self.vocab = load_vocab(vocab_file) - self.ids_to_tokens = collections.OrderedDict([ - (ids, tok) for tok, ids in self.vocab.items() - ]) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, - unk_token=unk_token) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=unk_token) @property def vocab_size(self): @@ -481,8 +481,7 @@ def create_token_type_ids_from_sequences(self, return len(token_ids_0 + sep) * [0] return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] - def build_inputs_with_special_tokens(self, - token_ids_0, + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and @@ -510,8 +509,8 @@ def save_vocabulary(self, save_directory): vocab_file = os.path.join(save_directory, self.resource_files_names["vocab_file"]) with open(vocab_file, "w", encoding="utf-8") as writer: - for token, token_index in sorted(self.vocab.items(), - key=lambda kv: kv[1]): + for token, token_index in sorted( + self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: logging.warning( f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." From fb76a3b93ccd29fa24d4ccaf8a064d26fda2780b Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Mon, 21 Feb 2022 21:51:28 +0800 Subject: [PATCH 5/7] add prophetnet example --- .../text_summarization/prophetnet/README.md | 234 +++++++ .../text_summarization/prophetnet/eval.py | 73 ++ .../prophetnet/evaluate/cnndm/bs_pyrouge.py | 658 ++++++++++++++++++ .../evaluate/cnndm/postprocess_cnn_dm.py | 261 +++++++ .../prophetnet/evaluate/gigaword/__init__.py | 0 .../evaluate/gigaword/bs_pyrouge.py | 658 ++++++++++++++++++ .../prophetnet/evaluate/gigaword/eval.py | 380 ++++++++++ .../text_summarization/prophetnet/generate.py | 341 +++++++++ .../prophetnet/requirements.txt | 5 + .../text_summarization/prophetnet/run_eval.sh | 37 + .../prophetnet/run_train.sh | 29 + .../prophetnet/train_prophetnet.py | 321 +++++++++ .../prophetnet/uncase_tokenize_data.py | 117 ++++ .../prophetnet/uncompress_data.sh | 12 + 14 files changed, 3126 insertions(+) create mode 100644 examples/text_summarization/prophetnet/README.md create mode 100644 examples/text_summarization/prophetnet/eval.py create mode 100644 examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py create mode 100644 examples/text_summarization/prophetnet/evaluate/cnndm/postprocess_cnn_dm.py create mode 100644 examples/text_summarization/prophetnet/evaluate/gigaword/__init__.py create mode 100644 examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py create mode 100644 examples/text_summarization/prophetnet/evaluate/gigaword/eval.py create mode 100644 examples/text_summarization/prophetnet/generate.py create mode 100644 examples/text_summarization/prophetnet/requirements.txt create mode 100644 examples/text_summarization/prophetnet/run_eval.sh create mode 100644 examples/text_summarization/prophetnet/run_train.sh create mode 100644 examples/text_summarization/prophetnet/train_prophetnet.py create mode 100644 examples/text_summarization/prophetnet/uncase_tokenize_data.py create mode 100644 examples/text_summarization/prophetnet/uncompress_data.sh diff --git a/examples/text_summarization/prophetnet/README.md b/examples/text_summarization/prophetnet/README.md new file mode 100644 index 000000000000..c7f83cf430ce --- /dev/null +++ b/examples/text_summarization/prophetnet/README.md @@ -0,0 +1,234 @@ +# Prophetnet + +## 模型简介 + +ProphetNet(先知网络)是一种新型的 seq2seq 预训练模型。在训练时,Prophetnet 每一时刻将会学习同时预测未来的 N 个字符,这种自监督学习目标可以使得模型考虑未来更远的字符,防止模型对强局部相关(strong +local correlation)过拟合。 + +本项目是 Prophetnet 在 PaddlePaddle 2.2 上开源实现的文本摘要的例子,包含了在 CNN/DailyMail 数据集,Gigaword 数据集上微调和生成的代码。 + +### 项目依赖 + +``` +pip install -r requirements.txt +python -m pip install paddlepaddle-gpu==2.2.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +pip install paddlenlp==2.2.3 +``` + +### 代码结构说明 + +以下是本项目主要代码结构及说明: + +```text +├── train_prophetnet.py # 模型finetune主程序入口 +├── generate.py # 模型生成主程序入口 +├── eval.py # 生成结果评估入口 +├── uncase_tokenize_data.py # 数据预处理 +├── uncompress_data.sh # 数据解压脚本 +├── run_train.sh # 模型训练脚本 +├── run_eval.sh # 模型评估脚本 +├── requirements.txt # 环境依赖文件 +└── README.md # 文档说明 +``` + +### 数据准备 + +GLGE 数据集下载:[链接](https://drive.google.com/file/d/1F4zppa9Gqrh6iNyVsZJkxfbm5waalqEA/view) + +GLGE 测试集下载:[链接](https://drive.google.com/file/d/11lDXIG87dChIfukq3x2Wx4r5_duCRm_J/view) + +将glge_public.tar与glge_hidden_v1.1.tar.gz放入到项目根目录下。 + +``` +bash uncompress_data.sh +``` + +### 下载预训练权重与词表 + +模型权重和词表[下载链接](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 提取码:o28q,下载后放入项目根目录。 + +### 数据预处理 + +``` +python uncase_tokenize_data.py --dataset +``` + +说明: + +- ``可选`cnndm`, `gigaword`. + +### 模型训练 + +``` +bash run_train.sh +``` + +或直接运行finetune程序 + +- cnndm: + +``` +python train_prophetnet.py \ + --dataset=cnndm \ + --pretrained_model_path=./model_state.pdparams \ + --batch_size=4 \ + --epochs=4 \ + --lr=0.0001 \ + --warmup_init_lr=1e-07 \ + --warmup_updates=1000 \ + --clip_norm=0.1 \ + --num_workers=4 \ + --output_dir=./ckpt/cnndm +``` + +- gigaword: + +``` +python train_prophetnet.py \ + --dataset=gigaword \ + --pretrained_model_path=./model_state.pdparams \ + --batch_size=16 \ + --epochs=6 \ + --lr=0.0001 \ + --warmup_init_lr=1e-07 \ + --warmup_updates=1000 \ + --clip_norm=0.1 \ + --num_workers=8 \ + --output_dir=./ckpt/gigaword +``` + +其中参数释义如下: + +- `dataset` 指定数据集,可选cnndm和gigaword + +- `pretrained_model_path` 本地预训练模型初始化权重文件路径,例如: ./model_state.pdparams。 + +- `batch_size` 表示训练样本批大小。 + +- `epochs` 表示训练轮数。 + +- `lr` 表示学习率 + +- `warmup_init_lr` 表示预热学习率 + +- `warmup_updates` 表示预热学习步数 + +- `clip_norm` 表示梯度裁剪 + +- `num_workers` 指定数据加载规模 + +- `output_idr` 指定微调结果权重存放路径 + +已经finetune好的模型权重: + +- cnndm : [链接](https://pan.baidu.com/s/1cemrUDxkqEW9raoasJ_VKw), 提取码:1egi + +- gigaword : [链接](https://pan.baidu.com/s/1qRH2FStT3vNQtDjZLkYJBQ), 提取码:on5v + +### 模型评估 + +使用prophetNet源码的[评估脚本](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 此脚本依赖于pyrouge,需要提前安装rouge。 + +``` +pip install git+https://github.com/pltrdy/pyrouge +``` + +``` +bash run_eval.sh +``` + +或直接运行模型生成程序 + +- cnndm: + +``` +python generate.py \ + --dataset=cnndm \ + --vocab_file=./prophetnet.tokenizer \ + --output_path=./generate/cnndm/generate.txt \ + --min_target_length=45 \ + --max_target_length=110 \ + --decode_strategy=beam_search \ + --num_beams=4 \ + --length_penalty=1.2 \ + --batch_size=16 \ + --ignore_pad_token_for_loss=True \ + --early_stopping=True \ + --logging_steps=100 \ + --device=gpu + +python eval.py --dataset cnndm --generated ./generate/cnndm/generate.txt +``` + +- gigaword: + +``` +python generate.py \ + --dataset=gigaword \ + --vocab_file=./prophetnet.tokenizer \ + --output_path=./generate/gigaword/generate.txt \ + --min_target_length=1 \ + --max_target_length=200 \ + --decode_strategy=beam_search \ + --num_beams=4 \ + --length_penalty=1.6 \ + --batch_size=16 \ + --ignore_pad_token_for_loss=True \ + --early_stopping=True \ + --logging_steps=100 \ + --device=gpu + +python eval.py --dataset gigaword --generated ./generate/gigaword/generate.txt +``` + +其中参数释义如下: + +- `dataset` 指定数据集,可选cnndm和gigaword + +- `vocab_file` 指定词表文件 + +- `output_path` 指定生成结果存放路径 + +- `min_target_length` 指定解码最短长度 + +- `max_target_length` 指定解码最大长度 + +- `decode_strategy` 指定解码策略 + +- `num_beams` 指定beam_search解码宽度 + +- `length_penalty` 指定beam_search解码的长度指数惩罚 + +- `batch_size` 指定评估样本批大小 + +- `ignore_pad_token_for_loss` 表示计算loss时忽略padding + +- `early_stopping` 指定生成结束符是否停止预测 + +- `logging_steps` 指定日志打印间隔 + +- `device` 指定使用设备 + +### 微调测试精度 + +> #### 在CNN/DM数据集的测试效果如下表。 + +|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L| +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +|prophetnet-large-uncased|Adam|4|CNN/DM|44.17|21.24|41.36| + +> #### 在gigaword数据集的测试效果如下表。 + +|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L| +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +|prophetnet-large-uncased|Adam|16|gigaword|38.92|19.81|36.06| + +### 实验环境 + +- GPU RTX3090 * 1, CPU Intel i7-11700k +- Ubuntu 18.04 + +### 参考文献 + +1. Qi W, Yan Y, Gong Y, et al. Prophetnet: Predicting future n-gram for sequence-to-sequence pre-training[J]. arXiv + preprint arXiv:2001.04063, 2020. diff --git a/examples/text_summarization/prophetnet/eval.py b/examples/text_summarization/prophetnet/eval.py new file mode 100644 index 000000000000..cba318b43963 --- /dev/null +++ b/examples/text_summarization/prophetnet/eval.py @@ -0,0 +1,73 @@ +import argparse +import os +import re +import sys +from os import listdir +from os.path import isfile, join + +parser = argparse.ArgumentParser() +parser.add_argument( + "--dataset", + type=str, + help="choose from all, or 1 of 8 dataset like cnndm, gigaword etc.") +parser.add_argument("--generated", type=str, help="generated output file.") + +args = parser.parse_args() + +data_root_path = 'data' + +support_dataset = ['cnndm', 'gigaword'] +files2rouge_template = '.*ROUGE-1 Average_F: (?P\d+(\.\d*)?|\.\d+).*ROUGE-2 Average_F: (?P\d+(\.\d*)?|\.\d+).*ROUGE-L Average_F: (?P\d+(\.\d*)?|\.\d+).*' +# gigaword_template='.*ROUGE-1: (?P\d+(\.\d*)?|\.\d+).*ROUGE-2: (?P\d+(\.\d*)?|\.\d+).*ROUGE-L: (?P\d+(\.\d*)?|\.\d+).*' +qg_template = '.*Bleu_4: (?P\d+(\.\d*)?|\.\d+).*METEOR: (?P\d+(\.\d*)?|\.\d+).*ROUGE_L: (?P\d+(\.\d*)?|\.\d+).*' +personachat_template = '.*?(?P[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*?(?P[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*Bleu_1: (?P\d+(\.\d*)?|\.\d+).*Bleu_2: (?P\d+(\.\d*)?|\.\d+).*' + + +def scale_up(d): + return {k: float(d[k]) * 100 for k in d.keys()} + + +def eval_one_dataset(): + golden_file = f"{data_root_path}/{args.dataset}_data/test.tgt" + + eval_template = { + 'cnndm': + f"python ./evaluate/cnndm/postprocess_cnn_dm.py --generated {generated_file} --golden {golden_file}", + 'gigaword': + f"python ./evaluate/gigaword/eval.py --perl --pred {generated_file} --gold {golden_file}", + } + + cmd = eval_template[args.dataset] + try: + output = os.popen(cmd).read() + if args.dataset in ['cnndm', 'gigaword']: + d = re.search(files2rouge_template, + output.replace("\n", " ")).groupdict() + d = scale_up(d) + print( + f"{args.dataset}\trouge1/rouge2/rougeL\t{d['rouge1_f']:.2f}/{d['rouge2_f']:.2f}/{d['rougeL_f']:.2f}" + ) + except: + print("Unexpected error:", sys.exc_info()[0]) + print(f"{args.dataset} evaluate failed!") + + +if args.dataset != 'all': + generated_file = args.generated + eval_one_dataset() +else: + output_root_path = args.generated + onlyfolders = [ + f for f in listdir(output_root_path) + if not isfile(join(args.generated, f)) + ] + for dataset in support_dataset: + for folder in onlyfolders: + if folder.startswith(dataset): + for hypo_file in listdir(args.generated + '/' + folder): + if 'hypo' in hypo_file or 'score' in hypo_file: + generated_file = args.generated + '/' + folder + '/' + hypo_file + print(f"{dataset}\tpredict_file:{generated_file}") + args.dataset = dataset + args.gnerated = generated_file + eval_one_dataset() diff --git a/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py b/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py new file mode 100644 index 000000000000..efee4b9a911c --- /dev/null +++ b/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py @@ -0,0 +1,658 @@ +from __future__ import print_function, unicode_literals, division + +import codecs +import os +import platform +import re +from functools import partial +from subprocess import check_output +from tempfile import mkdtemp + +try: + from configparser import ConfigParser +except ImportError: + from ConfigParser import ConfigParser + +import logging +from pyrouge.utils import log +from pyrouge.utils.file_utils import verify_dir + +REMAP = { + "-lrb-": "(", + "-rrb-": ")", + "-lcb-": "{", + "-rcb-": "}", + "-lsb-": "[", + "-rsb-": "]", + "``": '"', + "''": '"' +} + + +def clean(x): + return re.sub(r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", + lambda m: REMAP.get(m.group()), x) + + +class DirectoryProcessor: + @staticmethod + def process(input_dir, output_dir, function): + """ + Apply function to all files in input_dir and save the resulting ouput + files in output_dir. + + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logger = log.get_global_console_logger() + logger.info("Processing files in {}.".format(input_dir)) + input_file_names = os.listdir(input_dir) + for input_file_name in input_file_names: + input_file = os.path.join(input_dir, input_file_name) + with codecs.open(input_file, "r", encoding="UTF-8") as f: + input_string = f.read() + output_string = function(input_string) + output_file = os.path.join(output_dir, input_file_name) + with codecs.open(output_file, "w", encoding="UTF-8") as f: + f.write(clean(output_string.lower())) + logger.info("Saved processed files to {}.".format(output_dir)) + + +class Rouge155(object): + """ + This is a wrapper for the ROUGE 1.5.5 summary evaluation package. + This class is designed to simplify the evaluation process by: + + 1) Converting summaries into a format ROUGE understands. + 2) Generating the ROUGE configuration file automatically based + on filename patterns. + + This class can be used within Python like this: + + rouge = Rouge155() + rouge.system_dir = 'test/systems' + rouge.model_dir = 'test/models' + + # The system filename pattern should contain one group that + # matches the document ID. + rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' + + # The model filename pattern has '#ID#' as a placeholder for the + # document ID. If there are multiple model summaries, pyrouge + # will use the provided regex to automatically match them with + # the corresponding system summary. Here, [A-Z] matches + # multiple model summaries for a given #ID#. + rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' + + rouge_output = rouge.evaluate() + print(rouge_output) + output_dict = rouge.output_to_dict(rouge_ouput) + print(output_dict) + -> {'rouge_1_f_score': 0.95652, + 'rouge_1_f_score_cb': 0.95652, + 'rouge_1_f_score_ce': 0.95652, + 'rouge_1_precision': 0.95652, + [...] + + + To evaluate multiple systems: + + rouge = Rouge155() + rouge.system_dir = '/PATH/TO/systems' + rouge.model_dir = 'PATH/TO/models' + for system_id in ['id1', 'id2', 'id3']: + rouge.system_filename_pattern = \ + 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) + rouge.model_filename_pattern = \ + 'SL.P.10.R.[A-Z].SL062003-#ID#.html' + rouge_output = rouge.evaluate(system_id) + print(rouge_output) + + """ + + def __init__(self, rouge_dir=None, rouge_args=None, temp_dir=None): + """ + Create a Rouge155 object. + + rouge_dir: Directory containing Rouge-1.5.5.pl + rouge_args: Arguments to pass through to ROUGE if you + don't want to use the default pyrouge + arguments. + + """ + self.temp_dir = temp_dir + self.log = log.get_global_console_logger() + self.log.setLevel(logging.WARNING) + self.__set_dir_properties() + self._config_file = None + self._settings_file = self.__get_config_path() + self.__set_rouge_dir(rouge_dir) + self.args = self.__clean_rouge_args(rouge_args) + self._system_filename_pattern = None + self._model_filename_pattern = None + + def save_home_dir(self): + config = ConfigParser() + section = 'pyrouge settings' + config.add_section(section) + config.set(section, 'home_dir', self._home_dir) + with open(self._settings_file, 'w') as f: + config.write(f) + self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) + + @property + def settings_file(self): + """ + Path of the setttings file, which stores the ROUGE home dir. + + """ + return self._settings_file + + @property + def bin_path(self): + """ + The full path of the ROUGE binary (although it's technically + a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl + + """ + if self._bin_path is None: + raise Exception( + "ROUGE path not set. Please set the ROUGE home directory " + "and ensure that ROUGE-1.5.5.pl exists in it.") + return self._bin_path + + @property + def system_filename_pattern(self): + """ + The regular expression pattern for matching system summary + filenames. The regex string. + + E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system + filenames in the SPL2003/system folder of the ROUGE SPL example + in the "sample-test" folder. + + Currently, there is no support for multiple systems. + + """ + return self._system_filename_pattern + + @system_filename_pattern.setter + def system_filename_pattern(self, pattern): + self._system_filename_pattern = pattern + + @property + def model_filename_pattern(self): + """ + The regular expression pattern for matching model summary + filenames. The pattern needs to contain the string "#ID#", + which is a placeholder for the document ID. + + E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model + filenames in the SPL2003/system folder of the ROUGE SPL + example in the "sample-test" folder. + + "#ID#" is a placeholder for the document ID which has been + matched by the "(\d+)" part of the system filename pattern. + The different model summaries for a given document ID are + matched by the "[A-Z]" part. + + """ + return self._model_filename_pattern + + @model_filename_pattern.setter + def model_filename_pattern(self, pattern): + self._model_filename_pattern = pattern + + @property + def config_file(self): + return self._config_file + + @config_file.setter + def config_file(self, path): + config_dir, _ = os.path.split(path) + verify_dir(config_dir, "configuration file") + self._config_file = path + + def split_sentences(self): + """ + ROUGE requires texts split into sentences. In case the texts + are not already split, this method can be used. + + """ + from pyrouge.utils.sentence_splitter import PunktSentenceSplitter + self.log.info("Splitting sentences.") + ss = PunktSentenceSplitter() + + def sent_split_to_string(s): + return "\n".join(ss.split(s)) + + process_func = partial( + DirectoryProcessor.process, function=sent_split_to_string) + self.__process_summaries(process_func) + + @staticmethod + def convert_summaries_to_rouge_format(input_dir, output_dir): + """ + Convert all files in input_dir into a format ROUGE understands + and saves the files to output_dir. The input files are assumed + to be plain text with one sentence per line. + + input_dir: Path of directory containing the input files. + output_dir: Path of directory in which the converted files + will be saved. + + """ + DirectoryProcessor.process(input_dir, output_dir, + Rouge155.convert_text_to_rouge_format) + + @staticmethod + def convert_text_to_rouge_format(text, title="dummy title"): + """ + Convert a text to a format ROUGE understands. The text is + assumed to contain one sentence per line. + + text: The text to convert, containg one sentence per line. + title: Optional title for the text. The title will appear + in the converted file, but doesn't seem to have + any other relevance. + + Returns: The converted text as string. + + """ + sentences = text.split("\n") + sent_elems = [ + "[{i}] " + "{text}".format( + i=i, text=sent) for i, sent in enumerate( + sentences, start=1) + ] + html = """ + +{title} + + +{elems} + +""".format( + title=title, elems="\n".join(sent_elems)) + + return html + + @staticmethod + def write_config_static(system_dir, + system_filename_pattern, + model_dir, + model_filename_pattern, + config_file_path, + system_id=None): + """ + Write the ROUGE configuration file, which is basically a list + of system summary files and their corresponding model summary + files. + + pyrouge uses regular expressions to automatically find the + matching model summary files for a given system summary file + (cf. docstrings for system_filename_pattern and + model_filename_pattern). + + system_dir: Path of directory containing + system summaries. + system_filename_pattern: Regex string for matching + system summary filenames. + model_dir: Path of directory containing + model summaries. + model_filename_pattern: Regex string for matching model + summary filenames. + config_file_path: Path of the configuration file. + system_id: Optional system ID string which + will appear in the ROUGE output. + + """ + system_filenames = [f for f in os.listdir(system_dir)] + system_models_tuples = [] + + system_filename_pattern = re.compile(system_filename_pattern) + for system_filename in sorted(system_filenames): + match = system_filename_pattern.match(system_filename) + if match: + id = match.groups(0)[0] + model_filenames = [model_filename_pattern.replace('#ID#', id)] + # model_filenames = Rouge155.__get_model_filenames_for_id( + # id, model_dir, model_filename_pattern) + system_models_tuples.append( + (system_filename, sorted(model_filenames))) + if not system_models_tuples: + raise Exception("Did not find any files matching the pattern {} " + "in the system summaries directory {}.".format( + system_filename_pattern.pattern, system_dir)) + + with codecs.open(config_file_path, 'w', encoding='utf-8') as f: + f.write('') + for task_id, (system_filename, model_filenames) in enumerate( + system_models_tuples, start=1): + eval_string = Rouge155.__get_eval_string( + task_id, system_id, system_dir, system_filename, model_dir, + model_filenames) + f.write(eval_string) + f.write("") + + def write_config(self, config_file_path=None, system_id=None): + """ + Write the ROUGE configuration file, which is basically a list + of system summary files and their matching model summary files. + + This is a non-static version of write_config_file_static(). + + config_file_path: Path of the configuration file. + system_id: Optional system ID string which will + appear in the ROUGE output. + + """ + if not system_id: + system_id = 1 + if (not config_file_path) or (not self._config_dir): + self._config_dir = mkdtemp(dir=self.temp_dir) + config_filename = "rouge_conf.xml" + else: + config_dir, config_filename = os.path.split(config_file_path) + verify_dir(config_dir, "configuration file") + self._config_file = os.path.join(self._config_dir, config_filename) + Rouge155.write_config_static( + self._system_dir, self._system_filename_pattern, self._model_dir, + self._model_filename_pattern, self._config_file, system_id) + self.log.info("Written ROUGE configuration to {}".format( + self._config_file)) + + def evaluate(self, system_id=1, rouge_args=None): + """ + Run ROUGE to evaluate the system summaries in system_dir against + the model summaries in model_dir. The summaries are assumed to + be in the one-sentence-per-line HTML format ROUGE understands. + + system_id: Optional system ID which will be printed in + ROUGE's output. + + Returns: Rouge output as string. + + """ + self.write_config(system_id=system_id) + options = self.__get_options(rouge_args) + command = [self._bin_path] + options + self.log.info("Running ROUGE with command {}".format(" ".join(command))) + rouge_output = check_output(command).decode("UTF-8") + return rouge_output + + def convert_and_evaluate(self, + system_id=1, + split_sentences=False, + rouge_args=None): + """ + Convert plain text summaries to ROUGE format and run ROUGE to + evaluate the system summaries in system_dir against the model + summaries in model_dir. Optionally split texts into sentences + in case they aren't already. + + This is just a convenience method combining + convert_summaries_to_rouge_format() and evaluate(). + + split_sentences: Optional argument specifying if + sentences should be split. + system_id: Optional system ID which will be printed + in ROUGE's output. + + Returns: ROUGE output as string. + + """ + if split_sentences: + self.split_sentences() + self.__write_summaries() + rouge_output = self.evaluate(system_id, rouge_args) + return rouge_output + + def output_to_dict(self, output): + """ + Convert the ROUGE output into python dictionary for further + processing. + + """ + # 0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) + pattern = re.compile(r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " + r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") + results = {} + for line in output.split("\n"): + match = pattern.match(line) + if match: + sys_id, rouge_type, measure, result, conf_begin, conf_end = \ + match.groups() + measure = { + 'Average_R': 'recall', + 'Average_P': 'precision', + 'Average_F': 'f_score' + }[measure] + rouge_type = rouge_type.lower().replace("-", '_') + key = "{}_{}".format(rouge_type, measure) + results[key] = float(result) + results["{}_cb".format(key)] = float(conf_begin) + results["{}_ce".format(key)] = float(conf_end) + return results + + ################################################################### + # Private methods + + def __set_rouge_dir(self, home_dir=None): + """ + Verfify presence of ROUGE-1.5.5.pl and data folder, and set + those paths. + + """ + if not home_dir: + self._home_dir = self.__get_rouge_home_dir_from_settings() + else: + self._home_dir = home_dir + self.save_home_dir() + self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') + self.data_dir = os.path.join(self._home_dir, 'data') + if not os.path.exists(self._bin_path): + raise Exception("ROUGE binary not found at {}. Please set the " + "correct path by running pyrouge_set_rouge_path " + "/path/to/rouge/home.".format(self._bin_path)) + + def __get_rouge_home_dir_from_settings(self): + config = ConfigParser() + with open(self._settings_file) as f: + if hasattr(config, "read_file"): + config.read_file(f) + else: + # use deprecated python 2.x method + config.readfp(f) + rouge_home_dir = config.get('pyrouge settings', 'home_dir') + return rouge_home_dir + + @staticmethod + def __get_eval_string(task_id, system_id, system_dir, system_filename, + model_dir, model_filenames): + """ + ROUGE can evaluate several system summaries for a given text + against several model summaries, i.e. there is an m-to-n + relation between system and model summaries. The system + summaries are listed in the tag and the model summaries + in the tag. pyrouge currently only supports one system + summary per text, i.e. it assumes a 1-to-n relation between + system and model summaries. + + """ + peer_elems = "

{name}

".format( + id=system_id, name=system_filename) + + model_elems = [ + "{name}".format( + id=chr(65 + i), name=name) + for i, name in enumerate(model_filenames) + ] + + model_elems = "\n\t\t\t".join(model_elems) + eval_string = """ + + {model_root} + {peer_root} + + + + {peer_elems} + + + {model_elems} + + +""".format( + task_id=task_id, + model_root=model_dir, + model_elems=model_elems, + peer_root=system_dir, + peer_elems=peer_elems) + return eval_string + + def __process_summaries(self, process_func): + """ + Helper method that applies process_func to the files in the + system and model folders and saves the resulting files to new + system and model folders. + + """ + temp_dir = mkdtemp(dir=self.temp_dir) + new_system_dir = os.path.join(temp_dir, "system") + os.mkdir(new_system_dir) + new_model_dir = os.path.join(temp_dir, "model") + os.mkdir(new_model_dir) + self.log.info("Processing summaries. Saving system files to {} and " + "model files to {}.".format(new_system_dir, + new_model_dir)) + process_func(self._system_dir, new_system_dir) + process_func(self._model_dir, new_model_dir) + self._system_dir = new_system_dir + self._model_dir = new_model_dir + + def __write_summaries(self): + self.log.info("Writing summaries.") + self.__process_summaries(self.convert_summaries_to_rouge_format) + + @staticmethod + def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): + pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) + model_filenames = [f for f in os.listdir(model_dir) if pattern.match(f)] + if not model_filenames: + raise Exception( + "Could not find any model summaries for the system" + " summary with ID {}. Specified model filename pattern was: " + "{}".format(id, model_filenames_pattern)) + return model_filenames + + def __get_options(self, rouge_args=None): + """ + Get supplied command line arguments for ROUGE or use default + ones. + + """ + if self.args: + options = self.args.split() + elif rouge_args: + options = rouge_args.split() + else: + options = [ + '-e', + self._data_dir, + '-c', + 95, + # '-2', + # '-1', + # '-U', + '-m', + # '-v', + '-r', + 1000, + '-n', + 2, + # '-w', 1.2, + '-a', + ] + options = list(map(str, options)) + + options = self.__add_config_option(options) + return options + + def __create_dir_property(self, dir_name, docstring): + """ + Generate getter and setter for a directory property. + + """ + property_name = "{}_dir".format(dir_name) + private_name = "_" + property_name + setattr(self, private_name, None) + + def fget(self): + return getattr(self, private_name) + + def fset(self, path): + verify_dir(path, dir_name) + setattr(self, private_name, path) + + p = property(fget=fget, fset=fset, doc=docstring) + setattr(self.__class__, property_name, p) + + def __set_dir_properties(self): + """ + Automatically generate the properties for directories. + + """ + directories = [ + ("home", "The ROUGE home directory."), + ("data", "The path of the ROUGE 'data' directory."), + ("system", "Path of the directory containing system summaries."), + ("model", "Path of the directory containing model summaries."), + ] + for (dirname, docstring) in directories: + self.__create_dir_property(dirname, docstring) + + def __clean_rouge_args(self, rouge_args): + """ + Remove enclosing quotation marks, if any. + + """ + if not rouge_args: + return + quot_mark_pattern = re.compile('"(.+)"') + match = quot_mark_pattern.match(rouge_args) + if match: + cleaned_args = match.group(1) + return cleaned_args + else: + return rouge_args + + def __add_config_option(self, options): + return options + [self._config_file] + + def __get_config_path(self): + if platform.system() == "Windows": + parent_dir = os.getenv("APPDATA") + config_dir_name = "pyrouge" + elif os.name == "posix": + parent_dir = os.path.expanduser("~") + config_dir_name = ".pyrouge" + else: + parent_dir = os.path.dirname(__file__) + config_dir_name = "" + config_dir = os.path.join(parent_dir, config_dir_name) + if not os.path.exists(config_dir): + os.makedirs(config_dir) + return os.path.join(config_dir, 'settings.ini') + + +if __name__ == "__main__": + import argparse + from utils.argparsers import rouge_path_parser + + parser = argparse.ArgumentParser(parents=[rouge_path_parser]) + args = parser.parse_args() + + rouge = Rouge155(args.rouge_home) + rouge.save_home_dir() diff --git a/examples/text_summarization/prophetnet/evaluate/cnndm/postprocess_cnn_dm.py b/examples/text_summarization/prophetnet/evaluate/cnndm/postprocess_cnn_dm.py new file mode 100644 index 000000000000..a32c99faa9cd --- /dev/null +++ b/examples/text_summarization/prophetnet/evaluate/cnndm/postprocess_cnn_dm.py @@ -0,0 +1,261 @@ +import argparse +import os +import shutil +import string +import tempfile +import time + +from bs_pyrouge import Rouge155 + +parser = argparse.ArgumentParser() +parser.add_argument("--generated", type=str, help="generated output file.") +parser.add_argument("--golden", type=str, help="Gold output file.") +parser.add_argument( + "--duplicate_rate", + type=float, + default=0.7, + help="If the duplicat rate (compared with history) is large, we can discard the current sentence." +) +parser.add_argument( + "--trunc_len", + type=int, + default=0, + help="Truncate line by the maximum length.") +args = parser.parse_args() + +fin = open(args.generated, 'r', encoding='utf-8') +fgolden = open(args.golden, 'r', encoding='utf-8') +dedup_rate = args.duplicate_rate +trunc_len = args.trunc_len + +_tok_dict = { + "(": "-LRB-", + ")": "-RRB-", + "[": "-LSB-", + "]": "-RSB-", + "{": "-LCB-", + "}": "-RCB-" +} + + +def _is_digit(w): + for ch in w: + if not (ch.isdigit() or ch == ','): + return False + return True + + +def fix_tokenization(text): + input_tokens = text.split() + output_tokens = [] + has_left_quote = False + has_left_single_quote = False + + i = 0 + prev_dash = False + while i < len(input_tokens): + tok = input_tokens[i] + flag_prev_dash = False + if tok in _tok_dict.keys(): + output_tokens.append(_tok_dict[tok]) + i += 1 + elif tok == "\"": + if has_left_quote: + output_tokens.append("''") + else: + output_tokens.append("``") + has_left_quote = not has_left_quote + i += 1 + elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and \ + input_tokens[i + 1] == "t": + output_tokens[-1] = output_tokens[-1][:-1] + output_tokens.append("n't") + i += 2 + elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[ + i + 1] in ("s", "d", "ll"): + output_tokens.append("'" + input_tokens[i + 1]) + i += 2 + elif tok == "'": + if has_left_single_quote: + output_tokens.append("'") + else: + output_tokens.append("`") + has_left_single_quote = not has_left_single_quote + i += 1 + elif tok == "." and i < len(input_tokens) - 2 and input_tokens[ + i + 1] == "." and input_tokens[i + 2] == ".": + output_tokens.append("...") + i += 3 + elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[ + -1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[ + i + 1]): + # $ 3 , 000 -> $ 3,000 + output_tokens[-1] += ',' + input_tokens[i + 1] + i += 2 + elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and \ + input_tokens[i + 1].isdigit(): + # 3 . 03 -> $ 3.03 + output_tokens[-1] += '.' + input_tokens[i + 1] + i += 2 + elif tok == "." and len(output_tokens) > 0 and len(output_tokens[ + -1]) == 1 and output_tokens[-1].isupper() and i < len( + input_tokens) - 2 and len(input_tokens[ + i + 1]) == 1 and input_tokens[i + 1].isupper( + ) and input_tokens[i + 2] == '.': + # U . N . -> U.N. + k = i + 3 + while k + 2 < len(input_tokens): + if len(input_tokens[k + 1]) == 1 and input_tokens[ + k + 1].isupper() and input_tokens[k + 2] == '.': + k += 2 + else: + break + output_tokens[-1] += ''.join(input_tokens[i:k]) + i += 2 + elif tok == "-": + if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": + output_tokens.append("--") + i += 2 + elif i == len(input_tokens) - 1 or i == 0: + output_tokens.append("-") + i += 1 + elif output_tokens[-1] not in string.punctuation and input_tokens[ + i + 1][0] not in string.punctuation: + output_tokens[-1] += "-" + i += 1 + flag_prev_dash = True + else: + output_tokens.append("-") + i += 1 + elif prev_dash and len(output_tokens) > 0 and tok[ + 0] not in string.punctuation: + output_tokens[-1] += tok + i += 1 + else: + output_tokens.append(tok) + i += 1 + prev_dash = flag_prev_dash + return " ".join(output_tokens) + + +def remove_duplicate(l_list, duplicate_rate): + tk_list = [l.lower().split() for l in l_list] + r_list = [] + history_set = set() + for i, w_list in enumerate(tk_list): + w_set = set(w_list) + if len(w_set & history_set) / len(w_set) <= duplicate_rate: + r_list.append(l_list[i]) + history_set |= w_set + return r_list + + +def test_rouge(cand, ref): + temp_dir = tempfile.mkdtemp() + candidates = cand + references = ref + assert len(candidates) == len(references) + + cnt = len(candidates) + current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) + if not os.path.isdir(tmp_dir): + os.mkdir(tmp_dir) + os.mkdir(tmp_dir + "/candidate") + os.mkdir(tmp_dir + "/reference") + try: + for i in range(cnt): + if len(references[i]) < 1: + continue + with open( + tmp_dir + "/candidate/cand.{}.txt".format(i), + "w", + encoding="utf-8") as f: + f.write(candidates[i]) + with open( + tmp_dir + "/reference/ref.{}.txt".format(i), + "w", + encoding="utf-8") as f: + f.write(references[i]) + r = Rouge155(temp_dir=temp_dir) + r.model_dir = tmp_dir + "/reference/" + r.system_dir = tmp_dir + "/candidate/" + r.model_filename_pattern = 'ref.#ID#.txt' + r.system_filename_pattern = r'cand.(\d+).txt' + rouge_results = r.convert_and_evaluate() + print(rouge_results) + results_dict = r.output_to_dict(rouge_results) + finally: + if os.path.isdir(tmp_dir): + shutil.rmtree(tmp_dir) + return results_dict + + +def rouge_results_to_str(results_dict): + return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( + results_dict["rouge_1_f_score"] * 100, results_dict["rouge_2_f_score"] * + 100, results_dict["rouge_l_f_score"] * 100, + results_dict["rouge_1_recall"] * 100, results_dict["rouge_2_recall"] * + 100, results_dict["rouge_l_recall"] * 100) + + +def count_tokens(tokens): + counter = {} + for t in tokens: + if t in counter.keys(): + counter[t] += 1 + else: + counter[t] = 1 + return counter + + +def get_f1(text_a, text_b): + tokens_a = text_a.lower().split() + tokens_b = text_b.lower().split() + if len(tokens_a) == 0 or len(tokens_b) == 0: + return 1 if len(tokens_a) == len(tokens_b) else 0 + set_a = count_tokens(tokens_a) + set_b = count_tokens(tokens_b) + match = 0 + for token in set_a.keys(): + if token in set_b.keys(): + match += min(set_a[token], set_b[token]) + p = match / len(tokens_a) + r = match / len(tokens_b) + return 2.0 * p * r / (p + r + 1e-5) + + +generated_list = [] +for line in fin: + buf = [] + for sentence in line.strip().split('[X_SEP]'): + sentence = fix_tokenization(sentence) + if any(get_f1(sentence, s) > 1.0 for s in buf): + continue + s_len = len(sentence.split()) + if s_len <= 4: + continue + buf.append(sentence) + if dedup_rate < 1: + buf = remove_duplicate(buf, dedup_rate) + if trunc_len: + num_left = trunc_len + trunc_list = [] + for bit in buf: + tk_list = bit.split() + n = min(len(tk_list), num_left) + trunc_list.append(' '.join(tk_list[:n])) + num_left -= n + if num_left <= 0: + breakgolden_list + else: + trunc_list = buf + generated_list.append("\n".join(trunc_list)) + +golden_list = [] +for line in fgolden: + line = line.strip().replace(" ", '\n') + golden_list.append(line) + +scores = test_rouge(generated_list, golden_list) +print(rouge_results_to_str(scores)) diff --git a/examples/text_summarization/prophetnet/evaluate/gigaword/__init__.py b/examples/text_summarization/prophetnet/evaluate/gigaword/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py b/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py new file mode 100644 index 000000000000..f7523c29d3ba --- /dev/null +++ b/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py @@ -0,0 +1,658 @@ +from __future__ import print_function, unicode_literals, division + +import codecs +import logging +import os +import platform +import re +from functools import partial +from subprocess import check_output +from tempfile import mkdtemp + +try: + from configparser import ConfigParser +except ImportError: + from ConfigParser import ConfigParser + +from pyrouge.utils import log +from pyrouge.utils.file_utils import verify_dir + +REMAP = { + "-lrb-": "(", + "-rrb-": ")", + "-lcb-": "{", + "-rcb-": "}", + "-lsb-": "[", + "-rsb-": "]", + "``": '"', + "''": '"' +} + + +def clean(x): + return re.sub(r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", + lambda m: REMAP.get(m.group()), x) + + +class DirectoryProcessor: + @staticmethod + def process(input_dir, output_dir, function): + """ + Apply function to all files in input_dir and save the resulting ouput + files in output_dir. + + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logger = log.get_global_console_logger() + logger.info("Processing files in {}.".format(input_dir)) + input_file_names = os.listdir(input_dir) + for input_file_name in input_file_names: + input_file = os.path.join(input_dir, input_file_name) + with codecs.open(input_file, "r", encoding="UTF-8") as f: + input_string = f.read() + output_string = function(input_string) + output_file = os.path.join(output_dir, input_file_name) + with codecs.open(output_file, "w", encoding="UTF-8") as f: + f.write(clean(output_string.lower())) + logger.info("Saved processed files to {}.".format(output_dir)) + + +class Rouge155(object): + """ + This is a wrapper for the ROUGE 1.5.5 summary evaluation package. + This class is designed to simplify the evaluation process by: + + 1) Converting summaries into a format ROUGE understands. + 2) Generating the ROUGE configuration file automatically based + on filename patterns. + + This class can be used within Python like this: + + rouge = Rouge155() + rouge.system_dir = 'test/systems' + rouge.model_dir = 'test/models' + + # The system filename pattern should contain one group that + # matches the document ID. + rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' + + # The model filename pattern has '#ID#' as a placeholder for the + # document ID. If there are multiple model summaries, pyrouge + # will use the provided regex to automatically match them with + # the corresponding system summary. Here, [A-Z] matches + # multiple model summaries for a given #ID#. + rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' + + rouge_output = rouge.evaluate() + print(rouge_output) + output_dict = rouge.output_to_dict(rouge_ouput) + print(output_dict) + -> {'rouge_1_f_score': 0.95652, + 'rouge_1_f_score_cb': 0.95652, + 'rouge_1_f_score_ce': 0.95652, + 'rouge_1_precision': 0.95652, + [...] + + + To evaluate multiple systems: + + rouge = Rouge155() + rouge.system_dir = '/PATH/TO/systems' + rouge.model_dir = 'PATH/TO/models' + for system_id in ['id1', 'id2', 'id3']: + rouge.system_filename_pattern = \ + 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) + rouge.model_filename_pattern = \ + 'SL.P.10.R.[A-Z].SL062003-#ID#.html' + rouge_output = rouge.evaluate(system_id) + print(rouge_output) + + """ + + def __init__(self, rouge_dir=None, rouge_args=None, temp_dir=None): + """ + Create a Rouge155 object. + + rouge_dir: Directory containing Rouge-1.5.5.pl + rouge_args: Arguments to pass through to ROUGE if you + don't want to use the default pyrouge + arguments. + + """ + self.temp_dir = temp_dir + self.log = log.get_global_console_logger() + self.log.setLevel(logging.WARNING) + self.__set_dir_properties() + self._config_file = None + self._settings_file = self.__get_config_path() + self.__set_rouge_dir(rouge_dir) + self.args = self.__clean_rouge_args(rouge_args) + self._system_filename_pattern = None + self._model_filename_pattern = None + + def save_home_dir(self): + config = ConfigParser() + section = 'pyrouge settings' + config.add_section(section) + config.set(section, 'home_dir', self._home_dir) + with open(self._settings_file, 'w') as f: + config.write(f) + self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) + + @property + def settings_file(self): + """ + Path of the setttings file, which stores the ROUGE home dir. + + """ + return self._settings_file + + @property + def bin_path(self): + """ + The full path of the ROUGE binary (although it's technically + a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl + + """ + if self._bin_path is None: + raise Exception( + "ROUGE path not set. Please set the ROUGE home directory " + "and ensure that ROUGE-1.5.5.pl exists in it.") + return self._bin_path + + @property + def system_filename_pattern(self): + """ + The regular expression pattern for matching system summary + filenames. The regex string. + + E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system + filenames in the SPL2003/system folder of the ROUGE SPL example + in the "sample-test" folder. + + Currently, there is no support for multiple systems. + + """ + return self._system_filename_pattern + + @system_filename_pattern.setter + def system_filename_pattern(self, pattern): + self._system_filename_pattern = pattern + + @property + def model_filename_pattern(self): + """ + The regular expression pattern for matching model summary + filenames. The pattern needs to contain the string "#ID#", + which is a placeholder for the document ID. + + E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model + filenames in the SPL2003/system folder of the ROUGE SPL + example in the "sample-test" folder. + + "#ID#" is a placeholder for the document ID which has been + matched by the "(\d+)" part of the system filename pattern. + The different model summaries for a given document ID are + matched by the "[A-Z]" part. + + """ + return self._model_filename_pattern + + @model_filename_pattern.setter + def model_filename_pattern(self, pattern): + self._model_filename_pattern = pattern + + @property + def config_file(self): + return self._config_file + + @config_file.setter + def config_file(self, path): + config_dir, _ = os.path.split(path) + verify_dir(config_dir, "configuration file") + self._config_file = path + + def split_sentences(self): + """ + ROUGE requires texts split into sentences. In case the texts + are not already split, this method can be used. + + """ + from pyrouge.utils.sentence_splitter import PunktSentenceSplitter + self.log.info("Splitting sentences.") + ss = PunktSentenceSplitter() + + def sent_split_to_string(s): + return "\n".join(ss.split(s)) + + process_func = partial( + DirectoryProcessor.process, function=sent_split_to_string) + self.__process_summaries(process_func) + + @staticmethod + def convert_summaries_to_rouge_format(input_dir, output_dir): + """ + Convert all files in input_dir into a format ROUGE understands + and saves the files to output_dir. The input files are assumed + to be plain text with one sentence per line. + + input_dir: Path of directory containing the input files. + output_dir: Path of directory in which the converted files + will be saved. + + """ + DirectoryProcessor.process(input_dir, output_dir, + Rouge155.convert_text_to_rouge_format) + + @staticmethod + def convert_text_to_rouge_format(text, title="dummy title"): + """ + Convert a text to a format ROUGE understands. The text is + assumed to contain one sentence per line. + + text: The text to convert, containg one sentence per line. + title: Optional title for the text. The title will appear + in the converted file, but doesn't seem to have + any other relevance. + + Returns: The converted text as string. + + """ + sentences = text.split("\n") + sent_elems = [ + "[{i}] " + "{text}".format( + i=i, text=sent) for i, sent in enumerate( + sentences, start=1) + ] + html = """ + +{title} + + +{elems} + +""".format( + title=title, elems="\n".join(sent_elems)) + + return html + + @staticmethod + def write_config_static(system_dir, + system_filename_pattern, + model_dir, + model_filename_pattern, + config_file_path, + system_id=None): + """ + Write the ROUGE configuration file, which is basically a list + of system summary files and their corresponding model summary + files. + + pyrouge uses regular expressions to automatically find the + matching model summary files for a given system summary file + (cf. docstrings for system_filename_pattern and + model_filename_pattern). + + system_dir: Path of directory containing + system summaries. + system_filename_pattern: Regex string for matching + system summary filenames. + model_dir: Path of directory containing + model summaries. + model_filename_pattern: Regex string for matching model + summary filenames. + config_file_path: Path of the configuration file. + system_id: Optional system ID string which + will appear in the ROUGE output. + + """ + system_filenames = [f for f in os.listdir(system_dir)] + system_models_tuples = [] + + system_filename_pattern = re.compile(system_filename_pattern) + for system_filename in sorted(system_filenames): + match = system_filename_pattern.match(system_filename) + if match: + id = match.groups(0)[0] + model_filenames = [model_filename_pattern.replace('#ID#', id)] + # model_filenames = Rouge155.__get_model_filenames_for_id( + # id, model_dir, model_filename_pattern) + system_models_tuples.append( + (system_filename, sorted(model_filenames))) + if not system_models_tuples: + raise Exception("Did not find any files matching the pattern {} " + "in the system summaries directory {}.".format( + system_filename_pattern.pattern, system_dir)) + + with codecs.open(config_file_path, 'w', encoding='utf-8') as f: + f.write('') + for task_id, (system_filename, model_filenames) in enumerate( + system_models_tuples, start=1): + eval_string = Rouge155.__get_eval_string( + task_id, system_id, system_dir, system_filename, model_dir, + model_filenames) + f.write(eval_string) + f.write("") + + def write_config(self, config_file_path=None, system_id=None): + """ + Write the ROUGE configuration file, which is basically a list + of system summary files and their matching model summary files. + + This is a non-static version of write_config_file_static(). + + config_file_path: Path of the configuration file. + system_id: Optional system ID string which will + appear in the ROUGE output. + + """ + if not system_id: + system_id = 1 + if (not config_file_path) or (not self._config_dir): + self._config_dir = mkdtemp(dir=self.temp_dir) + config_filename = "rouge_conf.xml" + else: + config_dir, config_filename = os.path.split(config_file_path) + verify_dir(config_dir, "configuration file") + self._config_file = os.path.join(self._config_dir, config_filename) + Rouge155.write_config_static( + self._system_dir, self._system_filename_pattern, self._model_dir, + self._model_filename_pattern, self._config_file, system_id) + self.log.info("Written ROUGE configuration to {}".format( + self._config_file)) + + def evaluate(self, system_id=1, rouge_args=None): + """ + Run ROUGE to evaluate the system summaries in system_dir against + the model summaries in model_dir. The summaries are assumed to + be in the one-sentence-per-line HTML format ROUGE understands. + + system_id: Optional system ID which will be printed in + ROUGE's output. + + Returns: Rouge output as string. + + """ + self.write_config(system_id=system_id) + options = self.__get_options(rouge_args) + command = [self._bin_path] + options + self.log.info("Running ROUGE with command {}".format(" ".join(command))) + rouge_output = check_output(command).decode("UTF-8") + return rouge_output + + def convert_and_evaluate(self, + system_id=1, + split_sentences=False, + rouge_args=None): + """ + Convert plain text summaries to ROUGE format and run ROUGE to + evaluate the system summaries in system_dir against the model + summaries in model_dir. Optionally split texts into sentences + in case they aren't already. + + This is just a convenience method combining + convert_summaries_to_rouge_format() and evaluate(). + + split_sentences: Optional argument specifying if + sentences should be split. + system_id: Optional system ID which will be printed + in ROUGE's output. + + Returns: ROUGE output as string. + + """ + if split_sentences: + self.split_sentences() + self.__write_summaries() + rouge_output = self.evaluate(system_id, rouge_args) + return rouge_output + + def output_to_dict(self, output): + """ + Convert the ROUGE output into python dictionary for further + processing. + + """ + # 0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) + pattern = re.compile(r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " + r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") + results = {} + for line in output.split("\n"): + match = pattern.match(line) + if match: + sys_id, rouge_type, measure, result, conf_begin, conf_end = \ + match.groups() + measure = { + 'Average_R': 'recall', + 'Average_P': 'precision', + 'Average_F': 'f_score' + }[measure] + rouge_type = rouge_type.lower().replace("-", '_') + key = "{}_{}".format(rouge_type, measure) + results[key] = float(result) + results["{}_cb".format(key)] = float(conf_begin) + results["{}_ce".format(key)] = float(conf_end) + return results + + ################################################################### + # Private methods + + def __set_rouge_dir(self, home_dir=None): + """ + Verfify presence of ROUGE-1.5.5.pl and data folder, and set + those paths. + + """ + if not home_dir: + self._home_dir = self.__get_rouge_home_dir_from_settings() + else: + self._home_dir = home_dir + self.save_home_dir() + self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') + self.data_dir = os.path.join(self._home_dir, 'data') + if not os.path.exists(self._bin_path): + raise Exception("ROUGE binary not found at {}. Please set the " + "correct path by running pyrouge_set_rouge_path " + "/path/to/rouge/home.".format(self._bin_path)) + + def __get_rouge_home_dir_from_settings(self): + config = ConfigParser() + with open(self._settings_file) as f: + if hasattr(config, "read_file"): + config.read_file(f) + else: + # use deprecated python 2.x method + config.readfp(f) + rouge_home_dir = config.get('pyrouge settings', 'home_dir') + return rouge_home_dir + + @staticmethod + def __get_eval_string(task_id, system_id, system_dir, system_filename, + model_dir, model_filenames): + """ + ROUGE can evaluate several system summaries for a given text + against several model summaries, i.e. there is an m-to-n + relation between system and model summaries. The system + summaries are listed in the tag and the model summaries + in the tag. pyrouge currently only supports one system + summary per text, i.e. it assumes a 1-to-n relation between + system and model summaries. + + """ + peer_elems = "

{name}

".format( + id=system_id, name=system_filename) + + model_elems = [ + "{name}".format( + id=chr(65 + i), name=name) + for i, name in enumerate(model_filenames) + ] + + model_elems = "\n\t\t\t".join(model_elems) + eval_string = """ + + {model_root} + {peer_root} + + + + {peer_elems} + + + {model_elems} + + +""".format( + task_id=task_id, + model_root=model_dir, + model_elems=model_elems, + peer_root=system_dir, + peer_elems=peer_elems) + return eval_string + + def __process_summaries(self, process_func): + """ + Helper method that applies process_func to the files in the + system and model folders and saves the resulting files to new + system and model folders. + + """ + temp_dir = mkdtemp(dir=self.temp_dir) + new_system_dir = os.path.join(temp_dir, "system") + os.mkdir(new_system_dir) + new_model_dir = os.path.join(temp_dir, "model") + os.mkdir(new_model_dir) + self.log.info("Processing summaries. Saving system files to {} and " + "model files to {}.".format(new_system_dir, + new_model_dir)) + process_func(self._system_dir, new_system_dir) + process_func(self._model_dir, new_model_dir) + self._system_dir = new_system_dir + self._model_dir = new_model_dir + + def __write_summaries(self): + self.log.info("Writing summaries.") + self.__process_summaries(self.convert_summaries_to_rouge_format) + + @staticmethod + def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): + pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) + model_filenames = [f for f in os.listdir(model_dir) if pattern.match(f)] + if not model_filenames: + raise Exception( + "Could not find any model summaries for the system" + " summary with ID {}. Specified model filename pattern was: " + "{}".format(id, model_filenames_pattern)) + return model_filenames + + def __get_options(self, rouge_args=None): + """ + Get supplied command line arguments for ROUGE or use default + ones. + + """ + if self.args: + options = self.args.split() + elif rouge_args: + options = rouge_args.split() + else: + options = [ + '-e', + self._data_dir, + '-c', + 95, + # '-2', + # '-1', + # '-U', + '-m', + # '-v', + '-r', + 1000, + '-n', + 2, + # '-w', 1.2, + '-a', + ] + options = list(map(str, options)) + + options = self.__add_config_option(options) + return options + + def __create_dir_property(self, dir_name, docstring): + """ + Generate getter and setter for a directory property. + + """ + property_name = "{}_dir".format(dir_name) + private_name = "_" + property_name + setattr(self, private_name, None) + + def fget(self): + return getattr(self, private_name) + + def fset(self, path): + verify_dir(path, dir_name) + setattr(self, private_name, path) + + p = property(fget=fget, fset=fset, doc=docstring) + setattr(self.__class__, property_name, p) + + def __set_dir_properties(self): + """ + Automatically generate the properties for directories. + + """ + directories = [ + ("home", "The ROUGE home directory."), + ("data", "The path of the ROUGE 'data' directory."), + ("system", "Path of the directory containing system summaries."), + ("model", "Path of the directory containing model summaries."), + ] + for (dirname, docstring) in directories: + self.__create_dir_property(dirname, docstring) + + def __clean_rouge_args(self, rouge_args): + """ + Remove enclosing quotation marks, if any. + + """ + if not rouge_args: + return + quot_mark_pattern = re.compile('"(.+)"') + match = quot_mark_pattern.match(rouge_args) + if match: + cleaned_args = match.group(1) + return cleaned_args + else: + return rouge_args + + def __add_config_option(self, options): + return options + [self._config_file] + + def __get_config_path(self): + if platform.system() == "Windows": + parent_dir = os.getenv("APPDATA") + config_dir_name = "pyrouge" + elif os.name == "posix": + parent_dir = os.path.expanduser("~") + config_dir_name = ".pyrouge" + else: + parent_dir = os.path.dirname(__file__) + config_dir_name = "" + config_dir = os.path.join(parent_dir, config_dir_name) + if not os.path.exists(config_dir): + os.makedirs(config_dir) + return os.path.join(config_dir, 'settings.ini') + + +if __name__ == "__main__": + import argparse + from utils.argparsers import rouge_path_parser + + parser = argparse.ArgumentParser(parents=[rouge_path_parser]) + args = parser.parse_args() + + rouge = Rouge155(args.rouge_home) + rouge.save_home_dir() diff --git a/examples/text_summarization/prophetnet/evaluate/gigaword/eval.py b/examples/text_summarization/prophetnet/evaluate/gigaword/eval.py new file mode 100644 index 000000000000..84fa32bed177 --- /dev/null +++ b/examples/text_summarization/prophetnet/evaluate/gigaword/eval.py @@ -0,0 +1,380 @@ +"""BERT finetuning runner.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import glob +import json +import logging +import os +import shutil +import string +import tempfile +import time +from multiprocessing import Pool, cpu_count +from pathlib import Path + +# pip install py-rouge +import rouge + +# from pytorch_pretrained_bert.tokenization import BertTokenizer +# pip install pyrouge +from bs_pyrouge import Rouge155 + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +parser = argparse.ArgumentParser() + +# Required parameters +parser.add_argument("--gold", type=str, help="Gold output file.") +parser.add_argument("--pred", type=str, help="Input prediction file.") +parser.add_argument( + "--split", type=str, default="", help="Data split (train/dev/test).") +parser.add_argument("--save_best", action='store_true', help="Save best epoch.") +parser.add_argument( + "--only_eval_best", action='store_true', help="Only evaluate best epoch.") +parser.add_argument( + "--trunc_len", + type=int, + default=0, + help="Truncate line by the maximum length.") +default_process_count = max(1, cpu_count() - 1) +parser.add_argument( + "--processes", + type=int, + default=default_process_count, + help="Number of processes to use (default %(default)s)") +parser.add_argument( + "--perl", action='store_true', help="Using the perl script.") +parser.add_argument( + '--lazy_eval', + action='store_true', + help="Skip evaluation if the .rouge file exists.") +args = parser.parse_args() + +SPECIAL_TOKEN = ["[UNK]", "[PAD]", "[CLS]", "[MASK]"] +evaluator = rouge.Rouge( + metrics=['rouge-n', 'rouge-l'], + max_n=2, + limit_length=False, + apply_avg=True, + weight_factor=1.2) + + +def test_rouge(cand, ref): + temp_dir = tempfile.mkdtemp() + candidates = cand + references = ref + assert len(candidates) == len(references) + + cnt = len(candidates) + current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) + if not os.path.isdir(tmp_dir): + os.mkdir(tmp_dir) + os.mkdir(tmp_dir + "/candidate") + os.mkdir(tmp_dir + "/reference") + try: + for i in range(cnt): + if len(references[i]) < 1: + continue + with open( + tmp_dir + "/candidate/cand.{}.txt".format(i), + "w", + encoding="utf-8") as f: + f.write(candidates[i]) + with open( + tmp_dir + "/reference/ref.{}.txt".format(i), + "w", + encoding="utf-8") as f: + f.write(references[i]) + r = Rouge155(temp_dir=temp_dir) + r.model_dir = tmp_dir + "/reference/" + r.system_dir = tmp_dir + "/candidate/" + r.model_filename_pattern = 'ref.#ID#.txt' + r.system_filename_pattern = r'cand.(\d+).txt' + rouge_results = r.convert_and_evaluate() + print(rouge_results) + results_dict = r.output_to_dict(rouge_results) + finally: + if os.path.isdir(tmp_dir): + shutil.rmtree(tmp_dir) + return results_dict + + +def rouge_results_to_str(results_dict): + return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( + results_dict["rouge_1_f_score"] * 100, results_dict["rouge_2_f_score"] * + 100, results_dict["rouge_l_f_score"] * 100, + results_dict["rouge_1_recall"] * 100, results_dict["rouge_2_recall"] * + 100, results_dict["rouge_l_recall"] * 100) + + +def count_tokens(tokens): + counter = {} + for t in tokens: + if t in counter.keys(): + counter[t] += 1 + else: + counter[t] = 1 + return counter + + +def get_f1(text_a, text_b): + tokens_a = text_a.lower().split() + tokens_b = text_b.lower().split() + if len(tokens_a) == 0 or len(tokens_b) == 0: + return 1 if len(tokens_a) == len(tokens_b) else 0 + set_a = count_tokens(tokens_a) + set_b = count_tokens(tokens_b) + match = 0 + for token in set_a.keys(): + if token in set_b.keys(): + match += min(set_a[token], set_b[token]) + p = match / len(tokens_a) + r = match / len(tokens_b) + return 2.0 * p * r / (p + r + 1e-5) + + +_tok_dict = { + "(": "-lrb-", + ")": "-rrb-", + "[": "-lsb-", + "]": "-rsb-", + "{": "-lcb-", + "}": "-rcb-", + "[UNK]": "UNK", + '&': '&', + '<': '<', + '>': '>' +} + + +def _is_digit(w): + for ch in w: + if not (ch.isdigit() or ch == ','): + return False + return True + + +def fix_tokenization(text): + input_tokens = text.split() + output_tokens = [] + has_left_quote = False + has_left_single_quote = False + + i = 0 + prev_dash = False + while i < len(input_tokens): + tok = input_tokens[i] + flag_prev_dash = False + if tok in _tok_dict.keys(): + output_tokens.append(_tok_dict[tok]) + i += 1 + elif tok == "\"": + if has_left_quote: + output_tokens.append("''") + else: + output_tokens.append("``") + has_left_quote = not has_left_quote + i += 1 + elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and \ + input_tokens[i + 1] == "t": + output_tokens[-1] = output_tokens[-1][:-1] + output_tokens.append("n't") + i += 2 + elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[ + i + 1] in ("s", "d", "ll"): + output_tokens.append("'" + input_tokens[i + 1]) + i += 2 + elif tok == "'": + if has_left_single_quote: + output_tokens.append("'") + else: + output_tokens.append("`") + has_left_single_quote = not has_left_single_quote + i += 1 + elif tok == "." and i < len(input_tokens) - 2 and input_tokens[ + i + 1] == "." and input_tokens[i + 2] == ".": + output_tokens.append("...") + i += 3 + elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[ + -1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[ + i + 1]): + # $ 3 , 000 -> $ 3,000 + output_tokens[-1] += ',' + input_tokens[i + 1] + i += 2 + elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and \ + input_tokens[i + 1].isdigit(): + # 3 . 03 -> $ 3.03 + output_tokens[-1] += '.' + input_tokens[i + 1] + i += 2 + elif tok == "." and len(output_tokens) > 0 and len(output_tokens[ + -1]) == 1 and output_tokens[-1].isupper() and i < len( + input_tokens) - 2 and len(input_tokens[ + i + 1]) == 1 and input_tokens[i + 1].isupper( + ) and input_tokens[i + 2] == '.': + # U . N . -> U.N. + k = i + 3 + while k + 2 < len(input_tokens): + if len(input_tokens[k + 1]) == 1 and input_tokens[ + k + 1].isupper() and input_tokens[k + 2] == '.': + k += 2 + else: + break + output_tokens[-1] += ''.join(input_tokens[i:k]) + i += 2 + elif tok == "-": + if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": + output_tokens.append("--") + i += 2 + elif i == len(input_tokens) - 1 or i == 0: + output_tokens.append("-") + i += 1 + elif output_tokens[-1] not in string.punctuation and input_tokens[ + i + 1][0] not in string.punctuation: + output_tokens[-1] += "-" + i += 1 + flag_prev_dash = True + else: + output_tokens.append("-") + i += 1 + elif prev_dash and len(output_tokens) > 0 and tok[ + 0] not in string.punctuation: + output_tokens[-1] += tok + i += 1 + else: + output_tokens.append(tok) + i += 1 + prev_dash = flag_prev_dash + return " ".join(output_tokens) + + +def process_eval(eval_fn): + gold_list = [] + with open(args.gold, "r", encoding="utf-8") as f_in: + for l in f_in: + line = l.strip() + gold_list.append(line) + + pred_list = [] + with open(eval_fn, "r", encoding="utf-8") as f_in: + for l in f_in: + buf = [] + sentence = fix_tokenization(l.strip()).replace('1', '#') + buf.append(sentence) + if args.trunc_len: + num_left = args.trunc_len + trunc_list = [] + for bit in buf: + tk_list = bit.split() + n = min(len(tk_list), num_left) + trunc_list.append(' '.join(tk_list[:n])) + num_left -= n + if num_left <= 0: + break + else: + trunc_list = buf + line = "\n".join(trunc_list) + pred_list.append(line) + with open(eval_fn + '.post', 'w', encoding='utf-8') as f_out: + for l in pred_list: + f_out.write(l.strip()) + f_out.write('\n') + # rouge scores + if len(pred_list) < len(gold_list): + # evaluate subset + gold_list = gold_list[:len(pred_list)] + assert len(pred_list) == len(gold_list) + if args.perl: + scores = test_rouge(pred_list, gold_list) + else: + scores = evaluator.get_scores(pred_list, [[it] for it in gold_list]) + return eval_fn, scores + + +def main(): + if args.perl: + eval_fn_list = list(glob.glob(args.pred)) + else: + eval_fn_list = [ + eval_fn for eval_fn in glob.glob(args.pred) + if not (args.lazy_eval and Path(eval_fn + ".rouge").exists()) + ] + eval_fn_list = list( + filter(lambda fn: not (fn.endswith('.post') or fn.endswith('.rouge')), + eval_fn_list)) + + if args.only_eval_best: + best_epoch_dict = {} + for dir_path in set(Path(fn).parent for fn in eval_fn_list): + fn_save = os.path.join(dir_path, 'save_best.dev') + if Path(fn_save).exists(): + with open(fn_save, 'r') as f_in: + __, o_name, __ = f_in.read().strip().split('\n') + epoch = o_name.split('.')[1] + best_epoch_dict[dir_path] = epoch + new_eval_fn_list = [] + for fn in eval_fn_list: + dir_path = Path(fn).parent + if dir_path in best_epoch_dict: + if Path(fn).name.split('.')[1] == best_epoch_dict[dir_path]: + new_eval_fn_list.append(fn) + eval_fn_list = new_eval_fn_list + + logger.info("***** Evaluation: %s *****", ','.join(eval_fn_list)) + num_pool = max(1, min(args.processes, len(eval_fn_list))) + logger.info(args.processes, len(eval_fn_list), num_pool) + p = Pool(num_pool) + r_list = p.imap_unordered(process_eval, eval_fn_list) + r_list = sorted([(fn, scores) for fn, scores in r_list], key=lambda x: x[0]) + rg2_dict = {} + for fn, scores in r_list: + logger.info(fn) + if args.perl: + print(rouge_results_to_str(scores)) + else: + rg2_dict[fn] = scores['rouge-2']['f'] + print("ROUGE-1: {}\tROUGE-2: {}\tROUGE-L: {}\n".format(scores[ + 'rouge-1']['f'], scores['rouge-2']['f'], scores['rouge-l'][ + 'f'])) + with open(fn + ".rouge", 'w') as f_out: + f_out.write( + json.dumps({ + 'rg1': scores['rouge-1']['f'], + 'rg2': scores['rouge-2']['f'] + })) + p.close() + p.join() + + if args.save_best: + # find best results + group_dict = {} + for k, v in rg2_dict.items(): + d_name, o_name = Path(k).parent, Path(k).name + if (d_name not in group_dict) or (v > group_dict[d_name][1]): + group_dict[d_name] = (o_name, v) + # compare and save the best result + for k, v in group_dict.items(): + fn = os.path.join(k, 'save_best.' + args.split) + o_name_s, rst_s = v + should_save = True + if Path(fn).exists(): + with open(fn, 'r') as f_in: + rst_f = float(f_in.read().strip().split('\n')[-1]) + if rst_s <= rst_f: + should_save = False + if should_save: + with open(fn, 'w') as f_out: + f_out.write('{0}\n{1}\n{2}\n'.format(k, o_name_s, rst_s)) + + +if __name__ == "__main__": + main() diff --git a/examples/text_summarization/prophetnet/generate.py b/examples/text_summarization/prophetnet/generate.py new file mode 100644 index 000000000000..5c1cf11d7d82 --- /dev/null +++ b/examples/text_summarization/prophetnet/generate.py @@ -0,0 +1,341 @@ +import argparse +import os +import random +import time +from pprint import pprint + +import numpy as np +import paddle +from paddle.io import BatchSampler, DataLoader +from rouge_score import rouge_scorer, scoring +from tqdm import tqdm + +from paddlenlp.data import Pad, Tuple +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers.prophetnet.modeling import ProphetNetForConditionalGeneration, ProphetNetModel +from paddlenlp.transformers.prophetnet.tokenizer import ProphetNetTokenizer + +summarization_name_mapping = {"cnn_dailymail": ("article", "highlights")} + + +def parse_args(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--dataset", + default="gigaword", + choices=["cnndm", "gigaword"], + type=str, + help="Path to tokenizer vocab file. ") + parser.add_argument( + "--vocab_file", + default="./prophetnet.tokenizer", + type=str, + help="Path to tokenizer vocab file. ") + parser.add_argument( + '--output_path', + type=str, + default='generate.txt', + help='The file path where the infer result will be saved.') + parser.add_argument( + "--max_source_length", + default=1024, + type=int, + help="The maximum total input sequence length after " + "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--min_target_length", + default=45, + type=int, + help="The minimum total sequence length for target text when generating. " + ) + parser.add_argument( + "--max_target_length", + default=110, + type=int, + help="The maximum total sequence length for target text after " + "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." + "during ``evaluate`` and ``predict``.", ) + parser.add_argument( + '--decode_strategy', + default='beam_search', + type=str, + help='The decode strategy in generation.') + parser.add_argument( + '--top_k', + default=2, + type=int, + help='The number of highest probability vocabulary tokens to keep for top-k sampling.' + ) + parser.add_argument( + '--top_p', + default=1.0, + type=float, + help='The cumulative probability for top-p sampling.') + parser.add_argument( + '--num_beams', + default=5, + type=int, + help='The number of beams for beam search.') + parser.add_argument( + '--length_penalty', + default=1.2, + type=float, + help='The exponential penalty to the sequence length for beam search.') + parser.add_argument( + '--early_stopping', + default=False, + type=eval, + help='Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.' + ) + parser.add_argument( + "--diversity_rate", + default=0.0, + type=float, + help="The diversity of beam search. ") + parser.add_argument( + "--num_beam_groups", + default=1, + type=int, + help="Number of groups to divide `num_beams` into in order to use DIVERSE BEAM SEARCH." + ) + parser.add_argument( + "--repetition_penalty", + default=1.0, + type=float, + help="Number of groups to divide `num_beams` into in order to use DIVERSE BEAM SEARCH." + ) + parser.add_argument( + "--batch_size", + default=4, + type=int, + help="Batch size per GPU/CPU for testing or evaluation.") + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + choices=["cpu", "gpu", "xpu"], + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + "--ignore_pad_token_for_loss", + default=True, + type=bool, + help="Whether to ignore the tokens corresponding to " + "padded labels in the loss computation or not.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +def compute_metrics(preds, + labels, + tokenizer, + ignore_pad_token_for_loss=True, + compute_rouge_=True): + def compute_rouge(predictions, + references, + rouge_types=None, + use_stemmer=True): + if rouge_types is None: + rouge_types = ["rouge1", "rouge2", "rougeLsum"] + + scorer = rouge_scorer.RougeScorer( + rouge_types=rouge_types, use_stemmer=use_stemmer) + aggregator = scoring.BootstrapAggregator() + + for ref, pred in zip(references, predictions): + score = scorer.score(ref, pred) + aggregator.add_scores(score) + result = aggregator.aggregate() + result = { + key: round(value.mid.fmeasure * 100, 4) + for key, value in result.items() + } + return result + + def post_process_seq(seq, + bos_idx, + eos_idx, + output_bos=False, + output_eos=False): + """ + Post-process the decoded sequence. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = [ + idx for idx in seq[:eos_pos + 1] + if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) + ] + return seq + + if ignore_pad_token_for_loss: + labels = np.asarray(labels) + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_preds, decoded_labels = [], [] + for pred, label in zip(preds, labels): + pred_id = post_process_seq(pred, tokenizer.bos_token_id, + tokenizer.eos_token_id) + label_id = post_process_seq(label, tokenizer.bos_token_id, + tokenizer.eos_token_id) + decoded_preds.append(tokenizer.convert_ids_to_string(pred_id)) + decoded_labels.append(tokenizer.convert_ids_to_string(label_id)) + + if compute_rouge_: + rouge_result = compute_rouge(decoded_preds, decoded_labels) + return rouge_result, decoded_preds + else: + return decoded_preds, decoded_labels + + +def read(data_path): + data_path_src = data_path[0] + data_path_tgt = data_path[1] + with open(data_path_src, 'r', encoding='utf-8') as f_d_s: + src_lines_length = len(f_d_s.readlines()) + with open(data_path_tgt, 'r', encoding='utf-8') as f_d_t: + tgt_lines_length = len(f_d_t.readlines()) + assert src_lines_length == tgt_lines_length + with open(data_path_src, 'r', encoding='utf-8') as f_d_s: + with open(data_path_tgt, 'r', encoding='utf-8') as f_d_t: + for row_d_s, row_d_t in tqdm( + zip(f_d_s, f_d_t), total=src_lines_length): + yield {'article': row_d_s, 'highlights': row_d_t} + + +def convert_example(is_test=False): + def warpper(example): + """convert an example into necessary features""" + tokens = example['article'] + labels = example['highlights'] + src_ids, src_attention_mask_ids = tokens.split("$1$") + src_ids = [int(i) for i in src_ids.split(" ")] + src_attention_mask_ids = [ + int(i) for i in src_attention_mask_ids.split(" ") + ] + + if not is_test: + labels, decoder_input_attention_mask_ids = labels.split("$1$") + labels = [int(i) for i in labels.split(" ")] + decoder_input_attention_mask_ids = [ + int(i) for i in decoder_input_attention_mask_ids.split(" ") + ] + decoder_input_ids = [labels[-1]] + labels[:-1] + + return src_ids, src_attention_mask_ids, decoder_input_ids, decoder_input_attention_mask_ids, labels + + else: + labels, _ = labels.split("$1$") + labels = [int(i) for i in labels.split(" ")] + return src_ids, src_attention_mask_ids, labels + + return warpper + + +@paddle.no_grad() +def generate(args): + paddle.set_device(args.device) + tokenizer = ProphetNetTokenizer(vocab_file=args.vocab_file) + model = ProphetNetModel(vocab_size=30522) + model = ProphetNetForConditionalGeneration(model) + + ckpt = paddle.load("./ckpt/" + args.dataset + "/model_best.pdparams") + + model.load_dict(ckpt['model']) + + test_data_src = 'data/' + args.dataset + '_data/uncased_tok_data/test.src' + test_data_tgt = 'data/' + args.dataset + '_data/uncased_tok_data/test.tgt' + + test_dataset = load_dataset( + read, data_path=[test_data_src, test_data_tgt], lazy=False) + + trunc = convert_example(is_test=True) + + test_dataset = test_dataset.map(trunc) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids + Pad(axis=0, pad_val=0), # attn mask + Pad(axis=0, pad_val=tokenizer.pad_token_id) # labels + ): fn(samples) + + batch_sampler = BatchSampler( + test_dataset, batch_size=args.batch_size, shuffle=False) + test_data_loader = DataLoader( + dataset=test_dataset, + batch_sampler=batch_sampler, + num_workers=0, + collate_fn=batchify_fn, + return_list=True) + + model.eval() + total_time = 0.0 + start_time = time.time() + all_preds = [] + all_labels = [] + for step, batch in tqdm( + enumerate(test_data_loader), total=len(test_data_loader)): + input_ids, attention_mask, labels = batch + preds, _ = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_target_length, + min_length=args.min_target_length, + decode_strategy=args.decode_strategy, + top_k=args.top_k, + top_p=args.top_p, + num_beams=args.num_beams, + length_penalty=args.length_penalty, + early_stopping=args.early_stopping, + diversity_rate=args.diversity_rate, + num_beam_groups=args.num_beam_groups, + repetition_penalty=args.repetition_penalty) + total_time += (time.time() - start_time) + all_preds.extend(preds.numpy()) + all_labels.extend(labels.numpy()) + if step % args.logging_steps == 0: + print('step %d - %.3fs/step' % + (step, total_time / args.logging_steps)) + total_time = 0.0 + start_time = time.time() + decoded_preds, _ = compute_metrics( + all_preds, + all_labels, + tokenizer, + args.ignore_pad_token_for_loss, + compute_rouge_=False) + if not os.path.exists( + os.path.abspath( + os.path.dirname(args.output_path) + os.path.sep + ".")): + os.makedirs( + os.path.abspath( + os.path.dirname(args.output_path) + os.path.sep + ".")) + with open(args.output_path, 'w', encoding='utf-8') as fout: + for decoded_pred in decoded_preds: + fout.write(decoded_pred + '\n') + print('Save generated result into: %s' % args.output_path) + + +if __name__ == '__main__': + args = parse_args() + pprint(args) + generate(args) diff --git a/examples/text_summarization/prophetnet/requirements.txt b/examples/text_summarization/prophetnet/requirements.txt new file mode 100644 index 000000000000..4da80ee18539 --- /dev/null +++ b/examples/text_summarization/prophetnet/requirements.txt @@ -0,0 +1,5 @@ +configparser==5.2.0 +nltk==3.6.7 +numpy==1.21.0 +tqdm==4.62.3 +py-rouge=1.1 \ No newline at end of file diff --git a/examples/text_summarization/prophetnet/run_eval.sh b/examples/text_summarization/prophetnet/run_eval.sh new file mode 100644 index 000000000000..a1e70a2bf6f0 --- /dev/null +++ b/examples/text_summarization/prophetnet/run_eval.sh @@ -0,0 +1,37 @@ +DATASET=$1 + +if [ $DATASET = cnndm ] +then +python generate.py \ + --dataset=cnndm \ + --vocab_file=./prophetnet.tokenizer \ + --output_path=./generate/cnndm/generate.txt \ + --min_target_length=45 \ + --max_target_length=110 \ + --decode_strategy=beam_search \ + --num_beams=4 \ + --length_penalty=1.2 \ + --batch_size=16 \ + --ignore_pad_token_for_loss=True \ + --early_stopping=True \ + --logging_steps=100 \ + --device=gpu +else +python generate.py \ + --dataset=gigaword \ + --vocab_file=./prophetnet.tokenizer \ + --output_path=./generate/gigaword/generate.txt \ + --min_target_length=1 \ + --max_target_length=200 \ + --decode_strategy=beam_search \ + --num_beams=4 \ + --length_penalty=1.6 \ + --batch_size=16 \ + --ignore_pad_token_for_loss=True \ + --early_stopping=True \ + --logging_steps=100 \ + --device=gpu +fi + + +python eval.py --dataset $DATASET --generated ./generate/$DATASET/generate.txt \ No newline at end of file diff --git a/examples/text_summarization/prophetnet/run_train.sh b/examples/text_summarization/prophetnet/run_train.sh new file mode 100644 index 000000000000..efe119491c1d --- /dev/null +++ b/examples/text_summarization/prophetnet/run_train.sh @@ -0,0 +1,29 @@ +#!/bin/bash +DATASET=$1 + +if [ "$DATASET" == cnndm ] +then +python train_prophetnet.py \ + --dataset=cnndm \ + --pretrained_model_path=./model_state.pdparams \ + --batch_size=4 \ + --epochs=4 \ + --lr=0.0001 \ + --warmup_init_lr=1e-07 \ + --warmup_updates=1000 \ + --clip_norm=0.1 \ + --num_workers=4 \ + --output_dir=./ckpt/cnndm +else +python train_prophetnet.py \ + --dataset=gigaword \ + --pretrained_model_path=./model_state.pdparams \ + --batch_size=16 \ + --epochs=6 \ + --lr=0.0001 \ + --warmup_init_lr=1e-07 \ + --warmup_updates=1000 \ + --clip_norm=0.1 \ + --num_workers=8 \ + --output_dir=./ckpt/gigaword +fi \ No newline at end of file diff --git a/examples/text_summarization/prophetnet/train_prophetnet.py b/examples/text_summarization/prophetnet/train_prophetnet.py new file mode 100644 index 000000000000..3170d8418a33 --- /dev/null +++ b/examples/text_summarization/prophetnet/train_prophetnet.py @@ -0,0 +1,321 @@ +import argparse +import os + +import paddle +from paddle.io import DataLoader +from tqdm import tqdm + +from paddlenlp.data import Pad, Tuple +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers.prophetnet.modeling import ProphetNetForConditionalGeneration, ProphetNetModel +from paddlenlp.transformers.prophetnet.tokenizer import ProphetNetTokenizer + +parser = argparse.ArgumentParser() +# Required parameters +parser.add_argument( + "--dataset", + default="gigaword", + choices=["cnndm", "gigaword"], + type=str, + help="Path to tokenizer vocab file. ") +parser.add_argument( + "--pretrained_model_path", default="./model_state.pdparams", type=str) +parser.add_argument("--batch_size", default=24, type=int) +parser.add_argument("--epochs", default=3, type=int) +parser.add_argument("--lr", default=0.0001, type=float) +parser.add_argument("--weight_decay", default=0.0, type=float) +parser.add_argument("--warmup_init_lr", default=1e-07, type=float) +parser.add_argument("--warmup_updates", default=1000, type=int) +parser.add_argument("--clip_norm", default=0.1, type=float) +parser.add_argument("--num_workers", default=4, type=int) +parser.add_argument("--output_dir", default="./ckpt/gigaword", type=str) + +args = parser.parse_args() + + +def read(data_path): + data_path_src = data_path[0] + data_path_tgt = data_path[1] + with open(data_path_src, 'r', encoding='utf-8') as f_d_s: + src_lines_length = len(f_d_s.readlines()) + with open(data_path_tgt, 'r', encoding='utf-8') as f_d_t: + tgt_lines_length = len(f_d_t.readlines()) + assert src_lines_length == tgt_lines_length + with open(data_path_src, 'r', encoding='utf-8') as f_d_s: + with open(data_path_tgt, 'r', encoding='utf-8') as f_d_t: + for row_d_s, row_d_t in tqdm( + zip(f_d_s, f_d_t), total=src_lines_length): + yield {'article': row_d_s, 'highlights': row_d_t} + + +train_data_src = 'data/' + args.dataset + '_data/uncased_tok_data/train.src' +train_data_tgt = 'data/' + args.dataset + '_data/uncased_tok_data/train.tgt' + +dev_data_src = 'data/' + args.dataset + '_data/uncased_tok_data/dev.src' +dev_data_tgt = 'data/' + args.dataset + '_data/uncased_tok_data/dev.tgt' + +train_dataset = load_dataset( + read, data_path=[train_data_src, train_data_tgt], lazy=False) + +dev_dataset = load_dataset( + read, data_path=[dev_data_src, dev_data_tgt], lazy=False) + +t = ProphetNetTokenizer(vocab_file="prophetnet.tokenizer") + + +class InverseSquareRootSchedule(paddle.optimizer.lr.LRScheduler): + def __init__(self, + warmup_init_lr, + warmup_end_lr, + warmup_updates, + last_epoch=-1, + verbose=False): + self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates + self.decay_factor = warmup_end_lr * warmup_updates**0.5 + self.warmup_updates = warmup_updates + self.warmup_init_lr = warmup_init_lr + super(InverseSquareRootSchedule, self).__init__(warmup_init_lr, + last_epoch, verbose) + + def get_lr(self): + if self.last_epoch < self.warmup_updates: + self.base_lr = self.warmup_init_lr + self.last_epoch * self.lr_step + else: + self.base_lr = self.decay_factor * self.last_epoch**-0.5 + return self.base_lr + + +def convert_example(is_test=False): + def warpper(example): + """convert an example into necessary features""" + tokens = example['article'] + labels = example['highlights'] + src_ids, src_attention_mask_ids = tokens.split("$1$") + src_ids = [int(i) for i in src_ids.split(" ")] + src_attention_mask_ids = [ + int(i) for i in src_attention_mask_ids.split(" ") + ] + + if not is_test: + labels, decoder_input_attention_mask_ids = labels.split("$1$") + labels = [int(i) for i in labels.split(" ")] + decoder_input_attention_mask_ids = [ + int(i) for i in decoder_input_attention_mask_ids.split(" ") + ] + decoder_input_ids = [labels[-1]] + labels[:-1] + return src_ids, src_attention_mask_ids, decoder_input_ids, decoder_input_attention_mask_ids, labels + + else: + return src_ids, src_attention_mask_ids + + return warpper + + +trunc = convert_example() + +train_dataset = train_dataset.map(trunc) +dev_dataset = dev_dataset.map(trunc) + +batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=t.pad_token_id), # src_ids + Pad(axis=0, pad_val=0), # src_pids + Pad(axis=0, pad_val=t.pad_token_id), # tgt_ids + Pad(axis=0, pad_val=0), # tgt_pids + Pad(axis=0, pad_val=t.pad_token_id) # label +): fn(samples) + +batch_size = args.batch_size + +train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=batchify_fn, + use_shared_memory=False, + num_workers=args.num_workers) + +dev_data_loader = DataLoader( + dataset=dev_dataset, + batch_size=batch_size * 2, + shuffle=True, + collate_fn=batchify_fn, + use_shared_memory=False, + num_workers=args.num_workers) + +epochs = args.epochs +lr = args.lr +weight_decay = args.weight_decay +warmup_init_lr = args.warmup_init_lr +warmup_updates = args.warmup_updates +clip_norm = args.clip_norm +output_dir = args.output_dir + +best_valid_loss = None +start_epoch = 0 + +model = ProphetNetModel( + **ProphetNetModel.pretrained_init_configuration["prophetnet-large-uncased"]) +model = ProphetNetForConditionalGeneration(model) + +lr_scheduler = InverseSquareRootSchedule(warmup_init_lr, lr, warmup_updates) + +optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=weight_decay, + grad_clip=paddle.nn.ClipGradByNorm(clip_norm)) + +if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: + ckpt_file_last = os.listdir(output_dir)[-2] + ckpt_load = paddle.load(os.path.join(output_dir, ckpt_file_last)) + best_valid_loss = ckpt_load["best_valid_loss"] + start_epoch = ckpt_load["epoch"] + 1 + optimizer.set_state_dict(ckpt_load["optimizer"]) + model.load_dict(ckpt_load["model"]) +else: + model.load_dict(paddle.load(args.pretrained_model_path)) + +accumulate_batchs_num = int(32 * 16 / batch_size) + +scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + +def compute_loss(model, logits, labels, ignore_index=-100): + expend_targets = paddle.cast( + paddle.zeros((model.prophetnet.config["ngram"], labels.shape[0], + labels.shape[1])).fill_(ignore_index), + dtype=paddle.int32) + + for i in range(model.prophetnet.config["ngram"]): + if i > 0 and model.prophetnet.disable_ngram_loss: + break + expend_targets[i, :, :] = labels.cast(dtype=paddle.int32) # B,Ngram,Seq + + logits = logits.transpose([1, 0, 2, 3]) + + if model.prophetnet.eps > 0.0: + expend_targets_mask = paddle.cast( + expend_targets != ignore_index, dtype=paddle.float32) + expend_targets = paddle.nn.functional.one_hot( + expend_targets, num_classes=model.vocab_size) + expend_targets = paddle.nn.functional.label_smooth( + expend_targets, epsilon=model.prophetnet.eps) + loss = paddle.nn.functional.cross_entropy( + logits, expend_targets, soft_label=True, reduction='none').squeeze() + loss = paddle.sum(expend_targets_mask * + loss) / expend_targets_mask.sum() + else: + loss = paddle.nn.functional.cross_entropy( + logits, + expend_targets.cast(dtype=paddle.int64), + ignore_index=ignore_index) + + return loss + + +@paddle.no_grad() +def valid(data): + model.eval() + losses = 0 + with tqdm(total=len(data)) as bar: + for step, batch in enumerate(data, start=1): + src_ids, src_attention_mask_ids, decoder_input_ids, decoder_input_attention_mask_ids, label_ids = batch + src_ids = src_ids.cast(dtype=paddle.int32) + src_attention_mask_ids = src_attention_mask_ids.cast( + dtype=paddle.int32) + decoder_input_ids = decoder_input_ids.cast(dtype=paddle.int32) + decoder_input_attention_mask_ids = decoder_input_attention_mask_ids.cast( + dtype=paddle.int32) + label_ids = label_ids.cast(dtype=paddle.int64) + _, _, logits = model( + input_ids=src_ids, + attention_mask=src_attention_mask_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_input_attention_mask_ids) + loss = compute_loss( + model, logits, label_ids, ignore_index=model.padding_idx) + losses += loss.detach().numpy() + bar.update(1) + return losses / step + + +def train(): + global_step = 1 + global best_valid_loss + model.train() + for epoch in range(start_epoch, epochs): + with tqdm(total=int(len(train_data_loader) / + accumulate_batchs_num)) as train_bar: + for step, batch in enumerate(train_data_loader, start=1): + src_ids, src_attention_mask_ids, decoder_input_ids, decoder_input_attention_mask_ids, label_ids = batch + src_ids = src_ids.cast(dtype=paddle.int32) + src_attention_mask_ids = src_attention_mask_ids.cast( + dtype=paddle.int32) + decoder_input_ids = decoder_input_ids.cast(dtype=paddle.int32) + decoder_input_attention_mask_ids = decoder_input_attention_mask_ids.cast( + dtype=paddle.int32) + label_ids = label_ids.cast(dtype=paddle.int64) + with paddle.amp.auto_cast(): + _, _, logits = model( + input_ids=src_ids, + attention_mask=src_attention_mask_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_input_attention_mask_ids) + loss = compute_loss( + model, + logits, + label_ids, + ignore_index=model.padding_idx) + + scaled = scaler.scale(loss) + scaled.backward() + if (step + 1) % accumulate_batchs_num == 0: + scaler.minimize(optimizer, scaled) + lr_scheduler.step() + optimizer.clear_grad() + train_bar.update(1) + train_bar.set_description( + "global step %d, epoch: %d, batch: %d, loss: %f, lr: %.3e" + % + (global_step, epoch, step, loss, lr_scheduler.get_lr())) + global_step += 1 + + valid_loss = valid(dev_data_loader) + best_ckpt_path = os.path.join(output_dir, "model_best.pdparams") + if best_valid_loss is None: + best_valid_loss = valid_loss + save(best_ckpt_path, { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "best_valid_loss": best_valid_loss + }) + else: + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + save(best_ckpt_path, { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "best_valid_loss": best_valid_loss + }) + print("valid loss: %f, best valid loss: %f" % + (valid_loss, best_valid_loss)) + ckpt_path = os.path.join(output_dir, "model_%d.pdparams" % epoch) + save(ckpt_path, { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "best_valid_loss": best_valid_loss + }) + + +def save(path, obj): + if not os.path.exists( + os.path.abspath(os.path.dirname(path) + os.path.sep + ".")): + os.makedirs(os.path.abspath(os.path.dirname(path) + os.path.sep + ".")) + paddle.save(obj, path) + + +if __name__ == "__main__": + train() diff --git a/examples/text_summarization/prophetnet/uncase_tokenize_data.py b/examples/text_summarization/prophetnet/uncase_tokenize_data.py new file mode 100644 index 000000000000..b12eb62da53d --- /dev/null +++ b/examples/text_summarization/prophetnet/uncase_tokenize_data.py @@ -0,0 +1,117 @@ +import argparse +import os + +import tqdm +from nltk.tokenize.treebank import TreebankWordDetokenizer + +from paddlenlp.transformers.prophetnet.tokenizer import ProphetNetTokenizer + + +def uncased_preocess(fin, fout, keep_sep=False, max_len=512): + tokenizer = ProphetNetTokenizer(vocab_file="prophetnet.tokenizer") + fin = open(fin, 'r', encoding='utf-8') + fout = open(fout, 'w', encoding='utf-8') + twd = TreebankWordDetokenizer() + for line in tqdm.tqdm(fin.readlines()): + line = line.strip().replace('``', '"').replace('\'\'', + '"').replace('`', '\'') + s_list = [ + twd.detokenize( + x.strip().split(' '), convert_parentheses=True) + for x in line.split('') + ] + if keep_sep: + output_string = " [X_SEP] ".join(s_list) + else: + output_string = " ".join(s_list) + encoded_string = tokenizer( + output_string, return_attention_mask=True, max_seq_len=max_len) + ids, attention_mask_ids = encoded_string[ + "input_ids"][:max_len], encoded_string["attention_mask"][:max_len] + output_string = "$1$".join([ + " ".join([str(i) for i in ids]), + " ".join([str(i) for i in attention_mask_ids]) + ]) + fout.write('{}\n'.format(output_string)) + + +def tokenize_with_bert_uncase(fin, fout, max_len=512): + fin = open(fin, 'r', encoding='utf-8') + fout = open(fout, 'w', encoding='utf-8') + tokenizer = ProphetNetTokenizer(vocab_file="prophetnet.tokenizer") + for line in tqdm.tqdm(fin.readlines()): + encoded_string = tokenizer( + line, return_attention_mask=True, max_seq_len=max_len) + ids, attention_mask_ids = encoded_string[ + "input_ids"][:max_len], encoded_string["attention_mask"][:max_len] + output_string = "$1$".join([ + " ".join([str(i) for i in ids]), + " ".join([str(i) for i in attention_mask_ids]) + ]) + fout.write('{}\n'.format(output_string)) + + +def tokenize_data(dataset): + dataset = dataset + "_data" + input_dir = './data/%s' % (dataset) + output_dir = './data/%s/uncased_tok_data' % (dataset) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + if dataset == 'cnndm': + uncased_preocess( + '%s/train.src' % input_dir, + '%s/train.src' % output_dir, + keep_sep=False) + uncased_preocess( + '%s/dev.src' % input_dir, '%s/dev.src' % output_dir, keep_sep=False) + uncased_preocess( + '%s/test.src' % input_dir, + '%s/test.src' % output_dir, + keep_sep=False) + uncased_preocess( + '%s/train.tgt' % input_dir, + '%s/train.tgt' % output_dir, + keep_sep=True, + max_len=128) + uncased_preocess( + '%s/dev.tgt' % input_dir, '%s/dev.tgt' % output_dir, keep_sep=True) + uncased_preocess( + '%s/test.tgt' % input_dir, + '%s/test.tgt' % output_dir, + keep_sep=True) + else: + tokenize_with_bert_uncase('%s/train.src' % input_dir, + '%s/train.src' % output_dir) + tokenize_with_bert_uncase('%s/train.tgt' % input_dir, + '%s/train.tgt' % output_dir) + tokenize_with_bert_uncase('%s/dev.src' % input_dir, + '%s/dev.src' % output_dir) + tokenize_with_bert_uncase('%s/dev.tgt' % input_dir, + '%s/dev.tgt' % output_dir) + tokenize_with_bert_uncase('%s/test.src' % input_dir, + '%s/test.src' % output_dir) + tokenize_with_bert_uncase('%s/test.tgt' % input_dir, + '%s/test.tgt' % output_dir) + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--dataset", + type=str, + help="choose dataset from all, or 1 of 8 datasets: cnndm, gigaword") +args = parser.parse_args() + +DATASET_LIST = ['cnndm', 'gigaword'] + +if args.dataset != 'all' and args.dataset not in DATASET_LIST: + print('please choose dataset from all, or 1 of 8 datasets: cnndm, gigaword') + exit() +else: + if args.dataset == 'all': + dataset_list = DATASET_LIST + else: + dataset_list = [args.dataset] + +print(dataset_list) +for dataset in dataset_list: + tokenize_data(dataset) diff --git a/examples/text_summarization/prophetnet/uncompress_data.sh b/examples/text_summarization/prophetnet/uncompress_data.sh new file mode 100644 index 000000000000..392596ae14b2 --- /dev/null +++ b/examples/text_summarization/prophetnet/uncompress_data.sh @@ -0,0 +1,12 @@ +tar -xvf ./glge_public.tar +tar -zxvf ./glge_hidden_v1.1.tar.gz + +DATA=./data +DATASETS=(cnndm gigaword) +mkdir $DATA +for DATASET in ${DATASETS[@]}; do + echo $DATASET +mkdir $DATA/$DATASET\_data +mv ./glge-released-dataset/easy/$DATASET\_data/org_data/* $DATA/$DATASET\_data/ +mv ./glge-hidden-dataset/easy/$DATASET\_data/org_data/* $DATA/$DATASET\_data/ +done From 80e2dca5718ded70e8c1dd7a5b087cf26bade124 Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Tue, 1 Mar 2022 11:15:08 +0800 Subject: [PATCH 6/7] update tokenizer.py,run_train.sh,train_prophetnet.py --- .../prophetnet/run_train.sh | 4 +- .../prophetnet/train_prophetnet.py | 16 +- .../transformers/prophetnet/tokenizer.py | 212 +----------------- 3 files changed, 11 insertions(+), 221 deletions(-) diff --git a/examples/text_summarization/prophetnet/run_train.sh b/examples/text_summarization/prophetnet/run_train.sh index efe119491c1d..fe0fa26356b9 100644 --- a/examples/text_summarization/prophetnet/run_train.sh +++ b/examples/text_summarization/prophetnet/run_train.sh @@ -10,7 +10,7 @@ python train_prophetnet.py \ --epochs=4 \ --lr=0.0001 \ --warmup_init_lr=1e-07 \ - --warmup_updates=1000 \ + --warmup_steps=1000 \ --clip_norm=0.1 \ --num_workers=4 \ --output_dir=./ckpt/cnndm @@ -22,7 +22,7 @@ python train_prophetnet.py \ --epochs=6 \ --lr=0.0001 \ --warmup_init_lr=1e-07 \ - --warmup_updates=1000 \ + --warmup_steps=1000 \ --clip_norm=0.1 \ --num_workers=8 \ --output_dir=./ckpt/gigaword diff --git a/examples/text_summarization/prophetnet/train_prophetnet.py b/examples/text_summarization/prophetnet/train_prophetnet.py index 3170d8418a33..d6ed4e5306df 100644 --- a/examples/text_summarization/prophetnet/train_prophetnet.py +++ b/examples/text_summarization/prophetnet/train_prophetnet.py @@ -25,7 +25,7 @@ parser.add_argument("--lr", default=0.0001, type=float) parser.add_argument("--weight_decay", default=0.0, type=float) parser.add_argument("--warmup_init_lr", default=1e-07, type=float) -parser.add_argument("--warmup_updates", default=1000, type=int) +parser.add_argument("--warmup_steps", default=1000, type=int) parser.add_argument("--clip_norm", default=0.1, type=float) parser.add_argument("--num_workers", default=4, type=int) parser.add_argument("--output_dir", default="./ckpt/gigaword", type=str) @@ -67,18 +67,18 @@ class InverseSquareRootSchedule(paddle.optimizer.lr.LRScheduler): def __init__(self, warmup_init_lr, warmup_end_lr, - warmup_updates, + warmup_steps, last_epoch=-1, verbose=False): - self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates - self.decay_factor = warmup_end_lr * warmup_updates**0.5 - self.warmup_updates = warmup_updates + self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps + self.decay_factor = warmup_end_lr * warmup_steps**0.5 + self.warmup_steps = warmup_steps self.warmup_init_lr = warmup_init_lr super(InverseSquareRootSchedule, self).__init__(warmup_init_lr, last_epoch, verbose) def get_lr(self): - if self.last_epoch < self.warmup_updates: + if self.last_epoch < self.warmup_steps: self.base_lr = self.warmup_init_lr + self.last_epoch * self.lr_step else: self.base_lr = self.decay_factor * self.last_epoch**-0.5 @@ -146,7 +146,7 @@ def warpper(example): lr = args.lr weight_decay = args.weight_decay warmup_init_lr = args.warmup_init_lr -warmup_updates = args.warmup_updates +warmup_steps = args.warmup_steps clip_norm = args.clip_norm output_dir = args.output_dir @@ -157,7 +157,7 @@ def warpper(example): **ProphetNetModel.pretrained_init_configuration["prophetnet-large-uncased"]) model = ProphetNetForConditionalGeneration(model) -lr_scheduler = InverseSquareRootSchedule(warmup_init_lr, lr, warmup_updates) +lr_scheduler = InverseSquareRootSchedule(warmup_init_lr, lr, warmup_steps) optimizer = paddle.optimizer.Adam( learning_rate=lr_scheduler, diff --git a/paddlenlp/transformers/prophetnet/tokenizer.py b/paddlenlp/transformers/prophetnet/tokenizer.py index 13e037a9d71e..b3ae379ec133 100644 --- a/paddlenlp/transformers/prophetnet/tokenizer.py +++ b/paddlenlp/transformers/prophetnet/tokenizer.py @@ -15,220 +15,10 @@ import collections import logging import os -from collections import OrderedDict from typing import List from .. import PretrainedTokenizer, BasicTokenizer, WordpieceTokenizer - - -class Trie: - """ - Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass - Loose reference https://en.wikipedia.org/wiki/Trie - """ - - def __init__(self): - self.data = {} - - def add(self, word: str): - """ - Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. - The special key `""` is used to represent termination. - - This function is idempotent, adding twice the same word will leave the trie unchanged - - Example: - - ```python - >>> trie = Trie() - >>> trie.add("Hello 友達") - >>> trie.data - {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} - >>> trie.add("Hello") - >>> trie.data - {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} - ``` - """ - if not word: - # Prevent empty string - return - ref = self.data - for char in word: - ref[char] = char in ref and ref[char] or {} - ref = ref[char] - ref[""] = 1 - - def split(self, text: str) -> List[str]: - """ - Will look for the words added to the trie within `text`. Output is the original string splitted along the - boundaries of the words found. - - This trie will match the longest possible word first ! - - Example: - - ```python - >>> trie = Trie() - >>> trie.split("[CLS] This is a extra_id_100") - ["[CLS] This is a extra_id_100"] - >>> trie.add("[CLS]") - >>> trie.add("extra_id_1") - >>> trie.add("extra_id_100") - >>> trie.split("[CLS] This is a extra_id_100") - ["[CLS]", " This is a ", "extra_id_100"] - ``` - """ - # indexes are counted left of the chars index. - # "hello", index 0, is left of h, index 1 is between h and e. - # index 5 is right of the "o". - - # States are going to capture every possible start (indexes as above) - # as keys, and have as values, a pointer to the position in the trie - # where we're at. This is a partial match for now. - # This enables to keep track of multiple matches while we're iterating - # the string - # If the trie contains, "blowing", and "lower" and we encounter the - # string "blower", we need to split into ["b", "lower"]. - # This is where we need to keep track of multiple possible starts. - states = OrderedDict() - - # This will contain every indices where we need - # to cut. - # We force to cut at offset 0 and len(text) (added later) - offsets = [0] - - # This is used by the lookahead which needs to skip over - # some text where the full match exceeded the place in the initial - # for loop - skip = None - # Main loop, Giving this algorithm O(n) complexity - for current, current_char in enumerate(text): - if skip and current < skip: - # Prevents the lookahead for matching twice - # like extra_id_100 and id_100 - continue - - # This will track every state - # that stop matching, we need to stop tracking them. - # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then - # fail on "b", we need to remove 0 from the valid states. - to_remove = set() - # Whenever we found a match, we need to drop everything - # this is a greedy algorithm, it will match on the first found token - reset = False - - # In this case, we already have partial matches (But unfinished) - for start, trie_pointer in states.items(): - if "" in trie_pointer: - # This is a final match, we need to reset and - # store the results in `offsets`. - - # Lookahead to match longest first - # Important in case of extra_id_1 vs extra_id_100 - # Here we are also actively looking for other earlier partial - # matches - # "[CLS]", "L", we need to match CLS even if L is special - for lookstart, looktrie_pointer in states.items(): - if lookstart > start: - # This partial match is later, we can stop looking - break - elif lookstart < start: - # This partial match is earlier, the trie pointer - # was already updated, so index is + 1 - lookahead_index = current + 1 - end = current + 1 - else: - # Here lookstart == start and - # looktrie_pointer == trie_pointer - # It wasn't updated yet so indices are current ones - lookahead_index = current - end = current - next_char = text[ - lookahead_index] if lookahead_index < len( - text) else None - while next_char in looktrie_pointer: - looktrie_pointer = looktrie_pointer[next_char] - lookahead_index += 1 - if "" in looktrie_pointer: - start = lookstart - end = lookahead_index - skip = lookahead_index - - if lookahead_index == len(text): - # End of string - break - next_char = text[lookahead_index] - # End lookahead - - # Storing and resetting - offsets.append(start) - offsets.append(end) - reset = True - break - elif current_char in trie_pointer: - # The current character being looked at has a match within the trie - # update the pointer (it will be stored back into states later). - trie_pointer = trie_pointer[current_char] - - # Storing back the new pointer into the states. - # Partial matches got longer by one. - states[start] = trie_pointer - else: - # The new character has not match in the trie, we need - # to stop keeping track of this partial match. - # We can't do it directly within the loop because of how - # python iteration works - to_remove.add(start) - - # Either clearing the full start (we found a real match) - # Or clearing only the partial matches that didn't work. - if reset: - states = {} - else: - for start in to_remove: - del states[start] - - # If this character is a starting character within the trie - # start keeping track of this partial match. - if current_char in self.data: - states[current] = self.data[current_char] - - # We have a cut at the end with states. - for start, trie_pointer in states.items(): - if "" in trie_pointer: - # This is a final match, we need to reset and - # store the results in `offsets`. - end = len(text) - offsets.append(start) - offsets.append(end) - # Longest cut is always the one with lower start so the first - # item so we need to break. - break - - return self.cut_text(text, offsets) - - def cut_text(self, text, offsets): - # We have all the offsets now, we just need to do the actual splitting. - # We need to eventually add the first part of the string and the eventual - # last part. - offsets.append(len(text)) - tokens = [] - start = 0 - for end in offsets: - if start > end: - logging.error( - "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway." - ) - continue - elif start == end: - # This might happen if there's a match at index 0 - # we're also preventing zero-width cuts in case of two - # consecutive matches - continue - tokens.append(text[start:end]) - start = end - - return tokens +from ..tokenizer_utils import Trie def load_vocab(vocab_file): From 7518275f660833f6348579d531ebfc487ff50f24 Mon Sep 17 00:00:00 2001 From: DMH_coco <294270681@qq.com> Date: Fri, 4 Mar 2022 11:15:06 +0800 Subject: [PATCH 7/7] remove evaluate/gigaword/__init__.py --- .../text_summarization/prophetnet/evaluate/gigaword/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 examples/text_summarization/prophetnet/evaluate/gigaword/__init__.py diff --git a/examples/text_summarization/prophetnet/evaluate/gigaword/__init__.py b/examples/text_summarization/prophetnet/evaluate/gigaword/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000