Skip to content

Commit

Permalink
enhance(sdk): support new typing hint example for transformers 4.36.0+ (
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Jan 9, 2024
1 parent a5da552 commit 40f67ae
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 22 deletions.
37 changes: 16 additions & 21 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,32 +259,27 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option:
}

# reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py
# only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type
# only support Union: NoneType(optional), str(optional) and other types, such as: Optional[int], Union[int], Union[int, str], Union[List[str], str] and Optional[Union[List[str], str]]
origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is t.Union:
if (
str not in field.type.__args__ and type(None) not in field.type.__args__
) or (len(field.type.__args__) != 2):
_args = list(field.type.__args__)
if type(None) in _args:
_args.remove(type(None))

_args_cnt = len(_args)
if (_args_cnt == 2 and str not in _args) or _args_cnt > 2 or _args_cnt == 0:
raise ValueError(
f"{field.type} is not supported."
"Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type"
"Only `Union[X, str, NoneType]` (i.e., `Optional[X]`) or `Union[X, str]` is allowed for `Union` because"
" the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'."
)

if type(None) in field.type.__args__:
# ignore None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == type(None)
else field.type.__args__[1]
)
if _args_cnt == 1:
field.type = _args[0]
origin_type = getattr(field.type, "__origin__", field.type)
else:
# ignore str and None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == str
else field.type.__args__[1]
)
elif _args_cnt == 2:
# filter `str` in Union
field.type = _args[0] if _args[1] == str else _args[1]
origin_type = getattr(field.type, "__origin__", field.type)

if (origin_type is Literal) or (
Expand All @@ -300,7 +295,7 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option:
kw["default"] = field.default
else:
kw["required"] = True
elif field.type is bool or field.type == t.Optional[bool]:
elif field.type is bool:
kw["is_flag"] = True
kw["type"] = bool
kw["default"] = False if field.default is dataclasses.MISSING else field.default
Expand Down
24 changes: 23 additions & 1 deletion client/tests/sdk/test_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class DebugOption(Enum):
TPU_METRICS_DEBUG = "tpu_metrics_debug"


class FSDPOption(Enum):
FSDP = "fsdp"
FSDP2 = "fsdp2"


@dataclasses.dataclass
class ScalarArguments:
no_field = 1
Expand Down Expand Up @@ -62,6 +67,13 @@ class ComposeArguments:
label_names: t.Optional[t.List[str]] = dataclasses.field(
default=None, metadata={"help": "label names"}
)
fsdp: t.Optional[t.Union[t.List[FSDPOption], str]] = dataclasses.field(
default="", metadata={"help": "fsdp"}
)
fsdp2: t.Optional[t.Union[str, t.List[FSDPOption]]] = dataclasses.field(
default="", metadata={"help": "fsdp2"}
)
tf32: t.Optional[bool] = dataclasses.field(default=None, metadata={"help": "tf32"})


class ArgumentTestCase(TestCase):
Expand Down Expand Up @@ -212,10 +224,20 @@ def test_compose_parser(self) -> None:
assert optional_list_obj.multiple
assert optional_list_obj.default is None

fsdp_obj = compose_parser._long_opt["--fsdp"].obj
assert isinstance(fsdp_obj.type, click.types.FuncParamType)
assert fsdp_obj.type.func == FSDPOption

fsdp_obj2 = compose_parser._long_opt["--fsdp2"].obj
assert fsdp_obj2.type.func == fsdp_obj.type.func

tf32_obj = compose_parser._long_opt["--tf32"].obj
assert tf32_obj.type == click.BOOL

argument_ctx = ArgumentContext.get_current_context()
assert len(argument_ctx._options) == 1
options = argument_ctx._options["tests.sdk.test_argument:ComposeArguments"]
assert len(options) == 6
assert len(options) == 9
assert options[0].name == "debug"
argument_ctx.echo_help()

Expand Down

0 comments on commit 40f67ae

Please sign in to comment.