diff --git a/tests/contrib/patch.py b/tests/contrib/patch.py index 5ef41339fdc..913b3895894 100644 --- a/tests/contrib/patch.py +++ b/tests/contrib/patch.py @@ -43,7 +43,7 @@ def assert_not_module_imported(self, modname): assert not self.module_imported(modname), "{} module is imported".format(modname) def is_wrapped(self, obj): - return isinstance(obj, wrapt.ObjectProxy) + return isinstance(obj, wrapt.ObjectProxy) or hasattr(obj, "__dd_wrapped__") def assert_wrapped(self, obj): """ @@ -64,7 +64,10 @@ def assert_not_double_wrapped(self, obj): This is useful for asserting idempotence. """ self.assert_wrapped(obj) - self.assert_not_wrapped(obj.__wrapped__) + + wrapped = obj.__wrapped__ if isinstance(obj, wrapt.ObjectProxy) else obj.__dd_wrapped__ + + self.assert_not_wrapped(wrapped) def raise_if_no_attrs(f): diff --git a/tests/contrib/pymongo/test.py b/tests/contrib/pymongo/test.py index 86d8c7c14a2..e33ed688d89 100644 --- a/tests/contrib/pymongo/test.py +++ b/tests/contrib/pymongo/test.py @@ -7,8 +7,8 @@ from ddtrace import Pin from ddtrace.contrib.internal.pymongo.client import normalize_filter from ddtrace.contrib.internal.pymongo.patch import _CHECKOUT_FN_NAME -from ddtrace.contrib.pymongo.patch import patch -from ddtrace.contrib.pymongo.patch import unpatch +from ddtrace.contrib.internal.pymongo.patch import patch +from ddtrace.contrib.internal.pymongo.patch import unpatch from ddtrace.ext import SpanTypes from tests.opentracer.utils import init_tracer from tests.utils import DummyTracer diff --git a/tests/contrib/pymongo/test_pymongo_patch.py b/tests/contrib/pymongo/test_pymongo_patch.py index 3596beca648..b1c51ea36a3 100644 --- a/tests/contrib/pymongo/test_pymongo_patch.py +++ b/tests/contrib/pymongo/test_pymongo_patch.py @@ -2,16 +2,27 @@ # script. If you want to make changes to it, you should make sure that you have # removed the ``_generated`` suffix from the file name, to prevent the content # from being overwritten by future re-generations. - -from ddtrace.contrib.pymongo import get_version +from ddtrace.contrib.pymongo.patch import get_version from ddtrace.contrib.pymongo.patch import patch +from ddtrace.contrib.pymongo.patch import pymongo +from ddtrace.contrib.pymongo.patch import unpatch +from tests.contrib.patch import PatchTestCase -try: - from ddtrace.contrib.pymongo.patch import unpatch -except ImportError: - unpatch = None -from tests.contrib.patch import PatchTestCase +_VERSION = pymongo.version_tuple + +if _VERSION >= (4, 9): + from pymongo.synchronous.pool import Connection + from pymongo.synchronous.server import Server + from pymongo.synchronous.topology import Topology +elif _VERSION >= (4, 5): + from pymongo.pool import Connection + from pymongo.server import Server + from pymongo.topology import Topology +else: + from pymongo.pool import SocketInfo as Connection + from pymongo.server import Server + from pymongo.topology import Topology class TestPymongoPatch(PatchTestCase.Base): @@ -22,10 +33,55 @@ class TestPymongoPatch(PatchTestCase.Base): __get_version__ = get_version def assert_module_patched(self, pymongo): - pass + self.assert_wrapped(pymongo.MongoClient.__init__) + self.assert_wrapped(Topology.select_server) + + if _VERSION >= (3, 12): + self.assert_wrapped(Server.run_operation) + elif _VERSION >= (3, 9): + self.assert_wrapped(Server.run_operation_with_response) + else: + self.assert_wrapped(Server.send_message_with_response) + + if _VERSION >= (4, 5): + self.assert_wrapped(Server.checkout) + else: + self.assert_wrapped(Server.get_socket) + self.assert_wrapped(Connection.command) + self.assert_wrapped(Connection.write_command) def assert_not_module_patched(self, pymongo): - pass + self.assert_not_wrapped(pymongo.MongoClient.__init__) + self.assert_not_wrapped(Topology.select_server) + if _VERSION >= (3, 12): + self.assert_not_wrapped(Server.run_operation) + elif _VERSION >= (3, 9): + self.assert_not_wrapped(Server.run_operation_with_response) + else: + self.assert_not_wrapped(Server.send_message_with_response) + + if _VERSION >= (4, 5): + self.assert_not_wrapped(Server.checkout) + else: + self.assert_not_wrapped(Server.get_socket) + + self.assert_not_wrapped(Connection.command) + self.assert_not_wrapped(Connection.write_command) def assert_not_module_double_patched(self, pymongo): - pass + self.assert_not_double_wrapped(pymongo.MongoClient.__init__) + self.assert_not_double_wrapped(Topology.select_server) + self.assert_not_double_wrapped(Connection.command) + self.assert_not_double_wrapped(Connection.write_command) + + if _VERSION >= (3, 12): + self.assert_not_double_wrapped(Server.run_operation) + elif _VERSION >= (3, 9): + self.assert_not_double_wrapped(Server.run_operation_with_response) + else: + self.assert_not_double_wrapped(Server.send_message_with_response) + + if _VERSION >= (4, 5): + self.assert_not_double_wrapped(Server.checkout) + else: + self.assert_not_double_wrapped(Server.get_socket)