diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 14b2cb7ab..0c53dc2c5 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -5,7 +5,6 @@ from time import perf_counter import django.test.testcases -from django.db.backends.utils import CursorWrapper from django.utils.encoding import force_str from debug_toolbar.utils import get_stack_trace, get_template_info @@ -60,34 +59,42 @@ def cursor(*args, **kwargs): cursor = connection._djdt_cursor(*args, **kwargs) if logger is None: return cursor - wrapper = NormalCursorWrapper if allow_sql.get() else ExceptionCursorWrapper - return wrapper(cursor.cursor, connection, logger) + mixin = NormalCursorMixin if allow_sql.get() else ExceptionCursorMixin + return patch_cursor_wrapper_with_mixin(cursor.__class__, mixin)( + cursor.cursor, connection, logger + ) def chunked_cursor(*args, **kwargs): # prevent double wrapping # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 logger = connection._djdt_logger cursor = connection._djdt_chunked_cursor(*args, **kwargs) - if logger is not None and not isinstance(cursor, DjDTCursorWrapper): - if allow_sql.get(): - wrapper = NormalCursorWrapper - else: - wrapper = ExceptionCursorWrapper - return wrapper(cursor.cursor, connection, logger) + if logger is not None and not isinstance(cursor, DjDTCursorWrapperMixin): + mixin = NormalCursorMixin if allow_sql.get() else ExceptionCursorMixin + return patch_cursor_wrapper_with_mixin(cursor.__class__, mixin)( + cursor.cursor, connection, logger + ) return cursor connection.cursor = cursor connection.chunked_cursor = chunked_cursor -class DjDTCursorWrapper(CursorWrapper): +def patch_cursor_wrapper_with_mixin(base_wrapper, mixin): + class DjDTCursorWrapper(mixin, base_wrapper): + pass + + return DjDTCursorWrapper + + +class DjDTCursorWrapperMixin: def __init__(self, cursor, db, logger): super().__init__(cursor, db) # logger must implement a ``record`` method self.logger = logger -class ExceptionCursorWrapper(DjDTCursorWrapper): +class ExceptionCursorMixin(DjDTCursorWrapperMixin): """ Wraps a cursor and raises an exception on any operation. Used in Templates panel. @@ -97,7 +104,7 @@ def __getattr__(self, attr): raise SQLQueryTriggered() -class NormalCursorWrapper(DjDTCursorWrapper): +class NormalCursorMixin(DjDTCursorWrapperMixin): """ Wraps a cursor and logs queries. """ diff --git a/docs/changes.rst b/docs/changes.rst index e06d4c615..5b7aaf1c1 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -15,6 +15,8 @@ Pending resolving to the wrong content type. * Fixed SQL statement recording under PostgreSQL for queries encoded as byte strings. +* Patch the ``CursorWrapper`` class with a mixin class to support multiple + base wrapper classes. 4.1.0 (2023-05-15) ------------------ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 2ab01758c..7a15d9aeb 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -30,6 +30,7 @@ memcache memcached middleware middlewares +mixin mousedown mouseup multi diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index d6b31ca2b..932a0dd92 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -2,12 +2,13 @@ import datetime import os import unittest -from unittest.mock import patch +from unittest.mock import call, patch import django from asgiref.sync import sync_to_async from django.contrib.auth.models import User from django.db import connection, transaction +from django.db.backends.utils import CursorDebugWrapper, CursorWrapper from django.db.models import Count from django.db.utils import DatabaseError from django.shortcuts import render @@ -68,39 +69,59 @@ def test_recording_chunked_cursor(self): self.assertEqual(len(self.panel._queries), 1) @patch( - "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", - wraps=sql_tracking.NormalCursorWrapper, + "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin", + wraps=sql_tracking.patch_cursor_wrapper_with_mixin, ) - def test_cursor_wrapper_singleton(self, mock_wrapper): + def test_cursor_wrapper_singleton(self, mock_patch_cursor_wrapper): sql_call() - # ensure that cursor wrapping is applied only once - self.assertEqual(mock_wrapper.call_count, 1) + self.assertIn( + mock_patch_cursor_wrapper.mock_calls, + [ + [call(CursorWrapper, sql_tracking.NormalCursorMixin)], + # CursorDebugWrapper is used if the test is called with `--debug-sql` + [call(CursorDebugWrapper, sql_tracking.NormalCursorMixin)], + ], + ) @patch( - "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", - wraps=sql_tracking.NormalCursorWrapper, + "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin", + wraps=sql_tracking.patch_cursor_wrapper_with_mixin, ) - def test_chunked_cursor_wrapper_singleton(self, mock_wrapper): + def test_chunked_cursor_wrapper_singleton(self, mock_patch_cursor_wrapper): sql_call(use_iterator=True) # ensure that cursor wrapping is applied only once - self.assertEqual(mock_wrapper.call_count, 1) + self.assertIn( + mock_patch_cursor_wrapper.mock_calls, + [ + [call(CursorWrapper, sql_tracking.NormalCursorMixin)], + # CursorDebugWrapper is used if the test is called with `--debug-sql` + [call(CursorDebugWrapper, sql_tracking.NormalCursorMixin)], + ], + ) @patch( - "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", - wraps=sql_tracking.NormalCursorWrapper, + "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin", + wraps=sql_tracking.patch_cursor_wrapper_with_mixin, ) - async def test_cursor_wrapper_async(self, mock_wrapper): + async def test_cursor_wrapper_async(self, mock_patch_cursor_wrapper): await sync_to_async(sql_call)() - self.assertEqual(mock_wrapper.call_count, 1) + self.assertIn( + mock_patch_cursor_wrapper.mock_calls, + [ + [call(CursorWrapper, sql_tracking.NormalCursorMixin)], + # CursorDebugWrapper is used if the test is called with `--debug-sql` + [call(CursorDebugWrapper, sql_tracking.NormalCursorMixin)], + ], + ) @patch( - "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", - wraps=sql_tracking.NormalCursorWrapper, + "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin", + wraps=sql_tracking.patch_cursor_wrapper_with_mixin, ) - async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper): + async def test_cursor_wrapper_asyncio_ctx(self, mock_patch_cursor_wrapper): self.assertTrue(sql_tracking.allow_sql.get()) await sync_to_async(sql_call)() @@ -116,7 +137,21 @@ async def task(): await asyncio.create_task(task()) # Because it was called in another context, it should not have affected ours self.assertTrue(sql_tracking.allow_sql.get()) - self.assertEqual(mock_wrapper.call_count, 1) + + self.assertIn( + mock_patch_cursor_wrapper.mock_calls, + [ + [ + call(CursorWrapper, sql_tracking.NormalCursorMixin), + call(CursorWrapper, sql_tracking.ExceptionCursorMixin), + ], + # CursorDebugWrapper is used if the test is called with `--debug-sql` + [ + call(CursorDebugWrapper, sql_tracking.NormalCursorMixin), + call(CursorDebugWrapper, sql_tracking.ExceptionCursorMixin), + ], + ], + ) def test_generate_server_timing(self): self.assertEqual(len(self.panel._queries), 0)