Skip to content

Commit

Permalink
fix to get_default_args(instance)
Browse files Browse the repository at this point in the history
Summary:
Small config system fix. Allows get_default_args to work on an instance which has been created with a dict (instead of a DictConfig) as an args field. E.g.

```
gm = GenericModel(
        raysampler_AdaptiveRaySampler_args={"scene_extent": 4.0}
    )
    OmegaConf.structured(gm1)
```

Reviewed By: shapovalov

Differential Revision: D40341047

fbshipit-source-id: 587d0e8262e271df442a80858949a48e5d6db3df
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 13, 2022
1 parent 76cddd9 commit 4d9215b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
21 changes: 13 additions & 8 deletions pytorch3d/implicitron/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,8 @@ def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
args = self.getattr(f"x_{self.x_class_type}_args")
self.create_x_impl(self.x_class_type, args)
Expand All @@ -733,8 +733,8 @@ def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
if self.x_class_type is None:
args = None
Expand Down Expand Up @@ -764,7 +764,7 @@ def create_x(self):...
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self):
self.create_x_impl(True, self.x_args)
Expand All @@ -786,7 +786,7 @@ def create_x(self):...
with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False
def create_x(self):
self.create_x_impl(self.x_enabled, self.x_args)
Expand Down Expand Up @@ -818,6 +818,11 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
x_args as an explicit dict without getting an incomprehensible error.
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
Expand Down Expand Up @@ -1040,7 +1045,7 @@ def _process_member(
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
some_class.__annotations__[args_name] = DictConfig
some_class.__annotations__[args_name] = dict
if hook is not None:
hook_closed = partial(hook, derived_type)
else:
Expand All @@ -1064,7 +1069,7 @@ def _process_member(
if issubclass(type_, some_class) or type_ in _do_not_process:
raise ValueError(f"Cannot process {type_} inside {some_class}")

some_class.__annotations__[args_name] = DictConfig
some_class.__annotations__[args_name] = dict
if hook is not None:
hook_closed = partial(hook, type_)
else:
Expand Down
28 changes: 28 additions & 0 deletions tests/implicitron/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,34 @@ class MainTestWrapper(Configurable):
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")

def test_get_instance_args(self):
mt1, mt2 = [
MainTest(
n_ids=0,
n_reps=909,
the_fruit_class_type="Pear",
the_second_fruit_class_type="Pear",
the_fruit_Pear_args=DictConfig({}),
the_second_fruit_Pear_args={},
)
for _ in range(2)
]
# Two equivalent ways to get the DictConfig back out of an instance.
cfg1 = OmegaConf.structured(mt1)
cfg2 = get_default_args(mt2)
self.assertEqual(cfg1, cfg2)
self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0)
self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0)

from_cfg = MainTest(**cfg2)
self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0)

# If you want the complete args, merge with the defaults.
merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2)
from_merged = MainTest(**merged_args)
self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1)
self.assertEqual(from_merged.n_reps, 909)

def test_tweak_hook(self):
class A(Configurable):
n: int = 9
Expand Down

0 comments on commit 4d9215b

Please sign in to comment.