diff --git a/model_zoo/uie/README.md b/model_zoo/uie/README.md index 164e940fa1e0..7dd54b7ded62 100644 --- a/model_zoo/uie/README.md +++ b/model_zoo/uie/README.md @@ -640,7 +640,7 @@ python finetune.py \ --device gpu ``` -多卡启动: +如果在GPU环境中使用,可以指定``gpus``参数进行多卡训练: ```shell python -u -m paddle.distributed.launch --gpus "0,1" finetune.py \ @@ -701,18 +701,24 @@ python evaluate.py \ 输出打印示例: ```text -[2022-06-23 08:25:23,017] [ INFO] - ----------------------------- -[2022-06-23 08:25:23,017] [ INFO] - Class name: 时间 -[2022-06-23 08:25:23,018] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000 -[2022-06-23 08:25:23,145] [ INFO] - ----------------------------- -[2022-06-23 08:25:23,146] [ INFO] - Class name: 目的地 -[2022-06-23 08:25:23,146] [ INFO] - Evaluation precision: 0.64286 | recall: 0.90000 | F1: 0.75000 -[2022-06-23 08:25:23,272] [ INFO] - ----------------------------- -[2022-06-23 08:25:23,273] [ INFO] - Class name: 费用 -[2022-06-23 08:25:23,273] [ INFO] - Evaluation precision: 0.11111 | recall: 0.10000 | F1: 0.10526 -[2022-06-23 08:25:23,399] [ INFO] - ----------------------------- -[2022-06-23 08:25:23,399] [ INFO] - Class name: 出发地 -[2022-06-23 08:25:23,400] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000 +[2022-09-14 03:13:58,877] [ INFO] - ----------------------------- +[2022-09-14 03:13:58,877] [ INFO] - Class Name: 疾病 +[2022-09-14 03:13:58,877] [ INFO] - Evaluation Precision: 0.89744 | Recall: 0.83333 | F1: 0.86420 +[2022-09-14 03:13:59,145] [ INFO] - ----------------------------- +[2022-09-14 03:13:59,145] [ INFO] - Class Name: 手术治疗 +[2022-09-14 03:13:59,145] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805 +[2022-09-14 03:13:59,439] [ INFO] - ----------------------------- +[2022-09-14 03:13:59,440] [ INFO] - Class Name: 检查 +[2022-09-14 03:13:59,440] [ INFO] - Evaluation Precision: 0.77778 | Recall: 0.56757 | F1: 0.65625 +[2022-09-14 03:13:59,708] [ INFO] - ----------------------------- +[2022-09-14 03:13:59,709] [ INFO] - Class Name: X的手术治疗 +[2022-09-14 03:13:59,709] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805 +[2022-09-14 03:13:59,893] [ INFO] - ----------------------------- +[2022-09-14 03:13:59,893] [ INFO] - Class Name: X的实验室检查 +[2022-09-14 03:13:59,894] [ INFO] - Evaluation Precision: 0.71429 | Recall: 0.55556 | F1: 0.62500 +[2022-09-14 03:14:00,057] [ INFO] - ----------------------------- +[2022-09-14 03:14:00,058] [ INFO] - Class Name: X的影像学检查 +[2022-09-14 03:14:00,058] [ INFO] - Evaluation Precision: 0.69231 | Recall: 0.45000 | F1: 0.54545 ``` 可配置参数说明: diff --git a/model_zoo/uie/data_distill/README.md b/model_zoo/uie/data_distill/README.md index 225767b96303..8f0d034e55f3 100644 --- a/model_zoo/uie/data_distill/README.md +++ b/model_zoo/uie/data_distill/README.md @@ -146,13 +146,6 @@ python train.py \ 'text': '登革热'}]}] ``` -## 效果验证 - -| 模型 | Entity-F1 | SPO-F1 | -| :---: | :--------: | :--------: | -| UIE-Finetune | 78.57 | 56.25 | -| GPLinker-ernie-3.0-mini-zh | 68.18 | 47.06 | -| GPLinker-ernie-3.0-mini-zh + UIE数据蒸馏 | 76.38 | 50.42 | # References diff --git a/model_zoo/uie/data_distill/data_distill.py b/model_zoo/uie/data_distill/data_distill.py index 1be16b1f5857..74d0045470f8 100644 --- a/model_zoo/uie/data_distill/data_distill.py +++ b/model_zoo/uie/data_distill/data_distill.py @@ -85,7 +85,7 @@ def do_data_distill(): for text in tqdm(infer_texts, desc="Predicting: ", leave=False): infer_results.extend(uie(text)) - train_synthetic_lines = synthetic2distill(texts, infer_results, + train_synthetic_lines = synthetic2distill(infer_texts, infer_results, args.task_type) # Concat origin and synthetic data diff --git a/model_zoo/uie/evaluate.py b/model_zoo/uie/evaluate.py index 5cffcfa9da0d..61fdd5fb2602 100644 --- a/model_zoo/uie/evaluate.py +++ b/model_zoo/uie/evaluate.py @@ -23,7 +23,7 @@ from paddlenlp.utils.log import logger from model import UIE -from utils import convert_example, reader, unify_prompt_name +from utils import convert_example, reader, unify_prompt_name, get_relation_type_dict, create_data_loader @paddle.no_grad() @@ -60,28 +60,34 @@ def do_eval(): max_seq_len=args.max_seq_len, lazy=False) class_dict = {} + relation_data = [] if args.debug: for data in test_ds: class_name = unify_prompt_name(data['prompt']) # Only positive examples are evaluated in debug mode if len(data['result_list']) != 0: - class_dict.setdefault(class_name, []).append(data) + if "的" not in data['prompt']: + class_dict.setdefault(class_name, []).append(data) + else: + relation_data.append((data['prompt'], data)) + relation_type_dict = get_relation_type_dict(relation_data) else: class_dict["all_classes"] = test_ds + + trans_fn = partial(convert_example, + tokenizer=tokenizer, + max_seq_len=args.max_seq_len) + for key in class_dict.keys(): if args.debug: test_ds = MapDataset(class_dict[key]) else: test_ds = class_dict[key] - test_ds = test_ds.map( - partial(convert_example, - tokenizer=tokenizer, - max_seq_len=args.max_seq_len)) - test_batch_sampler = paddle.io.BatchSampler(dataset=test_ds, - batch_size=args.batch_size, - shuffle=False) - test_data_loader = paddle.io.DataLoader( - dataset=test_ds, batch_sampler=test_batch_sampler, return_list=True) + + test_data_loader = create_data_loader(test_ds, + mode="test", + batch_size=args.batch_size, + trans_fn=trans_fn) metric = SpanEvaluator() precision, recall, f1 = evaluate(model, metric, test_data_loader) @@ -90,6 +96,22 @@ def do_eval(): logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1)) + if args.debug and len(relation_type_dict.keys()) != 0: + for key in relation_type_dict.keys(): + test_ds = MapDataset(relation_type_dict[key]) + + test_data_loader = create_data_loader(test_ds, + mode="test", + batch_size=args.batch_size, + trans_fn=trans_fn) + + metric = SpanEvaluator() + precision, recall, f1 = evaluate(model, metric, test_data_loader) + logger.info("-----------------------------") + logger.info("Class Name: X的%s" % key) + logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % + (precision, recall, f1)) + if __name__ == "__main__": # yapf: disable diff --git a/model_zoo/uie/finetune.py b/model_zoo/uie/finetune.py index 73e3fe6d5885..a53b6d6a5c8d 100644 --- a/model_zoo/uie/finetune.py +++ b/model_zoo/uie/finetune.py @@ -26,7 +26,7 @@ from model import UIE from evaluate import evaluate -from utils import set_seed, convert_example, reader, MODEL_MAP +from utils import set_seed, convert_example, reader, MODEL_MAP, create_data_loader def do_train(): @@ -57,28 +57,18 @@ def do_train(): max_seq_len=args.max_seq_len, lazy=False) - train_ds = train_ds.map( - partial(convert_example, - tokenizer=tokenizer, - max_seq_len=args.max_seq_len)) - dev_ds = dev_ds.map( - partial(convert_example, - tokenizer=tokenizer, - max_seq_len=args.max_seq_len)) - - train_batch_sampler = paddle.io.BatchSampler(dataset=train_ds, - batch_size=args.batch_size, - shuffle=True) - train_data_loader = paddle.io.DataLoader(dataset=train_ds, - batch_sampler=train_batch_sampler, - return_list=True) - - dev_batch_sampler = paddle.io.BatchSampler(dataset=dev_ds, - batch_size=args.batch_size, - shuffle=False) - dev_data_loader = paddle.io.DataLoader(dataset=dev_ds, - batch_sampler=dev_batch_sampler, - return_list=True) + trans_fn = partial(convert_example, + tokenizer=tokenizer, + max_seq_len=args.max_seq_len) + + train_data_loader = create_data_loader(train_ds, + mode="train", + batch_size=args.batch_size, + trans_fn=trans_fn) + dev_data_loader = create_data_loader(dev_ds, + mode="dev", + batch_size=args.batch_size, + trans_fn=trans_fn) if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): state_dict = paddle.load(args.init_from_ckpt) @@ -95,7 +85,6 @@ def do_train(): loss_list = [] global_step = 0 - best_step = 0 best_f1 = 0 tic_train = time.time() for epoch in range(1, args.num_epochs + 1): diff --git a/model_zoo/uie/utils.py b/model_zoo/uie/utils.py index ab220e81ae6f..a157fa1994cb 100644 --- a/model_zoo/uie/utils.py +++ b/model_zoo/uie/utils.py @@ -118,6 +118,35 @@ def set_seed(seed): np.random.seed(seed) +def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None): + """ + Create dataloader. + Args: + dataset(obj:`paddle.io.Dataset`): Dataset instance. + mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly. + batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch. + trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc. + Returns: + dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches. + """ + if trans_fn: + dataset = dataset.map(trans_fn) + + shuffle = True if mode == 'train' else False + if mode == "train": + sampler = paddle.io.DistributedBatchSampler(dataset=dataset, + batch_size=batch_size, + shuffle=shuffle) + else: + sampler = paddle.io.BatchSampler(dataset=dataset, + batch_size=batch_size, + shuffle=shuffle) + dataloader = paddle.io.DataLoader(dataset, + batch_sampler=sampler, + return_list=True) + return dataloader + + def convert_example(example, tokenizer, max_seq_len): """ example: { @@ -267,6 +296,48 @@ def unify_prompt_name(prompt): return prompt +def get_relation_type_dict(relation_data): + + def compare(a, b): + a = a[::-1] + b = b[::-1] + res = '' + for i in range(min(len(a), len(b))): + if a[i] == b[i]: + res += a[i] + else: + break + if res == "": + return res + elif res[::-1][0] == "的": + return res[::-1][1:] + return "" + + relation_type_dict = {} + added_list = [] + for i in range(len(relation_data)): + added = False + if relation_data[i][0] not in added_list: + for j in range(i + 1, len(relation_data)): + match = compare(relation_data[i][0], relation_data[j][0]) + if match != "": + match = unify_prompt_name(match) + if relation_data[i][0] not in added_list: + added_list.append(relation_data[i][0]) + relation_type_dict.setdefault(match, []).append( + relation_data[i][1]) + added_list.append(relation_data[j][0]) + relation_type_dict.setdefault(match, []).append( + relation_data[j][1]) + added = True + if not added: + added_list.append(relation_data[i][0]) + suffix = relation_data[i][0].rsplit("的", 1)[1] + suffix = unify_prompt_name(suffix) + relation_type_dict[suffix] = relation_data[i][1] + return relation_type_dict + + def add_entity_negative_example(examples, texts, prompts, label_set, negative_ratio): negative_examples = [] @@ -610,26 +681,31 @@ def _sep_cls_label(label, separator): redundants1 = inverse_relation_list[i] # 2. entity_name_set ^ subject_goldens[i] - nonentity_list = list( - set(entity_name_set) ^ set(subject_goldens[i])) - nonentity_list.sort() - - redundants2 = [ - nonentity + "的" + predicate_list[i][random.randrange( - len(predicate_list[i]))] - for nonentity in nonentity_list - ] + redundants2 = [] + if len(predicate_list[i]) != 0: + nonentity_list = list( + set(entity_name_set) ^ set(subject_goldens[i])) + nonentity_list.sort() + + redundants2 = [ + nonentity + "的" + + predicate_list[i][random.randrange( + len(predicate_list[i]))] + for nonentity in nonentity_list + ] # 3. entity_label_set ^ entity_prompts[i] - non_ent_label_list = list( - set(entity_label_set) ^ set(entity_prompts[i])) - non_ent_label_list.sort() - - redundants3 = [ - subject_goldens[i][random.randrange( - len(subject_goldens[i]))] + "的" + non_ent_label - for non_ent_label in non_ent_label_list - ] + redundants3 = [] + if len(subject_goldens[i]) != 0: + non_ent_label_list = list( + set(entity_label_set) ^ set(entity_prompts[i])) + non_ent_label_list.sort() + + redundants3 = [ + subject_goldens[i][random.randrange( + len(subject_goldens[i]))] + "的" + non_ent_label + for non_ent_label in non_ent_label_list + ] redundants_list = [redundants1, redundants2, redundants3]