Skip to content

Commit

Permalink
Merge pull request #681 from mit-ll-responsible-ai/defaultdict
Browse files Browse the repository at this point in the history
autoconfig support for defaultdict
  • Loading branch information
rsokl authored Apr 28, 2024
2 parents 85ef912 + 13ae716 commit e090e7b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ hydra-zen provides specialized auto-config support for values of the following t
- :py:class:`bytearray`
- :py:class:`complex`
- :py:class:`collections.Counter`
- :py:class:`collections.defaultdict`
- :py:class:`collections.deque`
- :py:class:`datetime.timedelta`
- :py:func:`functools.partial` (note: not compatible with pickling)
Expand Down
3 changes: 2 additions & 1 deletion src/hydra_zen/_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
from collections import Counter, deque
from collections import Counter, defaultdict, deque
from datetime import timedelta
from enum import Enum
from functools import partial
Expand Down Expand Up @@ -70,6 +70,7 @@ def _get_version(ver_str: str) -> Version:
Counter,
range,
timedelta,
defaultdict,
}
)

Expand Down
10 changes: 10 additions & 0 deletions src/hydra_zen/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,13 @@ def zen_processing(
if _zen_partial is True:
return _functools.partial(obj, *args, **kwargs)
return obj(*args, **kwargs)


def as_default_dict(
dict_: _tp.Dict[_tp.Any, _tp.Any], *, default_factory: _tp.Any
) -> _tp.DefaultDict[_tp.Any, _tp.Any]:
from collections import defaultdict

obj = defaultdict(default_factory)
obj.update(dict_)
return obj
28 changes: 26 additions & 2 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import sys
import warnings
from collections import Counter, deque
from collections import Counter, defaultdict, deque
from collections.abc import Collection
from dataclasses import ( # use this for runtime checks
MISSING,
Expand Down Expand Up @@ -72,7 +72,7 @@
HydraZenUnsupportedPrimitiveError,
HydraZenValidationError,
)
from hydra_zen.funcs import get_obj
from hydra_zen.funcs import as_default_dict, get_obj
from hydra_zen.structured_configs import _utils
from hydra_zen.structured_configs._type_guards import safe_getattr
from hydra_zen.typing import (
Expand Down Expand Up @@ -3750,6 +3750,30 @@ def __post_init__(self, CBuildsFn: Type[BuildsFn[Any]]) -> None:
del CBuildsFn


@dataclass(unsafe_hash=True)
class ConfigFromDefaultDict:
dict_: Dict[Any, Any]
default_factory: Any = field(init=False)
CBuildsFn: InitVar[Type[BuildsFn[Any]]]
_target_: str = BuildsFn._get_obj_path(as_default_dict)

def __post_init__(self, CBuildsFn: Type[BuildsFn[Any]]) -> None:
assert isinstance(self.dict_, defaultdict)
self.default_factory = CBuildsFn.just(self.dict_.default_factory)
out = CBuildsFn._make_hydra_compatible(
dict(self.dict_),
convert_dataclass=True,
allow_zen_conversion=True,
structured_conf_permitted=True,
)
assert isinstance(out, dict)
self.dict_ = out


ZEN_VALUE_CONVERSION[defaultdict] = lambda dict_, CBuildsFn: ConfigFromDefaultDict(
dict_, CBuildsFn
)

ZEN_VALUE_CONVERSION[set] = partial(
ConfigFromTuple, _target_=BuildsFn._get_obj_path(set)
)
Expand Down
17 changes: 16 additions & 1 deletion tests/test_value_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import pickle
import string
from collections import Counter, deque
from collections import Counter, defaultdict, deque
from dataclasses import dataclass, field
from datetime import timedelta
from enum import Enum
Expand Down Expand Up @@ -142,6 +142,9 @@ def test_value_supported_via_config_maker_functions(
Brange = pik_blds(dict, x=range(1, 10, 3), zen_dataclass={"cls_name": "Brange"})
Brange2 = pik_blds(dict, x=range(2), zen_dataclass={"cls_name": "Brange2"})
Bcounter = pik_blds(dict, x=Counter("apple"), zen_dataclass={"cls_name": "Bcounter"})
x = defaultdict(list)
x.update({1: [1, 2]})
Bdefaultdict = pik_blds(dict, x=x, zen_dataclass={"cls_name": "Bdefaultdict"})


@pytest.mark.parametrize(
Expand All @@ -164,6 +167,7 @@ def test_value_supported_via_config_maker_functions(
Brange,
Brange2,
Bcounter,
Bdefaultdict,
],
)
def test_pickle_compatibility(Config):
Expand All @@ -190,6 +194,7 @@ def test_pickle_compatibility(Config):
Brange,
Brange2,
Bcounter,
Bdefaultdict,
],
)
def test_unsafe_hash_default(Config):
Expand Down Expand Up @@ -412,3 +417,13 @@ def test_known_failcase_hydra_2350():
actual = instantiate(Conf, x={"b": 2})
expected = {"b": 2}
assert actual == expected, actual


def test_default_dict():
x = defaultdict(list)
x.update({1: [1 + 2j, 2]})
Conf = builds(dict, x=x)
actual = instantiate(Conf)["x"]
assert actual == x
assert isinstance(actual, defaultdict)
assert actual.default_factory is list

0 comments on commit e090e7b

Please sign in to comment.