diff --git a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py index d2bb76061a..6d7e37a45f 100644 --- a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py @@ -203,7 +203,7 @@ def instrument_connection( Returns: An instrumented connection. """ - if isinstance(connection, _TracedConnectionProxy): + if isinstance(connection, wrapt.ObjectProxy): _logger.warning("Connection already instrumented") return connection @@ -230,8 +230,8 @@ def uninstrument_connection(connection): Returns: An uninstrumented connection. """ - if isinstance(connection, _TracedConnectionProxy): - return connection._connection + if isinstance(connection, wrapt.ObjectProxy): + return connection.__wrapped__ _logger.warning("Connection is not instrumented") return connection @@ -320,22 +320,14 @@ def get_connection_attributes(self, connection): self.span_attributes[SpanAttributes.NET_PEER_PORT] = port -class _TracedConnectionProxy: - pass - - def get_traced_connection_proxy( connection, db_api_integration, *args, **kwargs ): # pylint: disable=abstract-method - class TracedConnectionProxy(type(connection), _TracedConnectionProxy): - def __init__(self, connection): - self._connection = connection - - def __getattr__(self, name): - return object.__getattribute__( - object.__getattribute__(self, "_connection"), name - ) + class TracedConnectionProxy(wrapt.ObjectProxy): + # pylint: disable=unused-argument + def __init__(self, connection, *args, **kwargs): + wrapt.ObjectProxy.__init__(self, connection) def __getattribute__(self, name): if object.__getattribute__(self, name): @@ -347,16 +339,17 @@ def __getattribute__(self, name): def cursor(self, *args, **kwargs): return get_traced_cursor_proxy( - self._connection.cursor(*args, **kwargs), db_api_integration + self.__wrapped__.cursor(*args, **kwargs), db_api_integration ) - # For some reason this is necessary as trying to access the close - # method of self._connection via __getattr__ leads to unexplained - # errors. - def close(self): - self._connection.close() + def __enter__(self): + self.__wrapped__.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self.__wrapped__.__exit__(*args, **kwargs) - return TracedConnectionProxy(connection) + return TracedConnectionProxy(connection, *args, **kwargs) class CursorTracer: diff --git a/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py b/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py index a7fb608c27..0d19ce8373 100644 --- a/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py +++ b/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py @@ -325,14 +325,14 @@ def test_callproc(self): @mock.patch("opentelemetry.instrumentation.dbapi") def test_wrap_connect(self, mock_dbapi): - dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-") + dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-") connection = mock_dbapi.connect() self.assertEqual(mock_dbapi.connect.call_count, 1) - self.assertIsInstance(connection._connection, mock.Mock) + self.assertIsInstance(connection.__wrapped__, mock.Mock) @mock.patch("opentelemetry.instrumentation.dbapi") def test_unwrap_connect(self, mock_dbapi): - dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-") + dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-") connection = mock_dbapi.connect() self.assertEqual(mock_dbapi.connect.call_count, 1) @@ -342,21 +342,19 @@ def test_unwrap_connect(self, mock_dbapi): self.assertIsInstance(connection, mock.Mock) def test_instrument_connection(self): - connection = MockConnectionEmpty() + connection = mock.Mock() # Avoid get_attributes failing because can't concatenate mock - # pylint: disable=attribute-defined-outside-init connection.database = "-" connection2 = dbapi.instrument_connection(self.tracer, connection, "-") - self.assertIs(connection2._connection, connection) + self.assertIs(connection2.__wrapped__, connection) def test_uninstrument_connection(self): - connection = MockConnectionEmpty() + connection = mock.Mock() # Set connection.database to avoid a failure because mock can't # be concatenated - # pylint: disable=attribute-defined-outside-init connection.database = "-" connection2 = dbapi.instrument_connection(self.tracer, connection, "-") - self.assertIs(connection2._connection, connection) + self.assertIs(connection2.__wrapped__, connection) connection3 = dbapi.uninstrument_connection(connection2) self.assertIs(connection3, connection) @@ -372,12 +370,10 @@ def mock_connect(*args, **kwargs): server_host = kwargs.get("server_host") server_port = kwargs.get("server_port") user = kwargs.get("user") - return MockConnectionWithAttributes( - database, server_port, server_host, user - ) + return MockConnection(database, server_port, server_host, user) -class MockConnectionWithAttributes: +class MockConnection: def __init__(self, database, server_port, server_host, user): self.database = database self.server_port = server_port @@ -410,7 +406,3 @@ def executemany(self, query, params=None, throw_exception=False): def callproc(self, query, params=None, throw_exception=False): if throw_exception: raise Exception("Test Exception") - - -class MockConnectionEmpty: - pass diff --git a/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py b/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py index 8274851ff1..3614febffd 100644 --- a/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py +++ b/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock, patch +from unittest import mock import mysql.connector @@ -23,15 +23,6 @@ from opentelemetry.test.test_base import TestBase -def mock_connect(*args, **kwargs): - class MockConnection: - def cursor(self): - # pylint: disable=no-self-use - return Mock() - - return MockConnection() - - def connect_and_execute_query(): cnx = mysql.connector.connect(database="test") cursor = cnx.cursor() @@ -47,9 +38,9 @@ def tearDown(self): with self.disable_logging(): MySQLInstrumentor().uninstrument() - @patch("mysql.connector.connect", new=mock_connect) + @mock.patch("mysql.connector.connect") # pylint: disable=unused-argument - def test_instrumentor(self): + def test_instrumentor(self, mock_connect): MySQLInstrumentor().instrument() connect_and_execute_query() @@ -71,8 +62,9 @@ def test_instrumentor(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @patch("mysql.connector.connect", new=mock_connect) - def test_custom_tracer_provider(self): + @mock.patch("mysql.connector.connect") + # pylint: disable=unused-argument + def test_custom_tracer_provider(self, mock_connect): resource = resources.Resource.create({}) result = self.create_tracer_provider(resource=resource) tracer_provider, exporter = result @@ -86,9 +78,9 @@ def test_custom_tracer_provider(self): self.assertIs(span.resource, resource) - @patch("mysql.connector.connect", new=mock_connect) + @mock.patch("mysql.connector.connect") # pylint: disable=unused-argument - def test_instrument_connection(self): + def test_instrument_connection(self, mock_connect): cnx, query = connect_and_execute_query() spans_list = self.memory_exporter.get_finished_spans() @@ -101,8 +93,8 @@ def test_instrument_connection(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @patch("mysql.connector.connect", new=mock_connect) - def test_instrument_connection_no_op_tracer_provider(self): + @mock.patch("mysql.connector.connect") + def test_instrument_connection_no_op_tracer_provider(self, mock_connect): tracer_provider = trace_api.NoOpTracerProvider() MySQLInstrumentor().instrument(tracer_provider=tracer_provider) connect_and_execute_query() @@ -110,9 +102,9 @@ def test_instrument_connection_no_op_tracer_provider(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 0) - @patch("mysql.connector.connect", new=mock_connect) + @mock.patch("mysql.connector.connect") # pylint: disable=unused-argument - def test_uninstrument_connection(self): + def test_uninstrument_connection(self, mock_connect): MySQLInstrumentor().instrument() cnx, query = connect_and_execute_query() diff --git a/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py b/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py index 42dd94f2da..587ebc1b53 100644 --- a/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py +++ b/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock, patch +from unittest import mock import pymysql @@ -22,24 +22,15 @@ from opentelemetry.test.test_base import TestBase -def mock_connect(*args, **kwargs): - class MockConnection: - def cursor(self): - # pylint: disable=no-self-use - return Mock() - - return MockConnection() - - class TestPyMysqlIntegration(TestBase): def tearDown(self): super().tearDown() with self.disable_logging(): PyMySQLInstrumentor().uninstrument() - @patch("pymysql.connect", new=mock_connect) + @mock.patch("pymysql.connect") # pylint: disable=unused-argument - def test_instrumentor(self): + def test_instrumentor(self, mock_connect): PyMySQLInstrumentor().instrument() cnx = pymysql.connect(database="test") @@ -67,9 +58,9 @@ def test_instrumentor(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @patch("pymysql.connect", new=mock_connect) + @mock.patch("pymysql.connect") # pylint: disable=unused-argument - def test_custom_tracer_provider(self): + def test_custom_tracer_provider(self, mock_connect): resource = resources.Resource.create({}) result = self.create_tracer_provider(resource=resource) tracer_provider, exporter = result @@ -87,9 +78,9 @@ def test_custom_tracer_provider(self): self.assertIs(span.resource, resource) - @patch("pymysql.connect", new=mock_connect) + @mock.patch("pymysql.connect") # pylint: disable=unused-argument - def test_instrument_connection(self): + def test_instrument_connection(self, mock_connect): cnx = pymysql.connect(database="test") query = "SELECT * FROM test" cursor = cnx.cursor() @@ -105,9 +96,9 @@ def test_instrument_connection(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @patch("pymysql.connect", new=mock_connect) + @mock.patch("pymysql.connect") # pylint: disable=unused-argument - def test_uninstrument_connection(self): + def test_uninstrument_connection(self, mock_connect): PyMySQLInstrumentor().instrument() cnx = pymysql.connect(database="test") query = "SELECT * FROM test"