diff --git a/testing/test_workermanage.py b/testing/test_workermanage.py index fae06011..b19c524a 100644 --- a/testing/test_workermanage.py +++ b/testing/test_workermanage.py @@ -3,9 +3,12 @@ import pytest import shutil import textwrap +import warnings from pathlib import Path +from util import generate_warning from xdist import workermanage -from xdist.workermanage import HostRSync, NodeManager +from xdist.remote import serialize_warning_message +from xdist.workermanage import HostRSync, NodeManager, unserialize_warning_message pytest_plugins = "pytester" @@ -326,3 +329,80 @@ def test_one(): ) (rep,) = reprec.getreports("pytest_runtest_logreport") assert rep.passed + + +class MyWarning(UserWarning): + pass + + +@pytest.mark.parametrize( + "w_cls", + [ + UserWarning, + MyWarning, + "Imported", + pytest.param( + "Nested", + marks=pytest.mark.xfail(reason="Nested warning classes are not supported."), + ), + ], +) +def test_unserialize_warning_msg(w_cls): + """Test that warning serialization process works well""" + + # Create a test warning message + with pytest.warns(UserWarning) as w: + if not isinstance(w_cls, str): + warnings.warn("hello", w_cls) + elif w_cls == "Imported": + generate_warning() + elif w_cls == "Nested": + # dynamic creation + class MyWarning2(UserWarning): + pass + + warnings.warn("hello", MyWarning2) + + # Unpack + assert len(w) == 1 + w_msg = w[0] + + # Serialize and deserialize + data = serialize_warning_message(w_msg) + w_msg2 = unserialize_warning_message(data) + + # Compare the two objects + all_keys = set(vars(w_msg).keys()).union(set(vars(w_msg2).keys())) + for k in all_keys: + v1 = getattr(w_msg, k) + v2 = getattr(w_msg2, k) + if k == "message": + assert type(v1) == type(v2) + assert v1.args == v2.args + else: + assert v1 == v2 + + +class MyWarningUnknown(UserWarning): + # Changing the __module__ attribute is only safe if class can be imported + # from there + __module__ = "unknown" + + +def test_warning_serialization_tweaked_module(): + """Test for GH#404""" + + # Create a test warning message + with pytest.warns(UserWarning) as w: + warnings.warn("hello", MyWarningUnknown) + + # Unpack + assert len(w) == 1 + w_msg = w[0] + + # Serialize and deserialize + data = serialize_warning_message(w_msg) + + # __module__ cannot be found! + with pytest.raises(ModuleNotFoundError): + unserialize_warning_message(data) diff --git a/testing/util.py b/testing/util.py new file mode 100644 index 00000000..c7bcc552 --- /dev/null +++ b/testing/util.py @@ -0,0 +1,9 @@ +import warnings + + +class MyWarning2(UserWarning): + pass + + +def generate_warning(): + warnings.warn(MyWarning2("hello"))