Skip to content

Commit

Permalink
[BigFix] Fix the bool parameter inputting problem (#4460)
Browse files Browse the repository at this point in the history
* Fix the bool parameter inputting problem
  • Loading branch information
1649759610 authored Jan 13, 2023
1 parent a650918 commit 4bec43f
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 8 deletions.
12 changes: 11 additions & 1 deletion applications/information_extraction/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def set_seed(seed):
np.random.seed(seed)


def str2bool(v):
"""Support bool type for argparse."""
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Unsupported value encountered.")


def do_convert():
set_seed(args.seed)

Expand Down Expand Up @@ -125,7 +135,7 @@ def _save_examples(save_dir, file_name, examples):
parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
parser.add_argument("--options", default=["正向", "θ΄Ÿε‘"], type=str, nargs="+", help="Used only for the classification task, the options for classification")
parser.add_argument("--prompt_prefix", default="ζƒ…ζ„Ÿε€Ύε‘", type=str, help="Used only for the classification task, the prompt prefix for classification")
parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--is_shuffle", default="True", type=str2bool, help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--layout_analysis", default=False, type=bool, help="Enable layout analysis to optimize the order of OCR result.")
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization")
parser.add_argument("--separator", type=str, default='##', help="Used only for entity/aspect-level classification task, separator for entity label and classification label")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import paddle
from utils import load_txt
from utils import load_txt, str2bool

from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -727,7 +727,7 @@ def _save_examples(save_dir, file_name, examples):
parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*", help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60% samples used for training, 20% for evaluation and 20% for test.")
parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Two task types [ext, cls] are supported, ext represents the aspect-based extraction task and cls represents the sentence-level classification task, defaults to ext.")
parser.add_argument("--options", type=str, nargs="+", help="Used only for the classification task, the options for classification")
parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--is_shuffle", type=str2bool, default="True", help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization")

args = parser.parse_args()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import random
import re
Expand Down Expand Up @@ -50,6 +51,16 @@ def write_json_file(examples, save_path):
f.write(line + "\n")


def str2bool(v):
"""Support bool type for argparse."""
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Unsupported value encountered.")


def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
"""
Create dataloader.
Expand Down
13 changes: 8 additions & 5 deletions model_zoo/uie/doccano.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import argparse
import json
import os
import time
from decimal import Decimal

import numpy as np
from paddlenlp.utils.log import logger
from utils import convert_cls_examples, convert_ext_examples, set_seed, str2bool

from utils import set_seed, convert_ext_examples, convert_cls_examples
from paddlenlp.utils.log import logger


def do_convert():
Expand Down Expand Up @@ -100,6 +101,8 @@ def _save_examples(save_dir, file_name, examples):
indexes = np.random.permutation(len(raw_examples))
index_list = indexes.tolist()
raw_examples = [raw_examples[i] for i in indexes]
else:
index_list = list(range(len(raw_examples)))

i1, i2, _ = args.splits
p1 = int(len(raw_examples) * i1)
Expand Down Expand Up @@ -164,7 +167,7 @@ def _save_examples(save_dir, file_name, examples):
parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
parser.add_argument("--options", default=["正向", "θ΄Ÿε‘"], type=str, nargs="+", help="Used only for the classification task, the options for classification")
parser.add_argument("--prompt_prefix", default="ζƒ…ζ„Ÿε€Ύε‘", type=str, help="Used only for the classification task, the prompt prefix for classification")
parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--is_shuffle", default="True", type=str2bool, help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization")
parser.add_argument("--separator", type=str, default='##', help="Used only for entity/aspect-level classification task, separator for entity label and classification label")
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
Expand Down
11 changes: 11 additions & 0 deletions model_zoo/uie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import math
import random
Expand All @@ -30,6 +31,16 @@ def set_seed(seed):
np.random.seed(seed)


def str2bool(v):
"""Support bool type for argparse."""
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Unsupported value encountered.")


def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
"""
Create dataloader.
Expand Down

0 comments on commit 4bec43f

Please sign in to comment.