Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX ImportError for _DictWithDeprecationKeys #188

Merged
merged 3 commits into from
Oct 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import Any

from sklearn.cluster import Birch
from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys

try:
# TODO: remove once support for sklearn<1.2 is dropped. See #187
from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys
except ImportError:
_DictWithDeprecatedKeys = None
from sklearn.linear_model._sgd_fast import (
EpsilonInsensitive,
Hinge,
Expand Down Expand Up @@ -122,6 +127,7 @@ def bunch_get_instance(state, src):
return Bunch(**content)


# TODO: remove once support for sklearn<1.2 is dropped.
def _DictWithDeprecatedKeys_get_state(
obj: Any, save_state: SaveState
) -> dict[str, Any]:
Expand All @@ -138,6 +144,7 @@ def _DictWithDeprecatedKeys_get_state(
return res


# TODO: remove once support for sklearn<1.2 is dropped.
def _DictWithDeprecatedKeys_get_instance(state, src):
# _DictWithDeprecatedKeys is just a wrapper for dict
content = dict_get_instance(state["content"]["main"], src)
Expand All @@ -153,7 +160,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src):
GET_STATE_DISPATCH_FUNCTIONS = [
(LossFunction, reduce_get_state),
(Tree, reduce_get_state),
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state),
]
for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))
Expand All @@ -163,5 +169,15 @@ def _DictWithDeprecatedKeys_get_instance(state, src):
(LossFunction, sgd_loss_get_instance),
(Tree, Tree_get_instance),
(Bunch, bunch_get_instance),
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance),
]

# TODO: remove once support for sklearn<1.2 is dropped.
# Starting from sklearn 1.2, _DictWithDeprecatedKeys is removed as it's no
# longer needed for GraphicalLassoCV, see #187.
if _DictWithDeprecatedKeys is not None:
GET_STATE_DISPATCH_FUNCTIONS.append(
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add the same note to _DictWithDeprecatedKeys_get_state and _DictWithDeprecatedKeys_get_instance

)
GET_INSTANCE_DISPATCH_FUNCTIONS.append(
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance)
)