From 4bec43f0dc5e0e3436789acfb3d00f808de56697 Mon Sep 17 00:00:00 2001 From: Milen <35913314+1649759610@users.noreply.github.com> Date: Fri, 13 Jan 2023 10:38:02 +0800 Subject: [PATCH] [BigFix] Fix the bool parameter inputting problem (#4460) * Fix the bool parameter inputting problem --- applications/information_extraction/label_studio.py | 12 +++++++++++- .../unified_sentiment_extraction/label_studio.py | 4 ++-- .../unified_sentiment_extraction/utils.py | 11 +++++++++++ model_zoo/uie/doccano.py | 13 ++++++++----- model_zoo/uie/utils.py | 11 +++++++++++ 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/applications/information_extraction/label_studio.py b/applications/information_extraction/label_studio.py index 3231176b7513..6cf157705e17 100644 --- a/applications/information_extraction/label_studio.py +++ b/applications/information_extraction/label_studio.py @@ -33,6 +33,16 @@ def set_seed(seed): np.random.seed(seed) +def str2bool(v): + """Support bool type for argparse.""" + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Unsupported value encountered.") + + def do_convert(): set_seed(args.seed) @@ -125,7 +135,7 @@ def _save_examples(save_dir, file_name, examples): parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.") parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+", help="Used only for the classification task, the options for classification") parser.add_argument("--prompt_prefix", default="情感倾向", type=str, help="Used only for the classification task, the prompt prefix for classification") - parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.") + parser.add_argument("--is_shuffle", default="True", type=str2bool, help="Whether to shuffle the labeled dataset, defaults to True.") parser.add_argument("--layout_analysis", default=False, type=bool, help="Enable layout analysis to optimize the order of OCR result.") parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization") parser.add_argument("--separator", type=str, default='##', help="Used only for entity/aspect-level classification task, separator for entity label and classification label") diff --git a/applications/sentiment_analysis/unified_sentiment_extraction/label_studio.py b/applications/sentiment_analysis/unified_sentiment_extraction/label_studio.py index a25192568303..dcd52e879c32 100644 --- a/applications/sentiment_analysis/unified_sentiment_extraction/label_studio.py +++ b/applications/sentiment_analysis/unified_sentiment_extraction/label_studio.py @@ -23,7 +23,7 @@ import numpy as np import paddle -from utils import load_txt +from utils import load_txt, str2bool from paddlenlp.utils.log import logger @@ -727,7 +727,7 @@ def _save_examples(save_dir, file_name, examples): parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*", help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60% samples used for training, 20% for evaluation and 20% for test.") parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Two task types [ext, cls] are supported, ext represents the aspect-based extraction task and cls represents the sentence-level classification task, defaults to ext.") parser.add_argument("--options", type=str, nargs="+", help="Used only for the classification task, the options for classification") - parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.") + parser.add_argument("--is_shuffle", type=str2bool, default="True", help="Whether to shuffle the labeled dataset, defaults to True.") parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization") args = parser.parse_args() diff --git a/applications/sentiment_analysis/unified_sentiment_extraction/utils.py b/applications/sentiment_analysis/unified_sentiment_extraction/utils.py index eecc1d7ba00c..3b52a525fb2a 100644 --- a/applications/sentiment_analysis/unified_sentiment_extraction/utils.py +++ b/applications/sentiment_analysis/unified_sentiment_extraction/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import json import random import re @@ -50,6 +51,16 @@ def write_json_file(examples, save_path): f.write(line + "\n") +def str2bool(v): + """Support bool type for argparse.""" + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Unsupported value encountered.") + + def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None): """ Create dataloader. diff --git a/model_zoo/uie/doccano.py b/model_zoo/uie/doccano.py index 274006fc9dc5..3cacd9a5d68f 100644 --- a/model_zoo/uie/doccano.py +++ b/model_zoo/uie/doccano.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import time import argparse import json +import os +import time from decimal import Decimal + import numpy as np -from paddlenlp.utils.log import logger +from utils import convert_cls_examples, convert_ext_examples, set_seed, str2bool -from utils import set_seed, convert_ext_examples, convert_cls_examples +from paddlenlp.utils.log import logger def do_convert(): @@ -100,6 +101,8 @@ def _save_examples(save_dir, file_name, examples): indexes = np.random.permutation(len(raw_examples)) index_list = indexes.tolist() raw_examples = [raw_examples[i] for i in indexes] + else: + index_list = list(range(len(raw_examples))) i1, i2, _ = args.splits p1 = int(len(raw_examples) * i1) @@ -164,7 +167,7 @@ def _save_examples(save_dir, file_name, examples): parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.") parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+", help="Used only for the classification task, the options for classification") parser.add_argument("--prompt_prefix", default="情感倾向", type=str, help="Used only for the classification task, the prompt prefix for classification") - parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.") + parser.add_argument("--is_shuffle", default="True", type=str2bool, help="Whether to shuffle the labeled dataset, defaults to True.") parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization") parser.add_argument("--separator", type=str, default='##', help="Used only for entity/aspect-level classification task, separator for entity label and classification label") parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.") diff --git a/model_zoo/uie/utils.py b/model_zoo/uie/utils.py index 67d5598f209b..4eb7194ddc1d 100644 --- a/model_zoo/uie/utils.py +++ b/model_zoo/uie/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import json import math import random @@ -30,6 +31,16 @@ def set_seed(seed): np.random.seed(seed) +def str2bool(v): + """Support bool type for argparse.""" + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Unsupported value encountered.") + + def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None): """ Create dataloader.