Skip to content

Commit

Permalink
Reorder some of the methods of ArgumentParser (#202)
Browse files Browse the repository at this point in the history
* Reorder some of the methods of `ArgumentParser`

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Move `_add_arguments` and remove unused arguments

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice authored Jan 17, 2023
1 parent 99a336d commit 97ebedb
Showing 1 changed file with 118 additions and 125 deletions.
243 changes: 118 additions & 125 deletions simple_parsing/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def add_arguments(
Parameters
----------
dataclass : Union[Dataclass, Type[Dataclass]]
dataclass : Union[Type[Dataclass], Dataclass]
The dataclass whose fields are to be parsed from the command-line.
If an instance of a dataclass is given, it is used as the default
value if none is provided.
Expand All @@ -215,7 +215,8 @@ def add_arguments(
default None
dataclass_wrapper_class : Type[DataclassWrapper], optional
The type of `DataclassWrapper` to use for this dataclass. This can be used to customize
how the arguments are generated.
how the arguments are generated. However, I'd suggest making a GitHub issue if you find
yourself using this often.
Returns
-------
Expand All @@ -232,63 +233,6 @@ def add_arguments(
self._wrappers.append(new_wrapper)
return new_wrapper

def _add_arguments(
self,
dataclass: type[DataclassT] | DataclassT,
name: str,
*,
prefix: str = "",
dataclass_fn: Callable[..., DataclassT] | None = None,
default: Dataclass | None = None,
dataclass_wrapper_class: type[DataclassWrapperType] = DataclassWrapper,
parent: DataclassWrapper | None = None,
_field: dataclasses.Field | None = None,
field_wrapper_class: type[FieldWrapper] = FieldWrapper,
) -> DataclassWrapper[DataclassT] | DataclassWrapperType:
for wrapper in self._wrappers:
if wrapper.dest == name:
if wrapper.dataclass == dataclass:
raise argparse.ArgumentError(
argument=None,
message=f"Destination attribute {name} is already used for "
f"dataclass of type {dataclass}. Make sure all destinations"
f" are unique. (new dataclass type: {dataclass})",
)
if not isinstance(dataclass, type):
if default is None:
default = dataclass
dataclass = type(dataclass)

dataclass_fn = dataclass_fn or dataclass

new_wrapper = dataclass_wrapper_class(
dataclass=dataclass,
name=name,
prefix=prefix,
default=default,
parent=parent,
_field=_field,
dataclass_fn=dataclass_fn,
field_wrapper_class=field_wrapper_class,
)

if new_wrapper.dest in self._defaults:
new_wrapper.set_default(self._defaults[new_wrapper.dest])
if self.nested_mode == NestedMode.WITHOUT_ROOT and all(
field.name in self._defaults for field in new_wrapper.fields
):
# If we did .set_defaults before we knew what dataclass we're using, then we try to
# still make use of those defaults:
new_wrapper.set_default(
{
k: v
for k, v in self._defaults.items()
if k in [f.name for f in dataclasses.fields(new_wrapper.dataclass)]
}
)

return new_wrapper

def parse_known_args(
self,
args: Sequence[str] | None = None,
Expand Down Expand Up @@ -346,7 +290,6 @@ def parse_known_args(
self._preprocessing(args=args, namespace=namespace)

logger.debug(f"Parser {id(self)} is parsing args: {args}, namespace: {namespace}")

parsed_args, unparsed_args = super().parse_known_args(args, namespace)

if unparsed_args and self._subparsers and attempt_to_reorder:
Expand All @@ -363,6 +306,22 @@ def parse_known_args(
parsed_args = self._postprocessing(parsed_args)
return parsed_args, unparsed_args

def add_argument_group(
self,
title: str | None = None,
description: str | None = None,
prefix_chars=None,
argument_default=None,
conflict_handler=None,
) -> argparse._ArgumentGroup:
return super().add_argument_group(
title=title,
description=description,
prefix_chars=prefix_chars or self.prefix_chars,
argument_default=argument_default or self.argument_default,
conflict_handler=conflict_handler or self.conflict_handler,
)

def print_help(self, file=None, args: Sequence[str] | None = None):
self._preprocessing(args=list(args) if args else [])
return super().print_help(file)
Expand Down Expand Up @@ -442,8 +401,60 @@ def equivalent_argparse_code(self, args: Sequence[str] | None = None) -> str:
code += "print(args)\n"
return code

def _resolve_conflicts(self) -> None:
self._wrappers = self._conflict_resolver.resolve_and_flatten(self._wrappers)
def _add_arguments(
self,
dataclass: type[DataclassT] | DataclassT,
name: str,
*,
prefix: str = "",
dataclass_fn: Callable[..., DataclassT] | None = None,
default: Dataclass | None = None,
dataclass_wrapper_class: type[DataclassWrapperType] = DataclassWrapper,
parent: DataclassWrapper | None = None,
) -> DataclassWrapper[DataclassT] | DataclassWrapperType:
for wrapper in self._wrappers:
if wrapper.dest == name:
if wrapper.dataclass == dataclass:
raise argparse.ArgumentError(
argument=None,
message=f"Destination attribute {name} is already used for "
f"dataclass of type {dataclass}. Make sure all destinations"
f" are unique. (new dataclass type: {dataclass})",
)
if not isinstance(dataclass, type):
if default is None:
default = dataclass
dataclass = type(dataclass)

dataclass_fn = dataclass_fn or dataclass

# Create this object that holds the dataclass we will create arguments for and the
# arguments that were passed.
new_wrapper = dataclass_wrapper_class(
dataclass=dataclass,
name=name,
prefix=prefix,
default=default,
parent=parent,
dataclass_fn=dataclass_fn,
)

if new_wrapper.dest in self._defaults:
new_wrapper.set_default(self._defaults[new_wrapper.dest])
if self.nested_mode == NestedMode.WITHOUT_ROOT and all(
field.name in self._defaults for field in new_wrapper.fields
):
# If we did .set_defaults before we knew what dataclass we're using, then we try to
# still make use of those defaults:
new_wrapper.set_default(
{
k: v
for k, v in self._defaults.items()
if k in [f.name for f in dataclasses.fields(new_wrapper.dataclass)]
}
)

return new_wrapper

def _preprocessing(self, args: Sequence[str] = (), namespace: Namespace | None = None) -> None:
"""Resolve potential conflicts, resolve subgroups, and add all the arguments."""
Expand All @@ -458,8 +469,6 @@ def _preprocessing(self, args: Sequence[str] = (), namespace: Namespace | None =
# Fix the potential conflicts between dataclass fields with the same names.
wrapped_dataclasses = self._conflict_resolver.resolve_and_flatten(wrapped_dataclasses)

# TODO: We're using the `self._conflict_resolver` inside `self._resolve_subgroups`, but
# what if the resolver is the ALWAYS_MERGE one? Would that cause issues if used repeatedly?
wrapped_dataclasses, chosen_subgroups = self._resolve_subgroups(
wrappers=wrapped_dataclasses, args=args, namespace=namespace
)
Expand All @@ -480,6 +489,49 @@ def _preprocessing(self, args: Sequence[str] = (), namespace: Namespace | None =
# Save this so we don't re-add all the arguments.
self._preprocessing_done = True

def _postprocessing(self, parsed_args: Namespace) -> Namespace:
"""Process the namespace by extract the fields and creating the objects.
Instantiate the dataclasses from the parsed arguments and set them at
their destination attribute in the namespace.
Parameters
----------
parsed_args : Namespace
the result of calling `super().parse_args(...)` or
`super().parse_known_args(...)`.
TODO: Try and maybe return a nicer, typed version of parsed_args.
Returns
-------
Namespace
The original Namespace, with all the arguments corresponding to the
dataclass fields removed, and with the added dataclass instances.
Also keeps whatever arguments were added in the traditional fashion,
i.e. with `parser.add_argument(...)`.
"""
logger.debug("\nPOST PROCESSING\n")
logger.debug(f"(raw) parsed args: {parsed_args}")

self._remove_subgroups_from_namespace(parsed_args)
# create the constructor arguments for each instance by consuming all
# the relevant attributes from `parsed_args`
wrappers = _flatten_wrappers(self._wrappers)

constructor_arguments = self.constructor_arguments.copy()
for wrapper in wrappers:
for destination in wrapper.destinations:
constructor_arguments.setdefault(destination, {})

parsed_args, constructor_arguments = self._fill_constructor_arguments_with_fields(
parsed_args, wrappers=wrappers, initial_constructor_arguments=constructor_arguments
)
parsed_args = self._instantiate_dataclasses(
parsed_args, wrappers=wrappers, constructor_arguments=constructor_arguments
)
return parsed_args

def _resolve_subgroups(
self,
wrappers: list[DataclassWrapper],
Expand Down Expand Up @@ -618,8 +670,8 @@ def _resolve_subgroups(
# TODO: What if a name conflict occurs between a subgroup field and one of the new
# fields below it? For example, something like --model model_a (and inside the `ModelA`
# dataclass, there's a field called `model`. Then, this will cause a conflict!)
# For now, I'm just going to wait and see how this plays out. I'm thinking that the
# auto conflict resolution shouldn't run into any issues with this here.
# For now, I'm just going to wait and see how this plays out. I'm hoping that the
# auto conflict resolution shouldn't run into any issues in this case.

wrappers = self._conflict_resolver.resolve(wrappers)

Expand All @@ -642,65 +694,6 @@ def _resolve_subgroups(
)
return wrappers, resolved_subgroups

def add_argument_group(
self,
title: str | None = None,
description: str | None = None,
prefix_chars=None,
argument_default=None,
conflict_handler=None,
) -> argparse._ArgumentGroup:
return super().add_argument_group(
title=title,
description=description,
prefix_chars=prefix_chars or self.prefix_chars,
argument_default=argument_default or self.argument_default,
conflict_handler=conflict_handler or self.conflict_handler,
)

def _postprocessing(self, parsed_args: Namespace) -> Namespace:
"""Process the namespace by extract the fields and creating the objects.
Instantiate the dataclasses from the parsed arguments and set them at
their destination attribute in the namespace.
Parameters
----------
parsed_args : Namespace
the result of calling `super().parse_args(...)` or
`super().parse_known_args(...)`.
TODO: Try and maybe return a nicer, typed version of parsed_args.
Returns
-------
Namespace
The original Namespace, with all the arguments corresponding to the
dataclass fields removed, and with the added dataclass instances.
Also keeps whatever arguments were added in the traditional fashion,
i.e. with `parser.add_argument(...)`.
"""
logger.debug("\nPOST PROCESSING\n")
logger.debug(f"(raw) parsed args: {parsed_args}")

self._remove_subgroups_from_namespace(parsed_args)
# create the constructor arguments for each instance by consuming all
# the relevant attributes from `parsed_args`
wrappers = _flatten_wrappers(self._wrappers)

constructor_arguments = self.constructor_arguments.copy()
for wrapper in wrappers:
for destination in wrapper.destinations:
constructor_arguments.setdefault(destination, {})

parsed_args, constructor_arguments = self._fill_constructor_arguments_with_fields(
parsed_args, wrappers=wrappers, initial_constructor_arguments=constructor_arguments
)
parsed_args = self._instantiate_dataclasses(
parsed_args, wrappers=wrappers, constructor_arguments=constructor_arguments
)
return parsed_args

def _remove_subgroups_from_namespace(self, parsed_args: argparse.Namespace) -> None:
"""Removes the subgroup choice results from the namespace.
Modifies the namespace in-place.
Expand Down

0 comments on commit 97ebedb

Please sign in to comment.