From 6d781cf11b6e078cfa9f0e2bfa0a89a459aa8727 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 14 Oct 2022 14:49:22 +0200 Subject: [PATCH 1/3] Bugfix for _DictWithDeprecatinKeys not importable In sklearn 1.2.0, _DictWithDeprecatinKeys will be removed. Currently, skops will fail to import it and raise an error. In the future, we try to import it and if that fails, not add it as a dispatch type. Implementation The bugfix should work. There is, however, trouble with testing it, because we need to patch top level imports. This does not work properly because the order of patching vs importing is wrong, resulting in the patch being applied too late. I have added a test to demonstrate it. If it is run in isolation, it passes: $ python -m pytest skops/io/tests/test_deprecation.py If run in conjunction with the other tests, it fails: $ python -m pytest -m skops E assert not True E + where True = any(. at 0x7fa727c2e3b0>) This is because the other tests will trigger the import of _DictWithDeprecatinKeys _before_ the patch is applied. This is bad, ideally we want our tests not to affect each other. I see these solutions: 1. remove the test and just assume that the bugfix works 2. CI matrix extended to run with sklearn<1.2 and with sklearn>=1.2 3. change imports such that the tests don't influence each other I think it goes without saying why 1 and 2 are not very desirable. I haven't actually tested 3. but the idea would be to make all skops imports in the tests local instead importing on the root level. So for example if we want to test skops.io.load, instead of having: from skops.io import load def test_it(): load(...) we should have: @pytest.fixture def load(): from skops.io import load return load def test_it(load): load(...) If I'm not mistaken, this should resolve this issue and most similar issues in the future (it's not a cure all solution though). As a practical example, in skorch we handle it exactly like this. Below is an example of testing the skorch to_tensor function. See how the function is only imported locally, not on a module level: https://github.com/skorch-dev/skorch/blob/cfe568b58f150730e7269b5a2a1b2bac2228dbd2/skorch/tests/test_utils.py#L15-L34 Hopefully, there is a better way than any of the suggestions, but I couldn't come up with any. --- skops/io/_sklearn.py | 16 ++++++++-- skops/io/tests/test_deprecation.py | 48 ++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 skops/io/tests/test_deprecation.py diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index da32b7e5..0fea20c9 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -3,7 +3,11 @@ from typing import Any from sklearn.cluster import Birch -from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys + +try: + from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys +except ImportError: + _DictWithDeprecatedKeys = None from sklearn.linear_model._sgd_fast import ( EpsilonInsensitive, Hinge, @@ -153,7 +157,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src): GET_STATE_DISPATCH_FUNCTIONS = [ (LossFunction, reduce_get_state), (Tree, reduce_get_state), - (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state), ] for type_ in UNSUPPORTED_TYPES: GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state)) @@ -163,5 +166,12 @@ def _DictWithDeprecatedKeys_get_instance(state, src): (LossFunction, sgd_loss_get_instance), (Tree, Tree_get_instance), (Bunch, bunch_get_instance), - (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance), ] + +if _DictWithDeprecatedKeys is not None: + GET_STATE_DISPATCH_FUNCTIONS.append( + (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) + ) + GET_INSTANCE_DISPATCH_FUNCTIONS.append( + (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance) + ) diff --git a/skops/io/tests/test_deprecation.py b/skops/io/tests/test_deprecation.py new file mode 100644 index 00000000..e4badde2 --- /dev/null +++ b/skops/io/tests/test_deprecation.py @@ -0,0 +1,48 @@ +from unittest.mock import patch + +orig_import = __import__ + + +def test_dictwithdeprecatedkeys_cannot_be_imported(tmp_path): + # _DictWithDeprecatedKeys is removed in sklearn 1.2.0 + # see bug reported in #187 + + # mock the loading of + # sklearn.covariance._graph_lasso._DictWithDeprecatedKeys to raise an + # ImportError + def import_mock(name, *args, **kwargs): + if name == "sklearn.covariance._graph_lasso": + if args and ("_DictWithDeprecatedKeys" in args[2]): + raise ImportError("mock import error") + return orig_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=import_mock): + # important: skops.io has to be loaded _after_ mocking the import, + # otherwise it's too late, as the dispatch methods are added to their + # respective lists on root level of their respective modules, so + # patching after that is too late. + from sklearn.covariance import GraphicalLassoCV + + from skops.io import load, save + + f_name = tmp_path / "file.skops" + estimator = GraphicalLassoCV() + save(file=f_name, obj=estimator) + load(file=f_name) + + # sanity check: make sure that the import really raised an error and + # thus there is no dispatch for _DictWithDeprecatedKeys, or else this + # test would pass trivially + from skops.io._sklearn import ( + GET_INSTANCE_DISPATCH_FUNCTIONS, + GET_STATE_DISPATCH_FUNCTIONS, + ) + + assert not any( + t.__name__ == "_DictWithDeprecatedKeys" + for (t, _) in GET_STATE_DISPATCH_FUNCTIONS + ) + assert not any( + t.__name__ == "_DictWithDeprecatedKeys" + for (t, _) in GET_INSTANCE_DISPATCH_FUNCTIONS + ) From 3a72cb485656d7702946a7579ae2e412b3b6ec7c Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 17 Oct 2022 13:38:01 +0200 Subject: [PATCH 2/3] Address reviewer feedback: - Add comments to clarify change and TODO about when it can be removed - Remove test as discussed, will be covered by CI --- skops/io/_sklearn.py | 4 +++ skops/io/tests/test_deprecation.py | 48 ------------------------------ 2 files changed, 4 insertions(+), 48 deletions(-) delete mode 100644 skops/io/tests/test_deprecation.py diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 0fea20c9..0e7bd680 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -5,6 +5,7 @@ from sklearn.cluster import Birch try: + # TODO: remove once support for sklearn<1.2 is dropped. See #187 from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys except ImportError: _DictWithDeprecatedKeys = None @@ -168,6 +169,9 @@ def _DictWithDeprecatedKeys_get_instance(state, src): (Bunch, bunch_get_instance), ] +# TODO: remove once support for sklearn<1.2 is dropped. +# Starting from sklearn 1.2, _DictWithDeprecatedKeys is removed as it's no +# longer needed for GraphicalLassoCV, see #187. if _DictWithDeprecatedKeys is not None: GET_STATE_DISPATCH_FUNCTIONS.append( (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) diff --git a/skops/io/tests/test_deprecation.py b/skops/io/tests/test_deprecation.py deleted file mode 100644 index e4badde2..00000000 --- a/skops/io/tests/test_deprecation.py +++ /dev/null @@ -1,48 +0,0 @@ -from unittest.mock import patch - -orig_import = __import__ - - -def test_dictwithdeprecatedkeys_cannot_be_imported(tmp_path): - # _DictWithDeprecatedKeys is removed in sklearn 1.2.0 - # see bug reported in #187 - - # mock the loading of - # sklearn.covariance._graph_lasso._DictWithDeprecatedKeys to raise an - # ImportError - def import_mock(name, *args, **kwargs): - if name == "sklearn.covariance._graph_lasso": - if args and ("_DictWithDeprecatedKeys" in args[2]): - raise ImportError("mock import error") - return orig_import(name, *args, **kwargs) - - with patch("builtins.__import__", side_effect=import_mock): - # important: skops.io has to be loaded _after_ mocking the import, - # otherwise it's too late, as the dispatch methods are added to their - # respective lists on root level of their respective modules, so - # patching after that is too late. - from sklearn.covariance import GraphicalLassoCV - - from skops.io import load, save - - f_name = tmp_path / "file.skops" - estimator = GraphicalLassoCV() - save(file=f_name, obj=estimator) - load(file=f_name) - - # sanity check: make sure that the import really raised an error and - # thus there is no dispatch for _DictWithDeprecatedKeys, or else this - # test would pass trivially - from skops.io._sklearn import ( - GET_INSTANCE_DISPATCH_FUNCTIONS, - GET_STATE_DISPATCH_FUNCTIONS, - ) - - assert not any( - t.__name__ == "_DictWithDeprecatedKeys" - for (t, _) in GET_STATE_DISPATCH_FUNCTIONS - ) - assert not any( - t.__name__ == "_DictWithDeprecatedKeys" - for (t, _) in GET_INSTANCE_DISPATCH_FUNCTIONS - ) From 2b7ee3c08d0d072df79eb669b189526783603226 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 17 Oct 2022 16:50:42 +0200 Subject: [PATCH 3/3] Reviewer: Add missing TODO comments --- skops/io/_sklearn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 0e7bd680..d81c54d1 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -127,6 +127,7 @@ def bunch_get_instance(state, src): return Bunch(**content) +# TODO: remove once support for sklearn<1.2 is dropped. def _DictWithDeprecatedKeys_get_state( obj: Any, save_state: SaveState ) -> dict[str, Any]: @@ -143,6 +144,7 @@ def _DictWithDeprecatedKeys_get_state( return res +# TODO: remove once support for sklearn<1.2 is dropped. def _DictWithDeprecatedKeys_get_instance(state, src): # _DictWithDeprecatedKeys is just a wrapper for dict content = dict_get_instance(state["content"]["main"], src)