Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch CursorWrapper dynamically to allow multiple base classes. #1820

Merged
merged 4 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from time import perf_counter

import django.test.testcases
from django.db.backends.utils import CursorWrapper
from django.utils.encoding import force_str

from debug_toolbar.utils import get_stack_trace, get_template_info
Expand Down Expand Up @@ -60,34 +59,45 @@ def cursor(*args, **kwargs):
cursor = connection._djdt_cursor(*args, **kwargs)
if logger is None:
return cursor
wrapper = NormalCursorWrapper if allow_sql.get() else ExceptionCursorWrapper
return wrapper(cursor.cursor, connection, logger)
mixin = NormalCursorWrapper if allow_sql.get() else ExceptionCursorWrapper
return patch_cursor_wrapper_with_mixin(cursor.__class__, mixin)(
cursor.cursor, connection, logger
)

def chunked_cursor(*args, **kwargs):
# prevent double wrapping
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
logger = connection._djdt_logger
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
if logger is not None and not isinstance(cursor, DjDTCursorWrapper):
if logger is not None and not isinstance(cursor, DjDTCursorWrapperMixin):
if allow_sql.get():
wrapper = NormalCursorWrapper
mixin = NormalCursorWrapper
else:
wrapper = ExceptionCursorWrapper
return wrapper(cursor.cursor, connection, logger)
mixin = ExceptionCursorWrapper
return patch_cursor_wrapper_with_mixin(cursor.__class__, mixin)(
cursor.cursor, connection, logger
)
return cursor

connection.cursor = cursor
connection.chunked_cursor = chunked_cursor


class DjDTCursorWrapper(CursorWrapper):
def patch_cursor_wrapper_with_mixin(base_wrapper, mixin):
class DjDTCursorWrapper(mixin, base_wrapper):
pass

return DjDTCursorWrapper


class DjDTCursorWrapperMixin:
def __init__(self, cursor, db, logger):
super().__init__(cursor, db)
# logger must implement a ``record`` method
self.logger = logger


class ExceptionCursorWrapper(DjDTCursorWrapper):
class ExceptionCursorWrapper(DjDTCursorWrapperMixin):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
Expand All @@ -97,7 +107,7 @@ def __getattr__(self, attr):
raise SQLQueryTriggered()


class NormalCursorWrapper(DjDTCursorWrapper):
class NormalCursorWrapper(DjDTCursorWrapperMixin):
"""
Wraps a cursor and logs queries.
"""
Expand Down
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Pending
resolving to the wrong content type.
* Fixed SQL statement recording under PostgreSQL for queries encoded as byte
strings.
* Patch the ``CursorWrapper`` class with a mixin class to support multiple
base wrapper classes.

4.1.0 (2023-05-15)
------------------
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ memcache
memcached
middleware
middlewares
mixin
mousedown
mouseup
multi
Expand Down
48 changes: 30 additions & 18 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import datetime
import os
import unittest
from unittest.mock import patch
from unittest.mock import call, patch

import django
from asgiref.sync import sync_to_async
from django.contrib.auth.models import User
from django.db import connection, transaction
from django.db.backends.utils import CursorWrapper
from django.db.models import Count
from django.db.utils import DatabaseError
from django.shortcuts import render
Expand Down Expand Up @@ -68,39 +69,44 @@ def test_recording_chunked_cursor(self):
self.assertEqual(len(self.panel._queries), 1)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
def test_cursor_wrapper_singleton(self, mock_wrapper):
def test_cursor_wrapper_singleton(self, mock_patch_cursor_wrapper):
sql_call()

# ensure that cursor wrapping is applied only once
self.assertEqual(mock_wrapper.call_count, 1)
mock_patch_cursor_wrapper.assert_called_once_with(
CursorWrapper, sql_tracking.NormalCursorWrapper
)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
def test_chunked_cursor_wrapper_singleton(self, mock_patch_cursor_wrapper):
sql_call(use_iterator=True)

# ensure that cursor wrapping is applied only once
self.assertEqual(mock_wrapper.call_count, 1)
mock_patch_cursor_wrapper.assert_called_once_with(
CursorWrapper, sql_tracking.NormalCursorWrapper
)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
async def test_cursor_wrapper_async(self, mock_wrapper):
async def test_cursor_wrapper_async(self, mock_patch_cursor_wrapper):
await sync_to_async(sql_call)()

self.assertEqual(mock_wrapper.call_count, 1)
mock_patch_cursor_wrapper.assert_called_once_with(
CursorWrapper, sql_tracking.NormalCursorWrapper
)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper):
async def test_cursor_wrapper_asyncio_ctx(self, mock_patch_cursor_wrapper):
self.assertTrue(sql_tracking.allow_sql.get())
await sync_to_async(sql_call)()

Expand All @@ -116,7 +122,13 @@ async def task():
await asyncio.create_task(task())
# Because it was called in another context, it should not have affected ours
self.assertTrue(sql_tracking.allow_sql.get())
self.assertEqual(mock_wrapper.call_count, 1)
self.assertEqual(
mock_patch_cursor_wrapper.call_args_list,
[
call(CursorWrapper, sql_tracking.NormalCursorWrapper),
call(CursorWrapper, sql_tracking.ExceptionCursorWrapper),
],
)

def test_generate_server_timing(self):
self.assertEqual(len(self.panel._queries), 0)
Expand Down