Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework of the subgroups feature + Minor Refactoring #185

Merged
merged 48 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
fab0f66
Adding test with current WIP implementation
lebrice Dec 12, 2022
facfc99
Removed the previous subgroups implementation
lebrice Dec 12, 2022
815b2b4
Fix test collection for vscode
lebrice Dec 12, 2022
e662ebe
Making some progress (need to fix postprocessing)
lebrice Dec 12, 2022
0b2bb3b
Making more progress, almost there (ish)
lebrice Dec 12, 2022
efea1cd
Making more progress (need to fix conflicts)
lebrice Dec 12, 2022
3d51edb
Conflicts confusion is clearing up
lebrice Dec 12, 2022
690e3bb
Things are parsing, but default is slightly wrong
lebrice Dec 12, 2022
a6253a7
'subgroups' is only added on namespace when needed
lebrice Dec 12, 2022
0d5db49
Progress: Might have identified issue
lebrice Dec 12, 2022
f28ff13
Refactor _set_instances_on_namespace
lebrice Dec 12, 2022
1870903
Refactor, fix some failing basic tests
lebrice Dec 13, 2022
2b1872f
Fix small bug in _consume_constructor_arguments
lebrice Dec 13, 2022
fa7a1d4
Refactoring: Improve naming of Parser methods
lebrice Dec 13, 2022
f722394
[dirty] wip (switching machines)
lebrice Dec 15, 2022
d583366
Fix mistake with self._wrappers, few tests left
lebrice Dec 19, 2022
3650f51
Simplify `subgroups` fn, begin rewriting tests
lebrice Dec 19, 2022
9bd0069
Simple subgroup tests are passing
lebrice Dec 19, 2022
386720c
Fix issue with .setdefault and new subgroups
lebrice Dec 19, 2022
bd46ca9
Fix required arg for equivalent_argparse_code
lebrice Dec 19, 2022
2c011c1
Remove a logger.Critical call
lebrice Dec 20, 2022
40283cf
Remove hacky code from previous iteration
lebrice Dec 20, 2022
ae45549
Remove the `_print_tree` function
lebrice Dec 20, 2022
9e1539c
Make first block of logging in test a bit better
lebrice Dec 20, 2022
9a701a6
Apply pre-commit hooks to field_wrapper.py
lebrice Dec 20, 2022
0f9c2ee
Cleanup the resolve_subgroups code a tiny bit
lebrice Dec 20, 2022
235b714
Make `Dataclass` an actual Protocol
lebrice Dec 20, 2022
73e8ea6
Make DataclassWrapper use Dataclass typevar
lebrice Dec 20, 2022
4595cd6
Add test to check that subgroups are saved (#139)
lebrice Dec 20, 2022
ab81b26
Make `subgroups` sig reflect current limitations
lebrice Dec 20, 2022
81ba977
Add more tests (not passing)
lebrice Dec 20, 2022
0ccbcd4
Refactor FieldWrapper.default
lebrice Jan 5, 2023
61e0ff2
Fix required subgroups issue
lebrice Jan 5, 2023
bd2f895
Fix error in test for subgroups with a conflict
lebrice Jan 5, 2023
aef61f2
Minor esthetic change to test_subgroups.py
lebrice Jan 5, 2023
f1fdff2
Fix issues with py37 and parsing of Reused lists
lebrice Jan 5, 2023
f94a274
Very minor improvement to unrelated test
lebrice Jan 5, 2023
b8787c8
[optional] Make the "reuse" logic a bit better?
lebrice Jan 5, 2023
5cd4270
Revert "[optional] Make the "reuse" logic a bit better?"
lebrice Jan 6, 2023
30b92ae
Fix typing of the `subgroups` function
lebrice Jan 6, 2023
448d4b1
Removed unused _remove_help_action and test
lebrice Jan 6, 2023
cf6ac13
"create_dataclasses"->"instantiate_dataclasses"
lebrice Jan 6, 2023
dcc8808
Remove commented subgroup tests
lebrice Jan 6, 2023
ea1e524
Add more tests for the `subgroups` function
lebrice Jan 6, 2023
1dac691
Use dict instead of Mapping
lebrice Jan 6, 2023
e2a73bf
Add test for help string future feature
lebrice Jan 10, 2023
8a4e183
Fix small bug in test
lebrice Jan 10, 2023
73c9cd1
Merge branch 'master' into nested-subgroups
lebrice Jan 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
addopts = --doctest-modules
testpaths = test simple_parsing
norecursedirs = examples .git .tox .eggs dist build docs *.egg
117 changes: 77 additions & 40 deletions simple_parsing/conflicts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import enum
from collections import defaultdict
from logging import getLogger
from typing import Dict, List, NamedTuple, Optional, Set, Union
from typing import NamedTuple

from .wrappers import DataclassWrapper, FieldWrapper

Expand Down Expand Up @@ -49,18 +51,33 @@ def __init__(self, *args, **kwargs):

class Conflict(NamedTuple):
option_string: str
wrappers: List[FieldWrapper]
wrappers: list[FieldWrapper]


def unflatten(possibly_related_wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]:
return [wrapper for wrapper in possibly_related_wrappers if wrapper.parent is None]


class ConflictResolver:
def __init__(self, conflict_resolution=ConflictResolution.AUTO):
self.conflict_resolution = conflict_resolution

def resolve(self, wrappers: List[DataclassWrapper]) -> List[DataclassWrapper]:
wrappers_flat = []
for wrapper in wrappers:
wrappers_flat.append(wrapper)
wrappers_flat.extend(wrapper.descendants)
def resolve_and_flatten(self, wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]:
"""Given the list of all dataclass wrappers, find and resolve any conflicts between fields.

Returns the new list of (possibly mutated in-place) dataclass wrappers.
This returned list is flattened, i.e. it contains all the dataclass wrappers and their
children.
"""
from simple_parsing.parsing import _assert_no_duplicates, _flatten_wrappers

wrappers = wrappers.copy()

_assert_no_duplicates(wrappers)
wrappers_flat = _flatten_wrappers(wrappers)

dests = [w.dest for w in wrappers_flat]
assert len(dests) == len(set(dests)), f"shouldn't be any duplicates: {wrappers_flat}"

conflict = self.get_conflict(wrappers_flat)

Expand Down Expand Up @@ -109,33 +126,40 @@ def resolve(self, wrappers: List[DataclassWrapper]) -> List[DataclassWrapper]:
assert not self._conflict_exists(wrappers_flat)
return wrappers_flat

def resolve(self, wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]:
return unflatten(self.resolve_and_flatten(wrappers))

def get_conflict(
self, wrappers: Union[List[FieldWrapper], List[DataclassWrapper]]
) -> Optional[Conflict]:
field_wrappers: List[FieldWrapper] = []
self, wrappers: list[DataclassWrapper] | list[FieldWrapper]
) -> Conflict | None:
field_wrappers: list[FieldWrapper] = []
for w in wrappers:
if isinstance(w, FieldWrapper):
field_wrappers.append(w)
else:
if isinstance(w, DataclassWrapper):
field_wrappers.extend(w.fields)
logger.debug(f"Wrapper {w.dest} has fields {w.fields}")
else:
field_wrappers.append(w)

