From 40f67ae259f8aab3150a94aa2782997f32281ce7 Mon Sep 17 00:00:00 2001 From: tianwei Date: Tue, 9 Jan 2024 10:56:26 +0800 Subject: [PATCH] enhance(sdk): support new typing hint example for transformers 4.36.0+ (#3116) --- client/starwhale/api/_impl/argument.py | 37 +++++++++++--------------- client/tests/sdk/test_argument.py | 24 ++++++++++++++++- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 525c134096..6e8d0fd1ee 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -259,32 +259,27 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option: } # reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py - # only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type + # only support Union: NoneType(optional), str(optional) and other types, such as: Optional[int], Union[int], Union[int, str], Union[List[str], str] and Optional[Union[List[str], str]] origin_type = getattr(field.type, "__origin__", field.type) if origin_type is t.Union: - if ( - str not in field.type.__args__ and type(None) not in field.type.__args__ - ) or (len(field.type.__args__) != 2): + _args = list(field.type.__args__) + if type(None) in _args: + _args.remove(type(None)) + + _args_cnt = len(_args) + if (_args_cnt == 2 and str not in _args) or _args_cnt > 2 or _args_cnt == 0: raise ValueError( - f"{field.type} is not supported." - "Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type" + "Only `Union[X, str, NoneType]` (i.e., `Optional[X]`) or `Union[X, str]` is allowed for `Union` because" + " the argument parser only supports one type per argument." + f" Problem encountered in field '{field.name}'." ) - if type(None) in field.type.__args__: - # ignore None type, use another type as the field type - field.type = ( - field.type.__args__[0] - if field.type.__args__[1] == type(None) - else field.type.__args__[1] - ) + if _args_cnt == 1: + field.type = _args[0] origin_type = getattr(field.type, "__origin__", field.type) - else: - # ignore str and None type, use another type as the field type - field.type = ( - field.type.__args__[0] - if field.type.__args__[1] == str - else field.type.__args__[1] - ) + elif _args_cnt == 2: + # filter `str` in Union + field.type = _args[0] if _args[1] == str else _args[1] origin_type = getattr(field.type, "__origin__", field.type) if (origin_type is Literal) or ( @@ -300,7 +295,7 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option: kw["default"] = field.default else: kw["required"] = True - elif field.type is bool or field.type == t.Optional[bool]: + elif field.type is bool: kw["is_flag"] = True kw["type"] = bool kw["default"] = False if field.default is dataclasses.MISSING else field.default diff --git a/client/tests/sdk/test_argument.py b/client/tests/sdk/test_argument.py index 91799d30cd..78acb8f28b 100644 --- a/client/tests/sdk/test_argument.py +++ b/client/tests/sdk/test_argument.py @@ -28,6 +28,11 @@ class DebugOption(Enum): TPU_METRICS_DEBUG = "tpu_metrics_debug" +class FSDPOption(Enum): + FSDP = "fsdp" + FSDP2 = "fsdp2" + + @dataclasses.dataclass class ScalarArguments: no_field = 1 @@ -62,6 +67,13 @@ class ComposeArguments: label_names: t.Optional[t.List[str]] = dataclasses.field( default=None, metadata={"help": "label names"} ) + fsdp: t.Optional[t.Union[t.List[FSDPOption], str]] = dataclasses.field( + default="", metadata={"help": "fsdp"} + ) + fsdp2: t.Optional[t.Union[str, t.List[FSDPOption]]] = dataclasses.field( + default="", metadata={"help": "fsdp2"} + ) + tf32: t.Optional[bool] = dataclasses.field(default=None, metadata={"help": "tf32"}) class ArgumentTestCase(TestCase): @@ -212,10 +224,20 @@ def test_compose_parser(self) -> None: assert optional_list_obj.multiple assert optional_list_obj.default is None + fsdp_obj = compose_parser._long_opt["--fsdp"].obj + assert isinstance(fsdp_obj.type, click.types.FuncParamType) + assert fsdp_obj.type.func == FSDPOption + + fsdp_obj2 = compose_parser._long_opt["--fsdp2"].obj + assert fsdp_obj2.type.func == fsdp_obj.type.func + + tf32_obj = compose_parser._long_opt["--tf32"].obj + assert tf32_obj.type == click.BOOL + argument_ctx = ArgumentContext.get_current_context() assert len(argument_ctx._options) == 1 options = argument_ctx._options["tests.sdk.test_argument:ComposeArguments"] - assert len(options) == 6 + assert len(options) == 9 assert options[0].name == "debug" argument_ctx.echo_help()