From 17d066448b8d97d719cc56c2191a6c02fd41ec3c Mon Sep 17 00:00:00 2001 From: Jiaqi Liu <709153940@qq.com> Date: Tue, 4 Apr 2023 17:44:37 +0800 Subject: [PATCH] [PPMiniLM P0] Add PretrainedConfig and unit tests (#5520) * update ppminilm * update copyright * update ppminilm tokenizer unittest * fix format of tinybert * remove useless arg --- .../pp-minilm/finetuning/export_model.py | 16 +- .../pp-minilm/pruning/export_model.py | 66 +--- .../pp-minilm/pruning/prune.py | 43 ++- .../pp-minilm/quantization/quant_post.py | 2 +- model_zoo/tinybert/data_augmentation.py | 16 +- model_zoo/tinybert/general_distill.py | 34 +- model_zoo/tinybert/task_distill.py | 24 +- paddlenlp/transformers/ofa_utils.py | 39 +-- .../transformers/ppminilm/configuration.py | 151 +++++++++ paddlenlp/transformers/ppminilm/modeling.py | 301 +++++------------ paddlenlp/transformers/ppminilm/tokenizer.py | 61 +++- tests/transformers/ppminilm/__init__.py | 13 + tests/transformers/ppminilm/test_modeling.py | 308 ++++++++++++++++++ tests/transformers/ppminilm/test_tokenizer.py | 137 ++++++++ 14 files changed, 840 insertions(+), 371 deletions(-) create mode 100644 paddlenlp/transformers/ppminilm/configuration.py create mode 100644 tests/transformers/ppminilm/__init__.py create mode 100644 tests/transformers/ppminilm/test_modeling.py create mode 100644 tests/transformers/ppminilm/test_tokenizer.py diff --git a/examples/model_compression/pp-minilm/finetuning/export_model.py b/examples/model_compression/pp-minilm/finetuning/export_model.py index 04c4693a5fa8..b7e01ca67bab 100644 --- a/examples/model_compression/pp-minilm/finetuning/export_model.py +++ b/examples/model_compression/pp-minilm/finetuning/export_model.py @@ -15,6 +15,8 @@ import os import sys +import paddle + from paddlenlp.trainer.argparser import strtobool from paddlenlp.transformers import PPMiniLMForSequenceClassification @@ -53,13 +55,15 @@ def parse_args(): def do_export(args): save_path = os.path.join(os.path.dirname(args.model_path), "inference") model = PPMiniLMForSequenceClassification.from_pretrained(args.model_path) - is_text_pair = True args.task_name = args.task_name.lower() - if args.task_name in ("tnews", "iflytek", "cluewsc2020"): - is_text_pair = False - model.to_static( - save_path, use_faster_tokenizer=args.save_inference_model_with_tokenizer, is_text_pair=is_text_pair - ) + + input_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec(shape=[None, None], dtype="int64"), # token_type_ids + ] + model = paddle.jit.to_static(model, input_spec=input_spec) + + paddle.jit.save(model, save_path) def print_arguments(args): diff --git a/examples/model_compression/pp-minilm/pruning/export_model.py b/examples/model_compression/pp-minilm/pruning/export_model.py index 278630e60130..e2088d1547f5 100644 --- a/examples/model_compression/pp-minilm/pruning/export_model.py +++ b/examples/model_compression/pp-minilm/pruning/export_model.py @@ -19,11 +19,9 @@ import sys import paddle -from paddle.common_ops_import import core from paddleslim.nas.ofa import OFA, utils from paddleslim.nas.ofa.convert_super import Convert, supernet -from paddlenlp.trainer.argparser import strtobool from paddlenlp.transformers import PPMiniLMModel sys.path.append("../") @@ -32,13 +30,7 @@ def ppminilm_forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): wtype = self.pooler.dense.fn.weight.dtype if hasattr(self.pooler.dense, "fn") else self.pooler.dense.weight.dtype - if self.use_faster_tokenizer: - input_ids, token_type_ids = self.tokenizer( - text=input_ids, - text_pair=token_type_ids, - max_seq_len=self.max_seq_len, - pad_to_max_seq_len=self.pad_to_max_seq_len, - ) + if attention_mask is None: attention_mask = paddle.unsqueeze((input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids) @@ -103,12 +95,6 @@ def parse_args(): parser.add_argument("--n_gpu", type=int, default=1, help="number of gpus to use, 0 for cpu.") parser.add_argument("--width_mult", type=float, default=1.0, help="width mult you want to export") parser.add_argument("--depth_mult", type=float, default=1.0, help="depth mult you want to export") - parser.add_argument( - "--use_faster_tokenizer", - type=strtobool, - default=True, - help="Whether to use FasterTokenizer to accelerate training or further inference.", - ) args = parser.parse_args() return args @@ -118,7 +104,7 @@ def do_export(args): args.model_type = args.model_type.lower() args.task_name = args.task_name.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config_path = os.path.join(args.model_name_or_path, "model_config.json") + config_path = os.path.join(args.model_name_or_path, "config.json") cfg_dict = dict(json.loads(open(config_path).read())) kept_layers_index = {} @@ -132,12 +118,9 @@ def do_export(args): with open(config_path, "w", encoding="utf-8") as f: f.write(json.dumps(cfg_dict, ensure_ascii=False)) - num_labels = cfg_dict["num_classes"] - - model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_labels) - model.use_faster_tokenizer = args.use_faster_tokenizer + model = model_class.from_pretrained(args.model_name_or_path) - origin_model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_labels) + origin_model = model_class.from_pretrained(args.model_name_or_path) os.rename(config_path + "_bak", config_path) @@ -164,30 +147,12 @@ def do_export(args): if isinstance(sublayer, paddle.nn.MultiHeadAttention): sublayer.num_heads = int(args.width_mult * sublayer.num_heads) - is_text_pair = True - if args.task_name in ("tnews", "iflytek", "cluewsc2020"): - is_text_pair = False - - if args.use_faster_tokenizer: - ofa_model.model.add_faster_tokenizer_op() - if is_text_pair: - origin_model_new = ofa_model.export( - best_config, - input_shapes=[[1], [1]], - input_dtypes=[core.VarDesc.VarType.STRINGS, core.VarDesc.VarType.STRINGS], - origin_model=origin_model, - ) - else: - origin_model_new = ofa_model.export( - best_config, input_shapes=[1], input_dtypes=core.VarDesc.VarType.STRINGS, origin_model=origin_model - ) - else: - origin_model_new = ofa_model.export( - best_config, - input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]], - input_dtypes=["int64", "int64"], - origin_model=origin_model, - ) + origin_model_new = ofa_model.export( + best_config, + input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]], + input_dtypes=["int64", "int64"], + origin_model=origin_model, + ) for name, sublayer in origin_model_new.named_sublayers(): if isinstance(sublayer, paddle.nn.MultiHeadAttention): @@ -199,10 +164,13 @@ def do_export(args): model_to_save = origin_model_new model_to_save.save_pretrained(output_dir) - if args.static_sub_model is not None: - origin_model_new.to_static( - args.static_sub_model, use_faster_tokenizer=args.use_faster_tokenizer, is_text_pair=is_text_pair - ) + input_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec(shape=[None, None], dtype="int64"), # token_type_ids + ] + origin_model_new = paddle.jit.to_static(origin_model_new, input_spec=input_spec) + + paddle.jit.save(origin_model_new, args.static_sub_model) def print_arguments(args): diff --git a/examples/model_compression/pp-minilm/pruning/prune.py b/examples/model_compression/pp-minilm/pruning/prune.py index cebf5601a02a..19427a8e4ea5 100644 --- a/examples/model_compression/pp-minilm/pruning/prune.py +++ b/examples/model_compression/pp-minilm/pruning/prune.py @@ -27,16 +27,28 @@ from paddle.io import DataLoader from paddleslim.nas.ofa import OFA, DistillConfig, utils from paddleslim.nas.ofa.convert_super import Convert, supernet -from paddleslim.nas.ofa.utils import nlp_utils from paddlenlp.data import Pad, Stack, Tuple from paddlenlp.datasets import load_dataset from paddlenlp.transformers import LinearDecayWithWarmup, PPMiniLMModel +from paddlenlp.transformers.ofa_utils import ( + compute_neuron_head_importance, + encoder_layer_ofa_forward, + encoder_ofa_forward, + mha_ofa_forward, + prepare_qkv_ofa, + reorder_neuron_head, +) from paddlenlp.utils.log import logger sys.path.append("../") from data import METRIC_CLASSES, MODEL_CLASSES, convert_example # noqa: E402 +paddle.nn.MultiHeadAttention.forward = mha_ofa_forward +paddle.nn.MultiHeadAttention._prepare_qkv = prepare_qkv_ofa +paddle.nn.TransformerEncoder.forward = encoder_ofa_forward +paddle.nn.TransformerEncoderLayer.forward = encoder_layer_ofa_forward + def parse_args(): parser = argparse.ArgumentParser() @@ -125,7 +137,11 @@ def parse_args(): help="The device to select to train the model, is must be cpu/gpu/xpu.", ) parser.add_argument( - "--width_mult_list", nargs="+", type=float, default=[1.0, 5 / 6, 2 / 3, 0.5], help="width mult in compress" + "--width_mult_list", + nargs="+", + type=str, + default=["1.0", "5 / 6", "2 / 3", "0.5"], + help="width mult of compression", ) args = parser.parse_args() return args @@ -161,10 +177,6 @@ def evaluate(model, metric, data_loader, width_mult, student=False): # monkey patch for ppminilm forward to accept [attention_mask, head_mask] as attention_mask def ppminilm_forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=[None, None]): - if self.use_faster_tokenizer: - input_ids, token_type_ids = self.tokenizer( - text=input_ids, text_pair=token_type_ids, max_seq_len=self.max_seq_len - ) wtype = self.pooler.dense.fn.weight.dtype if hasattr(self.pooler.dense, "fn") else self.pooler.dense.weight.dtype if attention_mask[0] is None: attention_mask[0] = paddle.unsqueeze((input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) @@ -178,19 +190,6 @@ def ppminilm_forward(self, input_ids, token_type_ids=None, position_ids=None, at PPMiniLMModel.forward = ppminilm_forward -# reorder weights according head importance and neuron importance -def reorder_neuron_head(model, head_importance, neuron_importance): - # reorder heads and ffn neurons - for layer, current_importance in enumerate(neuron_importance): - # reorder heads - idx = paddle.argsort(head_importance[layer], descending=True) - nlp_utils.reorder_head(model.ppminilm.encoder.layers[layer].self_attn, idx) - # reorder neurons - idx = paddle.argsort(paddle.to_tensor(current_importance), descending=True) - nlp_utils.reorder_neuron(model.ppminilm.encoder.layers[layer].linear1.fn, idx, dim=1) - nlp_utils.reorder_neuron(model.ppminilm.encoder.layers[layer].linear2.fn, idx, dim=0) - - def soft_cross_entropy(inp, target): inp_likelihood = F.log_softmax(inp, axis=-1) target_prob = F.softmax(target, axis=-1) @@ -273,8 +272,7 @@ def do_train(args): # Step6: Calculate the importance of neurons and head, # and then reorder them according to the importance. - head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance( - args.task_name, + head_importance, neuron_importance = compute_neuron_head_importance( ofa_model.model, dev_data_loader, loss_fct=criterion, @@ -315,6 +313,7 @@ def do_train(args): global_step = 0 tic_train = time.time() best_res = 0.0 + args.width_mult_list = [eval(width_mult) for width_mult in args.width_mult_list] for epoch in range(num_train_epochs): # Step7: Set current epoch and task. ofa_model.set_epoch(epoch) @@ -322,7 +321,7 @@ def do_train(args): for step, batch in enumerate(train_data_loader): global_step += 1 - input_ids, segment_ids, labels = batch + input_ids, segment_ids, _ = batch for width_mult in args.width_mult_list: # Step8: Broadcast supernet config from width_mult, diff --git a/examples/model_compression/pp-minilm/quantization/quant_post.py b/examples/model_compression/pp-minilm/quantization/quant_post.py index 4882c658ea54..7436f4f9f5aa 100644 --- a/examples/model_compression/pp-minilm/quantization/quant_post.py +++ b/examples/model_compression/pp-minilm/quantization/quant_post.py @@ -70,7 +70,7 @@ parser.add_argument( "--use_faster_tokenizer", type=strtobool, - default=True, + default=False, help="Whether to use FasterTokenizer to accelerate training or further inference.", ) diff --git a/model_zoo/tinybert/data_augmentation.py b/model_zoo/tinybert/data_augmentation.py index 13d0a68f9bc6..3043ce8c79ae 100644 --- a/model_zoo/tinybert/data_augmentation.py +++ b/model_zoo/tinybert/data_augmentation.py @@ -14,19 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random -import sys +import argparse +import csv +import logging import os -import unicodedata +import random import re -import logging -import csv -import argparse +import unicodedata import numpy as np import paddle -from paddlenlp.transformers import BertTokenizer, BertForPretraining +from paddlenlp.transformers import BertForPretraining, BertTokenizer logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO @@ -241,8 +240,6 @@ def _read_tsv(input_file, quotechar=None): reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: - if sys.version_info[0] == 2: - line = list(unicode(cell, "utf-8") for cell in line) lines.append(line) return lines @@ -477,7 +474,6 @@ def main(): "RTE": {"N": 30}, } - device = paddle.set_device(args.device) if args.task_name in default_params: args.N = default_params[args.task_name]["N"] diff --git a/model_zoo/tinybert/general_distill.py b/model_zoo/tinybert/general_distill.py index b139890bd361..fe7a0ca129b1 100644 --- a/model_zoo/tinybert/general_distill.py +++ b/model_zoo/tinybert/general_distill.py @@ -15,29 +15,27 @@ import argparse import logging import os -import sys import random import time -import math -from functools import partial from concurrent.futures import ThreadPoolExecutor import numpy as np import paddle from paddle.io import DataLoader -import paddle.nn as nn -import paddle.nn.functional as F -from paddle.metric import Metric, Accuracy, Precision, Recall +from paddle.metric import Accuracy -from paddlenlp.datasets import load_dataset -from paddlenlp.data import Stack, Tuple, Pad, Dict -from paddlenlp.utils.tools import TimeCostAverage -from paddlenlp.data.sampler import SamplerHelper +from paddlenlp.data import Pad, Tuple from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman -from paddlenlp.transformers import LinearDecayWithWarmup -from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer -from paddlenlp.transformers import TinyBertModel, TinyBertForPretraining, TinyBertTokenizer +from paddlenlp.transformers import ( + BertForSequenceClassification, + BertTokenizer, + LinearDecayWithWarmup, + TinyBertForPretraining, + TinyBertModel, + TinyBertTokenizer, +) from paddlenlp.transformers.distill_utils import to_distill +from paddlenlp.utils.tools import TimeCostAverage FORMAT = "%(asctime)s-%(levelname)s: %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -62,7 +60,6 @@ def parse_args(): parser = argparse.ArgumentParser() - # Required parameters parser.add_argument( "--model_type", @@ -253,7 +250,6 @@ def do_train(args): teacher_model_class, tokenizer_class = MODEL_CLASSES[args.teacher_model_type] teacher = teacher_model_class.from_pretrained(args.teacher_model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.teacher_model_name_or_path) - pad_token_id = teacher.pretrained_init_configuration[args.teacher_model_name_or_path]["pad_token_id"] if paddle.distributed.get_world_size() > 1: student = paddle.DataParallel(student, find_unused_parameters=True) teacher = paddle.DataParallel(teacher, find_unused_parameters=True) @@ -278,7 +274,6 @@ def do_train(args): grad_clip=clip, ) - ce_loss_fct = paddle.nn.CrossEntropyLoss(soft_label=True) mse_loss_fct = paddle.nn.MSELoss() pool = ThreadPoolExecutor(1) @@ -287,7 +282,6 @@ def do_train(args): student = to_distill(student, return_attentions=True, return_layer_outputs=True) global_step = 0 - tic_train = time.time() for epoch in range(args.num_train_epochs): files = [ os.path.join(args.input_dir, f) @@ -316,9 +310,6 @@ def do_train(args): data_file = files[ (f_start_id * paddle.distributed.get_world_size() + paddle.distributed.get_rank()) % num_files ] - - previous_file = data_file - train_data_loader, _ = create_pretraining_dataset(data_file, shared_file_list, args, worker_init, tokenizer) # TODO(guosheng): better way to process single file @@ -348,7 +339,6 @@ def cal_intermediate_distill_loss(student, teacher): data_file = files[ (f_id * paddle.distributed.get_world_size() + paddle.distributed.get_rank()) % num_files ] - previous_file = data_file dataset_future = pool.submit( create_pretraining_dataset, data_file, shared_file_list, args, worker_init, tokenizer ) @@ -359,7 +349,7 @@ def cal_intermediate_distill_loss(student, teacher): for step, batch in enumerate(train_data_loader): global_step += 1 input_ids = batch[0] - attention_mask = paddle.unsqueeze((input_ids == pad_token_id).astype("int64") * -1e9, axis=[1, 2]) + student(input_ids) with paddle.no_grad(): teacher(input_ids) diff --git a/model_zoo/tinybert/task_distill.py b/model_zoo/tinybert/task_distill.py index a765aa3bd0eb..d29c482debe1 100644 --- a/model_zoo/tinybert/task_distill.py +++ b/model_zoo/tinybert/task_distill.py @@ -14,27 +14,28 @@ import argparse import logging +import math import os -import sys import random import time -import math from functools import partial import numpy as np import paddle -from paddle.io import DataLoader -import paddle.nn as nn import paddle.nn.functional as F -from paddle.metric import Metric, Accuracy, Precision, Recall +from paddle.io import DataLoader +from paddle.metric import Accuracy +from paddlenlp.data import Pad, Stack, Tuple from paddlenlp.datasets import load_dataset -from paddlenlp.data import Stack, Tuple, Pad, Dict -from paddlenlp.data.sampler import SamplerHelper from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman -from paddlenlp.transformers import LinearDecayWithWarmup -from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer -from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer +from paddlenlp.transformers import ( + BertForSequenceClassification, + BertTokenizer, + LinearDecayWithWarmup, + TinyBertForSequenceClassification, + TinyBertTokenizer, +) from paddlenlp.transformers.distill_utils import to_distill FORMAT = "%(asctime)s-%(levelname)s: %(message)s" @@ -297,7 +298,7 @@ def do_train(args): dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True ) - num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list) + num_classes = 1 if train_ds.label_list is None else len(train_ds.label_list) student = model_class.from_pretrained(args.student_model_name_or_path, num_classes=num_classes) teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type] teacher = teacher_model_class.from_pretrained(args.teacher_path, num_classes=num_classes) @@ -337,7 +338,6 @@ def do_train(args): teacher = to_distill(teacher, return_attentions=True, return_qkv=False, return_layer_outputs=True) student = to_distill(student, return_attentions=True, return_qkv=False, return_layer_outputs=True) - pad_token_id = 0 global_step = 0 tic_train = time.time() best_res = 0.0 diff --git a/paddlenlp/transformers/ofa_utils.py b/paddlenlp/transformers/ofa_utils.py index 347acc69cf91..e0f7bb240258 100644 --- a/paddlenlp/transformers/ofa_utils.py +++ b/paddlenlp/transformers/ofa_utils.py @@ -88,7 +88,6 @@ def mha_ofa_forward(self, query, key, value, attn_mask=None, cache=None): outs.append(weights) if cache is not None: outs.append(cache) - if hasattr(self.q_proj, "fn") and self.q_proj.fn.cur_config["expand_ratio"] is not None: self.num_heads = int(float(self.num_heads) / self.q_proj.fn.cur_config["expand_ratio"]) return out if len(outs) == 1 else tuple(outs) @@ -117,10 +116,8 @@ def encoder_ofa_forward( head_mask = paddle.unsqueeze(paddle.unsqueeze(paddle.unsqueeze(head_mask, 1), -1), -1) else: head_mask = [None] * self.num_layers - for i, mod in enumerate(self.layers): output = mod(output, src_mask=[src_mask[0], head_mask[i]]) - if self.norm is not None: output = self.norm(output) @@ -288,22 +285,26 @@ def compute_neuron_head_importance( for i, batch in enumerate(data_loader): labels = None - if label_names is not None: - labels = [] - for label in label_names: - labels.append(batch.pop(label)) - labels = tuple(labels) - elif "labels" in batch: - labels = batch.pop("labels") - # For token cls tasks - for key in ("length", "seq_len"): - if key in batch: - batch.pop(key) - elif "start_positions" in batch and "end_positions" in batch: - labels = (batch.pop("start_positions"), batch.pop("end_positions")) - - batch["attention_mask"] = [None, head_mask] - logits = model(**batch) + if isinstance(batch, list): + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids, attention_mask=[None, head_mask]) + else: + if label_names is not None: + labels = [] + for label in label_names: + labels.append(batch.pop(label)) + labels = tuple(labels) + elif "labels" in batch: + labels = batch.pop("labels") + # For token cls tasks + for key in ("length", "seq_len"): + if key in batch: + batch.pop(key) + elif "start_positions" in batch and "end_positions" in batch: + labels = (batch.pop("start_positions"), batch.pop("end_positions")) + + batch["attention_mask"] = [None, head_mask] + logits = model(**batch) if loss_fct is not None: loss = loss_fct(logits, labels) diff --git a/paddlenlp/transformers/ppminilm/configuration.py b/paddlenlp/transformers/ppminilm/configuration.py new file mode 100644 index 000000000000..70330b10eb7d --- /dev/null +++ b/paddlenlp/transformers/ppminilm/configuration.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PPMiniLM model configuration""" +from __future__ import annotations + +from typing import Dict + +from paddlenlp.transformers.configuration_utils import PretrainedConfig + +__all__ = ["PPMINILM_PRETRAINED_INIT_CONFIGURATION", "PPMiniLMConfig", "PPMINILM_PRETRAINED_RESOURCE_FILES_MAP"] + +PPMINILM_PRETRAINED_INIT_CONFIGURATION = { + "ppminilm-6l-768h": { + "attention_probs_dropout_prob": 0.1, + "intermediate_size": 3072, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 6, + "type_vocab_size": 2, + "vocab_size": 21128, + "pad_token_id": 0, + }, +} + +PPMINILM_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": { + "ppminilm-6l-768h": "https://bj.bcebos.com/paddlenlp/models/transformers/ppminilm-6l-768h/ppminilm-6l-768h.pdparams", + }, +} + + +class PPMiniLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PPMiniLMModel`]. It is used to + instantiate a PPMiniLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PPMiniLM ppminilm-6l-768h architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 21128): + Vocabulary size of the PPMiniLM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PPMiniLMModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`PPMiniLMModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from paddlenlp.transformers import PPMiniLMModel, PPMiniLMConfig + + >>> # Initializing a PPMiniLM ppminilm-6l-768h style configuration + >>> configuration = PPMiniLMConfig() + + >>> # Initializing a model from the ppminilm-6l-768h style configuration + >>> model = PPMiniLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "ppminilm" + attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"} + pretrained_init_configuration = PPMINILM_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + vocab_size: int = 21128, + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + attention_probs_dropout_prob: float = 0.1, + max_position_embeddings: int = 512, + type_vocab_size: int = 2, + initializer_range=0.02, + pad_token_id: int = 0, + do_lower_case: bool = True, + is_split_into_words: bool = False, + max_seq_len: int = 128, + pad_to_max_seq_len: bool = False, + layer_norm_eps: float = 1e-12, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + self.do_lower_case = do_lower_case + self.max_seq_len = max_seq_len + self.is_split_into_words = is_split_into_words + self.pad_token_id = pad_token_id + self.pad_to_max_seq_len = pad_to_max_seq_len + self.initializer_range = initializer_range + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps diff --git a/paddlenlp/transformers/ppminilm/modeling.py b/paddlenlp/transformers/ppminilm/modeling.py index 4bf54bc8a664..77cb4bdb92cd 100644 --- a/paddlenlp/transformers/ppminilm/modeling.py +++ b/paddlenlp/transformers/ppminilm/modeling.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import paddle import paddle.nn as nn -from paddle.common_ops_import import core -from paddlenlp.experimental import FasterPretrainedModel, FasterTokenizer -from paddlenlp.utils.log import logger - -from .. import register_base_model +from ...utils.env import CONFIG_NAME +from .. import PretrainedModel, register_base_model +from .configuration import ( + PPMINILM_PRETRAINED_INIT_CONFIGURATION, + PPMINILM_PRETRAINED_RESOURCE_FILES_MAP, + PPMiniLMConfig, +) __all__ = [ "PPMiniLMModel", @@ -36,23 +37,14 @@ class PPMiniLMEmbeddings(nn.Layer): Include embeddings from word, position and token_type embeddings. """ - def __init__( - self, - vocab_size, - hidden_size=768, - hidden_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - pad_token_id=0, - weight_attr=None, - ): + def __init__(self, config: PPMiniLMConfig): super(PPMiniLMEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, weight_attr=weight_attr) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size, weight_attr=weight_attr) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size, weight_attr=weight_attr) - self.layer_norm = nn.LayerNorm(hidden_size) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None, position_ids=None): if position_ids is None: @@ -67,7 +59,6 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None): input_embedings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = input_embedings + position_embeddings + token_type_embeddings embeddings = self.layer_norm(embeddings) embeddings = self.dropout(embeddings) @@ -75,9 +66,9 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None): class PPMiniLMPooler(nn.Layer): - def __init__(self, hidden_size, weight_attr=None): + def __init__(self, config: PPMiniLMConfig): super(PPMiniLMPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size, weight_attr=weight_attr) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): @@ -89,7 +80,7 @@ def forward(self, hidden_states): return pooled_output -class PPMiniLMPretrainedModel(FasterPretrainedModel): +class PPMiniLMPretrainedModel(PretrainedModel): r""" An abstract class for pretrained PPMiniLM models. It provides PPMiniLM related `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, @@ -98,34 +89,13 @@ class PPMiniLMPretrainedModel(FasterPretrainedModel): Refer to :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. """ - - pretrained_init_configuration = { - "ppminilm-6l-768h": { - "attention_probs_dropout_prob": 0.1, - "intermediate_size": 3072, - "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, - "initializer_range": 0.02, - "max_position_embeddings": 512, - "num_attention_heads": 12, - "num_hidden_layers": 6, - "type_vocab_size": 4, - "vocab_size": 21128, - "pad_token_id": 0, - }, - } - resource_files_names = {"model_state": "model_state.pdparams", "vocab_file": "vocab.txt"} - pretrained_resource_files_map = { - "model_state": { - "ppminilm-6l-768h": "https://bj.bcebos.com/paddlenlp/models/transformers/ppminilm-6l-768h/ppminilm-6l-768h.pdparams", - }, - "vocab_file": { - "ppminilm-6l-768h": "https://bj.bcebos.com/paddlenlp/models/transformers/ppminilm-6l-768h/vocab.txt", - }, - } + model_config_file = CONFIG_NAME + config_class = PPMiniLMConfig + resource_files_names = {"model_state": "model_state.pdparams"} base_model_prefix = "ppminilm" - use_faster_tokenizer = False + + pretrained_init_configuration = PPMINILM_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = PPMINILM_PRETRAINED_RESOURCE_FILES_MAP def init_weights(self, layer): """Initialization hook""" @@ -136,53 +106,12 @@ def init_weights(self, layer): layer.weight.set_value( paddle.tensor.normal( mean=0.0, - std=self.initializer_range - if hasattr(self, "initializer_range") - else self.ppminilm.config["initializer_range"], + std=self.config.initializer_range, shape=layer.weight.shape, ) ) elif isinstance(layer, nn.LayerNorm): - layer._epsilon = 1e-12 - - def add_faster_tokenizer_op(self): - self.ppminilm.tokenizer = FasterTokenizer( - self.ppminilm.vocab, - do_lower_case=self.ppminilm.do_lower_case, - is_split_into_words=self.ppminilm.is_split_into_words, - ) - - def to_static(self, output_path, use_faster_tokenizer=True, is_text_pair=False): - self.eval() - self.use_faster_tokenizer = use_faster_tokenizer - # Convert to static graph with specific input description - if self.use_faster_tokenizer: - self.add_faster_tokenizer_op() - if is_text_pair: - model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec(shape=[None], dtype=core.VarDesc.VarType.STRINGS, name="text"), - paddle.static.InputSpec(shape=[None], dtype=core.VarDesc.VarType.STRINGS, name="text_pair"), - ], - ) - else: - model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec(shape=[None], dtype=core.VarDesc.VarType.STRINGS, name="text") - ], - ) - else: - model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), # input_ids - paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"), # segment_ids - ], - ) - paddle.jit.save(model, output_path) - logger.info("Already save the static model to the path %s" % output_path) + layer._epsilon = self.config.layer_norm_eps @register_base_model @@ -198,123 +127,42 @@ class PPMiniLMModel(PPMiniLMPretrainedModel): and refer to the Paddle documentation for all matter related to general usage and behavior. Args: - vocab_size (int): - Vocabulary size of `inputs_ids` in `PPMiniLMModel`. Also is the vocab size of token embedding matrix. - Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `PPMiniLMModel`. - hidden_size (int, optional): - Dimensionality of the embedding layer, encoder layers and pooler layer. Defaults to `768`. - num_hidden_layers (int, optional): - Number of hidden layers in the Transformer encoder. Defaults to `12`. - num_attention_heads (int, optional): - Number of attention heads for each attention layer in the Transformer encoder. - Defaults to `12`. - intermediate_size (int, optional): - Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors - to ff layers are firstly projected from `hidden_size` to `intermediate_size`, - and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. - Defaults to `3072`. - hidden_act (str, optional): - The non-linear activation function in the feed-forward layer. - ``"gelu"``, ``"relu"`` and any other paddle supported activation functions - are supported. Defaults to `"gelu"`. - hidden_dropout_prob (float, optional): - The dropout probability for all fully connected layers in the embeddings and encoder. - Defaults to `0.1`. - attention_probs_dropout_prob (float, optional): - The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. - Defaults to `0.1`. - max_position_embeddings (int, optional): - The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input - sequence. Defaults to `512`. - type_vocab_size (int, optional): - The vocabulary size of the `token_type_ids`. - Defaults to `2`. - initializer_range (float, optional): - The standard deviation of the normal initializer for initializing all weight matrices. - Defaults to `0.02`. - - .. note:: - A normal_initializer initializes weight matrices as normal distributions. - See :meth:`PPMiniLMPretrainedModel._init_weights()` for how weights are initialized in `PPMiniLMModel`. - - pad_token_id(int, optional): - The index of padding token in the token vocabulary. - Defaults to `0`. + config (:class:`PPMiniLMConfig`): + An instance of PPMiniLMConfig used to construct PPMiniLMModel. """ - def __init__( - self, - vocab_size, - vocab_file, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - pad_token_id=0, - do_lower_case=True, - is_split_into_words=False, - max_seq_len=128, - pad_to_max_seq_len=False, - ): - super(PPMiniLMModel, self).__init__() - if not os.path.isfile(vocab_file): - raise ValueError( - "Can't find a vocabulary file at path '{}'. To load the " - "vocabulary from a pretrained model please use " - "`model = PPMiniLMModel.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) - ) - self.vocab = self.load_vocabulary(vocab_file) - self.do_lower_case = do_lower_case - self.max_seq_len = max_seq_len - self.is_split_into_words = is_split_into_words - self.pad_token_id = pad_token_id - self.pad_to_max_seq_len = pad_to_max_seq_len - self.initializer_range = initializer_range - weight_attr = paddle.ParamAttr( - initializer=nn.initializer.TruncatedNormal(mean=0.0, std=self.initializer_range) - ) - self.embeddings = PPMiniLMEmbeddings( - vocab_size, - hidden_size, - hidden_dropout_prob, - max_position_embeddings, - type_vocab_size, - pad_token_id, - weight_attr, - ) + def __init__(self, config: PPMiniLMConfig): + super(PPMiniLMModel, self).__init__(config) + self.embeddings = PPMiniLMEmbeddings(config) + encoder_layer = nn.TransformerEncoderLayer( - hidden_size, - num_attention_heads, - intermediate_size, - dropout=hidden_dropout_prob, - activation=hidden_act, - attn_dropout=attention_probs_dropout_prob, - act_dropout=0, - weight_attr=weight_attr, - normalize_before=False, + config.hidden_size, + config.num_attention_heads, + config.intermediate_size, + dropout=config.hidden_dropout_prob, + activation=config.hidden_act, + attn_dropout=config.attention_probs_dropout_prob, + act_dropout=0.0, ) - self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers) - self.pooler = PPMiniLMPooler(hidden_size, weight_attr) + self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers) + self.pooler = PPMiniLMPooler(config) self.apply(self.init_weights) + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): r""" Args: - input_ids (Tensor, List[string]): + input_ids (Tensor): If `input_ids` is a Tensor object, it is an indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. - If `input_ids` is a list of string, `self.use_faster_tokenizer` - should be True, and the network contains `faster_tokenizer` - operator. token_type_ids (Tensor, string, optional): If `token_type_ids` is a Tensor object: Segment token indices to indicate different portions of the inputs. @@ -328,9 +176,6 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_m Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. Defaults to `None`, which means we don't add segment embeddings. - If `token_type_ids` is a list of string: `self.use_faster_tokenizer` - should be True, and the network contains `faster_tokenizer` operator. - position_ids (Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, max_position_embeddings - 1]``. @@ -377,18 +222,16 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_m sequence_output, pooled_output = model(**inputs) """ - # Only for saving - if self.use_faster_tokenizer: - input_ids, token_type_ids = self.tokenizer( - text=input_ids, - text_pair=token_type_ids, - max_seq_len=self.max_seq_len, - pad_to_max_seq_len=self.pad_to_max_seq_len, - ) if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2] ) + else: + if attention_mask.ndim == 2: + # attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length] + attention_mask = attention_mask.unsqueeze(axis=[1, 2]).astype(paddle.get_default_dtype()) + attention_mask = (1.0 - attention_mask) * -1e4 + embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids ) @@ -415,12 +258,14 @@ class PPMiniLMForSequenceClassification(PPMiniLMPretrainedModel): of `paddlenlp.transformers.PPMiniLMModel` instance. Defaults to `None`. """ - def __init__(self, ppminilm, num_classes=2, dropout=None): - super(PPMiniLMForSequenceClassification, self).__init__() - self.num_classes = num_classes - self.ppminilm = ppminilm # allow ppminilm to be config - self.dropout = nn.Dropout(dropout if dropout is not None else self.ppminilm.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.ppminilm.config["hidden_size"], num_classes) + def __init__(self, config: PPMiniLMConfig): + super(PPMiniLMForSequenceClassification, self).__init__(config) + self.ppminilm = PPMiniLMModel(config) + self.num_labels = config.num_labels + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): @@ -453,7 +298,6 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_m logits = model(**inputs) """ - self.ppminilm.use_faster_tokenizer = self.use_faster_tokenizer _, pooled_output = self.ppminilm( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask ) @@ -474,10 +318,13 @@ class PPMiniLMForQuestionAnswering(PPMiniLMPretrainedModel): An instance of `PPMiniLMModel`. """ - def __init__(self, ppminilm): - super(PPMiniLMForQuestionAnswering, self).__init__() - self.ppminilm = ppminilm # allow ppminilm to be config - self.classifier = nn.Linear(self.ppminilm.config["hidden_size"], 2) + def __init__(self, config: PPMiniLMConfig): + super(PPMiniLMForQuestionAnswering, self).__init__(config) + self.ppminilm = PPMiniLMModel(config) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, 2) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): @@ -547,12 +394,14 @@ class PPMiniLMForMultipleChoice(PPMiniLMPretrainedModel): instance `ppminilm`. Defaults to None. """ - def __init__(self, ppminilm, num_choices=2, dropout=None): - super(PPMiniLMForMultipleChoice, self).__init__() - self.num_choices = num_choices - self.ppminilm = ppminilm - self.dropout = nn.Dropout(dropout if dropout is not None else self.ppminilm.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.ppminilm.config["hidden_size"], 1) + def __init__(self, config: PPMiniLMConfig): + super(PPMiniLMForMultipleChoice, self).__init__(config) + self.num_choices = config.num_choices + self.ppminilm = PPMiniLMModel(config) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): diff --git a/paddlenlp/transformers/ppminilm/tokenizer.py b/paddlenlp/transformers/ppminilm/tokenizer.py index 274f43738168..8309cc64423b 100644 --- a/paddlenlp/transformers/ppminilm/tokenizer.py +++ b/paddlenlp/transformers/ppminilm/tokenizer.py @@ -85,11 +85,15 @@ def __init__( self, vocab_file, do_lower_case=True, + do_basic_tokenize=True, + never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, **kwargs ): @@ -101,7 +105,14 @@ def __init__( ) self.do_lower_case = do_lower_case self.vocab = self.load_vocabulary(vocab_file, unk_token=unk_token) - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=unk_token) @property @@ -114,6 +125,9 @@ def vocab_size(self): """ return len(self.vocab) + def get_vocab(self): + return dict(self.vocab.token_to_idx, **self.added_tokens_encoder) + def _tokenize(self, text): r""" End-to-end tokenization for PPMiniM models. @@ -125,9 +139,15 @@ def _tokenize(self, text): List[str]: A list of string representing converted tokens. """ split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens def convert_tokens_to_string(self, tokens): @@ -253,3 +273,36 @@ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): if token_ids_1 is None: return len(_cls + token_ids_0 + _sep) * [0] return len(_cls + token_ids_0 + _sep) * [0] + len(token_ids_1 + _sep) * [1] + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``encode`` methods. + + Args: + token_ids_0 (List[int]): + A list of `inputs_ids` for the first sequence. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. Defaults to None. + already_has_special_tokens (bool, optional): Whether or not the token list is already + formatted with special tokens for the model. Defaults to None. + + Returns: + List[int]: The list of integers either be 0 or 1: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in self.all_special_ids else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.vocab._idx_to_token.get(index, self.unk_token) diff --git a/tests/transformers/ppminilm/__init__.py b/tests/transformers/ppminilm/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/tests/transformers/ppminilm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/ppminilm/test_modeling.py b/tests/transformers/ppminilm/test_modeling.py new file mode 100644 index 000000000000..8413627f7597 --- /dev/null +++ b/tests/transformers/ppminilm/test_modeling.py @@ -0,0 +1,308 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from paddlenlp.transformers import ( + PPMiniLMForMultipleChoice, + PPMiniLMForQuestionAnswering, + PPMiniLMForSequenceClassification, + PPMiniLMModel, +) +from paddlenlp.transformers.ppminilm.configuration import PPMiniLMConfig + +from ...testing_utils import slow +from ..test_configuration_common import ConfigTester +from ..test_modeling_common import ( + ModelTesterMixin, + ModelTesterPretrainedMixin, + ids_tensor, + random_attention_mask, +) + + +class PPMiniLMModelTester: + def __init__( + self, + parent: PPMiniLMModelTest, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + pool_act="tanh", + num_labels=3, + num_choices=4, + scope=None, + dropout=0.56, + ): + self.parent: PPMiniLMModelTest = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.pad_token_id = pad_token_id + self.pool_act = pool_act + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.dropout = dropout + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + config = self.get_config() + return config, input_ids, token_type_ids, input_mask + + def get_config(self) -> PPMiniLMConfig: + return PPMiniLMConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + pool_act=self.pool_act, + num_labels=self.num_labels, + num_choices=self.num_choices, + ) + + def create_and_check_model( + self, + config: PPMiniLMConfig, + input_ids, + token_type_ids, + input_mask, + ): + model = PPMiniLMModel(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.hidden_size]) + + def create_and_check_for_multiple_choice( + self, + config: PPMiniLMConfig, + input_ids, + token_type_ids, + input_mask, + ): + model = PPMiniLMForMultipleChoice(config) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_input_mask = input_mask.unsqueeze(1).expand([-1, self.num_choices, -1]) + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + ) + self.parent.assertEqual(result.shape, [self.batch_size, self.num_choices]) + + def create_and_check_for_question_answering( + self, + config, + input_ids, + token_type_ids, + input_mask, + ): + model = PPMiniLMForQuestionAnswering(config) + model.eval() + result = model( + input_ids, + token_type_ids=token_type_ids, + attention_mask=input_mask, + ) + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length]) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length]) + + def create_and_check_for_sequence_classification( + self, + config: PPMiniLMConfig, + input_ids, + token_type_ids, + input_mask, + ): + model = PPMiniLMForSequenceClassification(config) + model.eval() + result = model(input_ids, token_type_ids=token_type_ids) + self.parent.assertEqual(result.shape, [self.batch_size, self.num_labels]) + + def test_addition_params(self, config: PPMiniLMConfig, *args, **kwargs): + config.num_labels = 7 + config.classifier_dropout = 0.98 + + model = PPMiniLMForSequenceClassification(config) + model.eval() + + self.parent.assertEqual(model.classifier.weight.shape, [config.hidden_size, 7]) + self.parent.assertEqual(model.dropout.p, 0.98) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class PPMiniLMModelTest(ModelTesterMixin, unittest.TestCase): + base_model_class = PPMiniLMModel + + all_model_classes = ( + PPMiniLMModel, + PPMiniLMForMultipleChoice, + PPMiniLMForQuestionAnswering, + PPMiniLMForSequenceClassification, + ) + + def setUp(self): + super().setUp() + + self.model_tester = PPMiniLMModelTester(self) + self.config_tester = ConfigTester(self, config_class=PPMiniLMConfig, vocab_size=256, hidden_size=24) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_custom_params(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.test_addition_params(*config_and_inputs) + + def test_model_name_list(self): + config = self.model_tester.get_config() + model = self.base_model_class(config) + self.assertTrue(len(model.model_name_list) != 0) + + @slow + def test_params_compatibility_of_init_method(self): + """test initing model with different params""" + model: PPMiniLMForSequenceClassification = PPMiniLMForSequenceClassification.from_pretrained( + "ppminilm-6l-768h", num_labels=4, dropout=0.3 + ) + assert model.num_labels == 4 + assert model.dropout.p == 0.3 + + +class PPMiniLMModelIntegrationTest(ModelTesterPretrainedMixin, unittest.TestCase): + base_model_class = PPMiniLMModel + + @slow + def test_inference_no_attention(self): + model = PPMiniLMModel.from_pretrained("ppminilm-6l-768h") + model.eval() + input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + with paddle.no_grad(): + output = model(input_ids)[0] + expected_shape = [1, 11, 768] + self.assertEqual(output.shape, expected_shape) + expected_slice = paddle.to_tensor( + [ + [ + [-0.79207015, 0.40036711, 1.18436682], + [-0.85833853, 0.34584877, 0.93867993], + [-0.97080499, 0.33460250, 0.69212830], + ] + ] + ) + self.assertTrue(paddle.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + @slow + def test_inference_with_attention(self): + model = PPMiniLMModel.from_pretrained("ppminilm-6l-768h") + model.eval() + input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + attention_mask = paddle.to_tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + with paddle.no_grad(): + output = model(input_ids, attention_mask=attention_mask)[0] + expected_shape = [1, 11, 768] + self.assertEqual(output.shape, expected_shape) + expected_slice = paddle.to_tensor( + [ + [ + [-0.79207015, 0.40036711, 1.18436682], + [-0.85833853, 0.34584877, 0.93867993], + [-0.97080499, 0.33460250, 0.69212830], + ] + ] + ) + self.assertTrue(paddle.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformers/ppminilm/test_tokenizer.py b/tests/transformers/ppminilm/test_tokenizer.py new file mode 100644 index 000000000000..70ba9400bcab --- /dev/null +++ b/tests/transformers/ppminilm/test_tokenizer.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from paddlenlp.transformers.ppminilm.tokenizer import PPMiniLMTokenizer + +from ...testing_utils import slow +from ...transformers.test_tokenizer_common import ( + TokenizerTesterMixin, + filter_non_english, +) + + +class PPMiniLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = PPMiniLMTokenizer + space_between_special_tokens = True + from_pretrained_filter = filter_non_english + test_seq2seq = False + + def setUp(self): + super().setUp() + + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + + self.vocab_file = os.path.join(self.tmpdirname, PPMiniLMTokenizer.resource_files_names["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def get_input_output_texts(self, tokenizer): + input_text = "UNwant\u00E9d,running" + output_text = "unwanted, running" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + tokens = tokenizer.tokenize("UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11]) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("ppminilm-6l-768h") + + text = tokenizer.encode("sequence builders", return_token_type_ids=None, add_special_tokens=False)["input_ids"] + text_2 = tokenizer.encode("multi-sequence build", return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_sentence == [101] + text + [102] + assert encoded_pair == [101] + text + [102] + text_2 + [102] + + def test_offsets_with_special_characters(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + sentence = f"A, naïve {tokenizer.mask_token} AllenNLP sentence." + tokens = tokenizer.encode( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + + expected_results = [ + ((0, 0), tokenizer.cls_token), + ((0, 1), "a"), + ((1, 2), ","), + ((3, 5), "na"), + ((5, 8), "##ive"), + ((9, 15), tokenizer.mask_token), + ((16, 21), "allen"), + ((21, 22), "##n"), + ((22, 24), "##lp"), + ((25, 27), "se"), + ((27, 29), "##nt"), + ((29, 33), "##ence"), + ((33, 34), "."), + ((0, 0), tokenizer.sep_token), + ] + + self.assertEqual( + [e[1] for e in expected_results], tokenizer.convert_ids_to_tokens(tokens["input_ids"]) + ) + self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + + def test_change_tokenize_chinese_chars(self): + list_of_commun_chinese_char = ["的", "人", "有"] + text_with_chinese_char = "".join(list_of_commun_chinese_char) + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + + kwargs["tokenize_chinese_chars"] = True + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + ids_without_spe_char_p = tokenizer.encode( + text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + tokens_without_spe_char_p = tokenizer.convert_ids_to_tokens(ids_without_spe_char_p) + + # it is expected that each Chinese character is not preceded by "##" + self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)