Skip to content

Commit

Permalink
[optional] Make the "reuse" logic a bit better?
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice committed Jan 5, 2023
1 parent f94a274 commit b8787c8
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 52 deletions.
2 changes: 1 addition & 1 deletion simple_parsing/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def _fill_constructor_arguments_with_fields(
# strategy), then we store the multiple values in the `dest` of the first field.
# They are they distributed in `constructor_arguments` using the
# `field.destinations`, which gives the destination for each value.
values = parsed_arg_values.pop(field.dest, field.default)
values = parsed_arg_values.pop(field.dest, field.defaults)
deleted_values[field.dest] = values

# call the "action" for the given attribute. This sets the right
Expand Down
139 changes: 88 additions & 51 deletions simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,13 @@ def get_arg_options(self) -> dict[str, Any]:
# For positional arguments that aren't required we need to set
# nargs='?' to make them optional.
_arg_options["nargs"] = "?"
_arg_options["default"] = self.default

_arg_options["default"] = self.default if not self.is_reused else self.defaults
_arg_options["metavar"] = get_metavar(self.type)

if self.help:
_arg_options["help"] = self.help
elif self.default is not None:
elif not self.is_reused and self.default is not None:
# issue 64: Need to add a temporary 'help' string, so that the formatter
# automatically adds the (default: '123'). We then remove it.
_arg_options["help"] = TEMPORARY_TOKEN
Expand Down Expand Up @@ -713,17 +714,19 @@ def default(self) -> Any:
2. the value of the corresponding attribute on the parent,
if it has a default value
"""

if self.is_reused:
raise RuntimeError(
"Reused fields (those created through a conflict, when using the "
"ConflictResolution.ALWAYS_MERGE option) don't have a .default attribute. "
"Use .defaults instead."
)
if self._default is not None:
# If a default value was set manually from the outside (e.g. from the DataclassWrapper)
# then use that value.
default = self._default
elif self.is_subgroup:
default = self.subgroup_default
elif any(
parent_default not in (None, argparse.SUPPRESS)
for parent_default in self.parent.defaults
):
return self._default
if self.is_subgroup:
return self.subgroup_default
if self.parent.default not in (None, argparse.SUPPRESS):
# if the dataclass with this field has a default value - either when a value was
# passed for the `default` argument of `add_arguments` or when the parent is a nested
# dataclass field with a default factory - we use the corresponding attribute on that
Expand All @@ -733,19 +736,12 @@ def _get_value(dataclass_default: utils.Dataclass | dict, name: str) -> Any:
return dataclass_default.get(name)
return getattr(dataclass_default, name)

defaults = [
_get_value(parent_default, self.field.name)
for parent_default in self.parent.defaults
if parent_default not in (None, argparse.SUPPRESS)
]
if len(self.parent.defaults) == 1:
default = defaults[0]
else:
default = defaults
return _get_value(self.parent.default, self.field.name)

# Try to get the default from the field, if possible.
elif self.field.default is not dataclasses.MISSING:
default = self.field.default
elif self.field.default_factory is not dataclasses.MISSING:
if self.field.default is not dataclasses.MISSING:
return self.field.default
if self.field.default_factory is not dataclasses.MISSING:
# Use the _default attribute to keep the result, so we can avoid calling the default
# factory another time.
# TODO: If the default factory is a function that returns None, it will still get
Expand All @@ -754,40 +750,75 @@ def _get_value(dataclass_default: utils.Dataclass | dict, name: str) -> Any:
# the default_factory before.
if self._default is None:
self._default = self.field.default_factory()
default = self._default
return self._default
# field doesn't have a default value set.
elif self.action == "store_true":
default = False
elif self.action == "store_false":
if self.action == "store_true":
return False
if self.action == "store_false":
# NOTE: The boolean parsing when default is `True` is really un-intuitive, and should
# change in the future. See https://github.com/lebrice/SimpleParsing/issues/68
default = True
else:
default = None

# If this field is being reused, then we package up the `default` in a list.
# TODO: Get rid of this. makes the code way uglier for no good reason.
if self.is_reused and default is not None:
n_destinations = len(self.destinations)
assert n_destinations >= 1
# BUG: This second part (the `or` part) is weird. Probably only applies when using
# Lists of lists with the Reuse option, which is most likely not even supported..
if utils.is_tuple_or_list(self.field.type) and len(default) != n_destinations:
# The field is of a list type field,
default = [default] * n_destinations
elif not isinstance(default, list):
default = [default] * n_destinations
assert len(default) == n_destinations, (
f"Not the same number of default values and destinations. "
f"(default: {default}, # of destinations: {n_destinations})"
)

return default
return True
return None

def set_default(self, value: Any):
logger.debug(f"The field {self.name} has its default manually set to a value of {value}.")
self._default = value

@property
def defaults(self) -> list[Any]:
"""Returns the default values for this field. Only applies to reused fields (fields created
by a conflict, when using the ConflictResolution.ALWAYS_MERGE option).
"""
if not self.is_reused:
# TODO: Choose between the strict option:

# raise RuntimeError(
# "Only reused fields (those created by a conflict, when using the "
# "ConflictResolution.ALWAYS_MERGE option) can have a .defaults."
# )
# or the more relaxed option:
return [self.default]
n_destinations = len(self.destinations)
assert n_destinations >= 1
# Do all the weird shit that we want to get rid of.

def _get_value(dataclass_default: utils.Dataclass | dict, name: str) -> Any:
if isinstance(dataclass_default, dict):
return dataclass_default.get(name)
return getattr(dataclass_default, name)

if self.parent.defaults:
defaults = [
_get_value(parent_default, self.field.name)
for parent_default in self.parent.defaults
if parent_default not in (None, argparse.SUPPRESS)
]
else:
# NOTE: Hacky: temporarily make this a non-reused field, to extract the '.default'
# attribute.
destinations = self.parent.destinations
self.parent.destinations = [self.parent.destinations[0]]
assert not self.is_reused
defaults = [self.default] * n_destinations

self.parent.destinations = destinations
assert self.is_reused

# BUG: This second part (the `or` part) is weird. Probably only applies when using
# Lists of lists with the Reuse option, which is most likely not even supported..
if utils.is_tuple_or_list(self.field.type) and len(defaults) != n_destinations:
# The field is of a list type field,
defaults = [defaults] * n_destinations
elif not isinstance(defaults, list):
defaults = [defaults] * n_destinations

# TODO: Required arguments with reuse doesn't currently work.
assert len(defaults) == n_destinations, (
f"Not the same number of default values and destinations. "
f"(default: {defaults}, # of destinations: {n_destinations})"
)
return defaults

@property
def required(self) -> bool:
if self._required is not None:
Expand All @@ -807,12 +838,18 @@ def required(self) -> bool:
return False
if self.nargs == "+":
return True
if self.default is None and argparse.SUPPRESS not in self.parent.defaults:
return True
if self.is_reused:
# if we're reusing this argument, the default value might be a list
# of `MISSING` values.
return any(v == dataclasses.MISSING for v in self.default)
# FIXME: THis should only check against dataclasses.MISSING, not None!
return all(v in (None, dataclasses.MISSING) for v in self.defaults)
if (
not self.is_reused
and self.default is None
and argparse.SUPPRESS not in self.parent.defaults
):
return True

return False

@required.setter
Expand Down

0 comments on commit b8787c8

Please sign in to comment.