Skip to content

Commit

Permalink
Merge pull request #2419 from jsiirola/dill-compatibility
Browse files Browse the repository at this point in the history
Resolve dill incompatibility with `attempt_import`.
  • Loading branch information
mrmundt authored May 24, 2022
2 parents 42ce167 + 3a8e8f5 commit fc02a1c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
37 changes: 33 additions & 4 deletions pyomo/common/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,16 @@ class ModuleUnavailable(object):
The module name that originally attempted the import
"""

# We need special handling for Sphinx here, as it will look for the
# __sphinx_mock__ attribute on all module-level objects, and we need
# that to raise an AttributeError and not a DeferredImportError
_getattr_raises_attributeerror = {'__sphinx_mock__',}
_getattr_raises_attributeerror = {
# We need special handling for Sphinx here, as it will look for the
# __sphinx_mock__ attribute on all module-level objects, and we need
# that to raise an AttributeError and not a DeferredImportError
'__sphinx_mock__',
# We need special handling for dill as well, as dill attempts to
# pickle module globals by looking for the '_dill' attribute on
# all global objects.
'_dill',
}

def __init__(self, name, message, version_error, import_error, package):
self.__name__ = name
Expand All @@ -68,6 +74,17 @@ def __getattr__(self, attr):
% (type(self).__name__, attr))
raise DeferredImportError(self._moduleunavailable_message())

def __getstate__(self):
return (self.__name__, self._moduleunavailable_info_)

def __setstate__(self, state):
self.__name__, self._moduleunavailable_info_ = state

# Included because recent dill picklers look for the mro() when
# detecting numpy types
def mro(self):
return [ModuleUnavailable, object]

def _moduleunavailable_message(self, msg=None):
_err, _ver, _imp, _package = self._moduleunavailable_info_
if msg is None:
Expand Down Expand Up @@ -145,6 +162,18 @@ def __getattr__(self, attr):
_mod = getattr(_mod, _sub)
return getattr(_mod, attr)

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
for k, v in state.items():
super().__setattr__(k, v)

# Included because recent dill picklers look for the mro() when
# detecting numpy types
def mro(self):
return [DeferredImportModule, object]


class _DeferredImportIndicatorBase(object):
def __and__(self, other):
Expand Down
5 changes: 5 additions & 0 deletions pyomo/common/tests/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
bogus, bogus_available \
= attempt_import('nonexisting.module.bogus', defer_check=True)

pkl_test, pkl_available = attempt_import(
'nonexisting.module.pickle_test',
deferred_submodules=['submod'], defer_check=True
)

pyo, pyo_available = attempt_import(
'pyomo', alt_names=['pyo'],
deferred_submodules={'version': None,
Expand Down
23 changes: 22 additions & 1 deletion pyomo/common/tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from pyomo.common.dependencies import (
attempt_import, ModuleUnavailable, DeferredImportModule,
DeferredImportIndicator, DeferredImportError,
_DeferredAnd, _DeferredOr, check_min_version
_DeferredAnd, _DeferredOr, check_min_version,
dill, dill_available
)

import pyomo.common.tests.dep_mod as dep_mod
Expand Down Expand Up @@ -48,6 +49,26 @@ def test_import_error(self):
"attribute '__sphinx_mock__'"):
module_obj.__sphinx_mock__

@unittest.skipUnless(dill_available, "Test requires dill module")
def test_pickle(self):
self.assertIs(deps.pkl_test.__class__, DeferredImportModule)
# Pickle the DeferredImportModule class
pkl = dill.dumps(deps.pkl_test)
deps.new_pkl_test = dill.loads(pkl)
self.assertIs(deps.pkl_test.__class__, deps.new_pkl_test.__class__)
self.assertIs(deps.new_pkl_test.__class__, DeferredImportModule)
self.assertIsNot(deps.pkl_test, deps.new_pkl_test)
self.assertIn('submod', deps.new_pkl_test.__dict__)
with self.assertRaisesRegex(
DeferredImportError, 'nonexisting.module.pickle_test module'):
deps.new_pkl_test.try_to_call_a_method()
# Pickle the ModuleUnavailable class
self.assertIs(deps.new_pkl_test.__class__, ModuleUnavailable)
pkl = dill.dumps(deps.new_pkl_test)
new_pkl_test_2 = dill.loads(pkl)
self.assertIs(deps.new_pkl_test.__class__, new_pkl_test_2.__class__)
self.assertIsNot(deps.new_pkl_test, new_pkl_test_2)
self.assertIs(new_pkl_test_2.__class__, ModuleUnavailable)

def test_import_success(self):
module_obj, module_available = attempt_import(
Expand Down

0 comments on commit fc02a1c

Please sign in to comment.