diff --git a/ddtrace/contrib/internal/pymongo/patch.py b/ddtrace/contrib/internal/pymongo/patch.py index 59a7ad423d6..e1aa6d9f839 100644 --- a/ddtrace/contrib/internal/pymongo/patch.py +++ b/ddtrace/contrib/internal/pymongo/patch.py @@ -27,6 +27,20 @@ from .client import set_address_tags +_VERSION = pymongo.version_tuple + +if _VERSION >= (4, 9): + from pymongo.synchronous.pool import Connection + from pymongo.synchronous.pool import SocketInfo + from pymongo.synchronous.server import Server + from pymongo.synchronous.topology import Topology +else: + from pymongo.pool import Connection + from pymongo.pool import SocketInfo + from pymongo.server import Server + from pymongo.topology import Topology + + _CHECKOUT_FN_NAME = "get_socket" if pymongo.version_tuple < (4, 5) else "checkout" @@ -41,9 +55,6 @@ def get_version(): return getattr(pymongo, "__version__", "") -_VERSION = pymongo.version_tuple - - def patch(): if getattr(pymongo, "_datadog_patch", False): return @@ -60,43 +71,43 @@ def unpatch(): def patch_pymongo_module(): _w(pymongo.MongoClient.__init__, _trace_mongo_client_init) - _w(pymongo.synchronous.topology.Topology.select_server, _trace_topology_select_server) + _w(Topology.select_server, _trace_topology_select_server) if _VERSION >= (3, 12): - _w(pymongo.synchronous.server.Server.run_operation, _trace_server_run_operation_and_with_response) + _w(Server.run_operation, _trace_server_run_operation_and_with_response) elif _VERSION >= (3, 9): - _w(pymongo.synchronous.server.Server.run_operation_with_response, _trace_server_run_operation_and_with_response) + _w(Server.run_operation_with_response, _trace_server_run_operation_and_with_response) else: - _w(pymongo.synchronous.server.Server.send_message_with_response, _trace_server_send_message_with_response) + _w(Server.send_message_with_response, _trace_server_send_message_with_response) if _VERSION >= (4, 5): - _w(pymongo.synchronous.server.Server.checkout, traced_get_socket) - _w(pymongo.synchronous.pool.Connection.command, _trace_socket_command) - _w(pymongo.synchronous.pool.Connection.write_command, _trace_socket_write_command) + _w(Server.checkout, traced_get_socket) + _w(Connection.command, _trace_socket_command) + _w(Connection.write_command, _trace_socket_write_command) else: - _w(pymongo.synchronous.server.Server.get_socket, traced_get_socket) - _w(pymongo.synchronous.pool.SocketInfo.command, _trace_socket_command) - _w(pymongo.synchronous.pool.SocketInfo.write_command, _trace_socket_write_command) + _w(Server.get_socket, traced_get_socket) + _w(SocketInfo.command, _trace_socket_command) + _w(SocketInfo.write_command, _trace_socket_write_command) def unpatch_pymongo_module(): _u(pymongo.MongoClient.__init__, _trace_mongo_client_init) - _u(pymongo.synchronous.topology.Topology.select_server, _trace_topology_select_server) + _u(Topology.select_server, _trace_topology_select_server) if _VERSION >= (3, 12): - _u(pymongo.synchronous.server.Server.run_operation, _trace_server_run_operation_and_with_response) + _u(Server.run_operation, _trace_server_run_operation_and_with_response) elif _VERSION >= (3, 9): - _u(pymongo.synchronous.server.Server.run_operation_with_response, _trace_server_run_operation_and_with_response) + _u(Server.run_operation_with_response, _trace_server_run_operation_and_with_response) else: - _u(pymongo.synchronous.server.Server.send_message_with_response, _trace_server_send_message_with_response) + _u(Server.send_message_with_response, _trace_server_send_message_with_response) if _VERSION >= (4, 5): - _u(pymongo.synchronous.server.Server.checkout, traced_get_socket) - _u(pymongo.synchronous.pool.Connection.command, _trace_socket_command) - _u(pymongo.synchronous.pool.Connection.write_command, _trace_socket_write_command) + _u(Server.checkout, traced_get_socket) + _u(Connection.command, _trace_socket_command) + _u(Connection.write_command, _trace_socket_write_command) else: - _u(pymongo.synchronous.server.Server.get_socket, traced_get_socket) - _u(pymongo.synchronous.pool.SocketInfo.command, _trace_socket_command) - _u(pymongo.synchronous.pool.SocketInfo.write_command, _trace_socket_write_command) + _u(Server.get_socket, traced_get_socket) + _u(SocketInfo.command, _trace_socket_command) + _u(SocketInfo.write_command, _trace_socket_write_command) @contextlib.contextmanager