# TODO: #49: Also consider the conflicts with regular argparse arguments.
assert len(field_wrappers) == len(set(field_wrappers)), "duplicates?"

conflicts: Dict[str, List[FieldWrapper]] = defaultdict(list)
# TODO: #49: Also consider the conflicts with regular argparse arguments.
conflicts: dict[str, list[FieldWrapper]] = defaultdict(list)
for field_wrapper in field_wrappers:
for option_string in field_wrapper.option_strings:
conflicts[option_string].append(field_wrapper)
# logger.debug(f"conflicts[{option_string}].append({repr(field_wrapper)})")

for option_string, wrappers in conflicts.items():
if len(wrappers) > 1:
return Conflict(option_string, wrappers)
for option_string, field_wrappers in conflicts.items():
if len(field_wrappers) > 1:
return Conflict(option_string, field_wrappers)
return None

def _add(
self,
wrapper: Union[DataclassWrapper, FieldWrapper],
wrappers: List[DataclassWrapper],
) -> DataclassWrapper:
wrapper: DataclassWrapper | FieldWrapper,
wrappers: list[DataclassWrapper],
) -> list[DataclassWrapper]:
"""Add the given wrapper and all its descendants to the list of wrappers."""
if isinstance(wrapper, FieldWrapper):
wrapper = wrapper.parent
assert isinstance(wrapper, DataclassWrapper)
Expand All @@ -147,18 +171,22 @@ def _add(

def _remove(
self,
wrapper: Union[DataclassWrapper, FieldWrapper],
wrappers: List[DataclassWrapper],
wrapper: DataclassWrapper | FieldWrapper,
wrappers: list[DataclassWrapper],
):
"""Remove the given wrapper and all its descendants from the list of wrappers."""
if isinstance(wrapper, FieldWrapper):
wrapper = wrapper.parent
assert isinstance(wrapper, DataclassWrapper)
logger.debug(f"Removing DataclassWrapper {wrapper}")
wrappers.remove(wrapper)
for child in wrapper.descendants:
logger.debug(f"\tAlso Removing Child DataclassWrapper {child}")
logger.debug(f"\tAlso removing Child DataclassWrapper {child}")
wrappers.remove(child)

# TODO: Should we also remove the reference to this wrapper from its parent?
for other_wrapper in wrappers:
if wrapper in other_wrapper._children:
other_wrapper._children.remove(wrapper)
return wrappers

def _fix_conflict_explicit(self, conflict: Conflict):
Expand Down Expand Up @@ -233,12 +261,13 @@ def _fix_conflict_auto(self, conflict: Conflict):
"""
field_wrappers = sorted(conflict.wrappers, key=lambda w: w.nesting_level)
logger.debug(f"Conflict with options string '{conflict.option_string}':")
for field in field_wrappers:
logger.debug(f"Field wrapper: {field} nesting level: {field.nesting_level}.")
for i, field in enumerate(field_wrappers):
logger.debug(f"Field wrapper #{i+1}: {field} nesting level: {field.nesting_level}.")

assert (
len(set(field_wrappers)) >= 2
), "Need at least 2 (distinct) FieldWrappers to have a conflict..."

first_wrapper = field_wrappers[0]
second_wrapper = field_wrappers[1]
if first_wrapper.nesting_level < second_wrapper.nesting_level:
Expand Down Expand Up @@ -286,7 +315,7 @@ def _fix_conflict_auto(self, conflict: Conflict):
field_wrapper.prefix = word_to_add + "." + current_prefix
logger.debug(f"New prefix: {field_wrapper.prefix}")

def _fix_conflict_merge(self, conflict: Conflict, wrappers_flat: List[DataclassWrapper]):
def _fix_conflict_merge(self, conflict: Conflict, wrappers_flat: list[DataclassWrapper]):
"""Fix conflicts using the merging approach.

The first wrapper is kept, and the rest of the wrappers are absorbed
Expand All @@ -305,24 +334,32 @@ def _fix_conflict_merge(self, conflict: Conflict, wrappers_flat: List[DataclassW
logger.debug(f"Field wrapper: {field} nesting level: {field.nesting_level}.")

assert len(conflict.wrappers) > 1
first_wrapper: FieldWrapper = fields[0]
wrappers_flat = self._remove(first_wrapper.parent, wrappers_flat)

for wrapper in conflict.wrappers[1:]:
wrappers_flat = self._remove(wrapper.parent, wrappers_flat)
first_wrapper.parent.merge(wrapper.parent)
# Merge all the fields into the first one.
first_wrapper: FieldWrapper = fields[0]
wrappers = wrappers_flat.copy()

assert first_wrapper.parent.multiple
wrappers_flat = self._add(first_wrapper.parent, wrappers_flat)
first_containing_dataclass: DataclassWrapper = first_wrapper.parent
original_parent = first_containing_dataclass.parent
wrappers = self._remove(first_containing_dataclass, wrappers)

return wrappers_flat
for wrapper in conflict.wrappers[1:]:
containing_dataclass = wrapper.parent
wrappers = self._remove(containing_dataclass, wrappers)
first_containing_dataclass.merge(containing_dataclass)

assert first_containing_dataclass.multiple
wrappers = self._add(first_containing_dataclass, wrappers)
if original_parent:
original_parent._children.append(first_containing_dataclass)
return wrappers

def _get_conflicting_group(self, all_wrappers: List[DataclassWrapper]) -> Optional[Conflict]:
def _get_conflicting_group(self, all_wrappers: list[DataclassWrapper]) -> Conflict | None:
"""Return the conflicting DataclassWrappers which share argument names.

TODO: maybe return the list of fields, rather than the dataclasses?
"""
conflicts: Dict[str, List[FieldWrapper]] = defaultdict(list)
conflicts: dict[str, list[FieldWrapper]] = defaultdict(list)
for wrapper in all_wrappers:
for field in wrapper.fields:
for option in field.option_strings:
Expand All @@ -338,9 +375,9 @@ def _get_conflicting_group(self, all_wrappers: List[DataclassWrapper]) -> Option
return Conflict(option_string, fields)
return None

def _conflict_exists(self, all_wrappers: List[DataclassWrapper]) -> bool:
def _conflict_exists(self, all_wrappers: list[DataclassWrapper]) -> bool:
"""Return True whenever a conflict exists. (option strings overlap)."""
arg_names: Set[str] = set()
arg_names: set[str] = set()
for wrapper in all_wrappers:
for field in wrapper.fields:
for option in field.option_strings:
Expand Down
6 changes: 3 additions & 3 deletions simple_parsing/help_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def _format_args(self, action: Action, default_metavar: str):
formats = ["%s" for _ in range(action.nargs)]
result = " ".join(formats) % _get_metavar(action.nargs)

logger.debug(
f"action type: {action_type}, Result: {result}, nargs: {action.nargs}, default metavar: {default_metavar}"
)
# logger.debug(
# f"action type: {action_type}, Result: {result}, nargs: {action.nargs}, default metavar: {default_metavar}"
# )
return result

def _get_default_metavar_for_optional(self, action: argparse.Action):
Expand Down
Loading