Skip to content

Commit

Permalink
bpo-45138: Expand traced SQL statements in sqlite3 trace callback (G…
Browse files Browse the repository at this point in the history
  • Loading branch information
Erlend Egeberg Aasland committed Mar 9, 2022
1 parent b33a1ae commit d177751
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 15 deletions.
6 changes: 6 additions & 0 deletions Doc/library/sqlite3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ Connection Objects

Passing :const:`None` as *trace_callback* will disable the trace callback.

For SQLite 3.14.0 and newer, bound parameters are expanded in the passed
statement string.

.. note::
Exceptions raised in the trace callback are not propagated. As a
development and debugging aid, use
Expand All @@ -568,6 +571,9 @@ Connection Objects

.. versionadded:: 3.3

.. versionchanged:: 3.11
Added support for expanded SQL statements.


.. method:: enable_load_extension(enabled)

Expand Down
4 changes: 4 additions & 0 deletions Doc/whatsnew/3.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ sqlite3
Instead we leave it to the SQLite library to handle these cases.
(Contributed by Erlend E. Aasland in :issue:`44092`.)

* For SQLite 3.14.0 and newer, bound parameters are expanded in the statement
string passed to the trace callback. See :meth:`~sqlite3.Connection.set_trace_callback`.
(Contributed by Erlend E. Aasland in :issue:`45138`.)


sys
---
Expand Down
61 changes: 60 additions & 1 deletion Lib/test/test_sqlite3/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.

import unittest
import contextlib
import sqlite3 as sqlite
import unittest

from test.support.os_helper import TESTFN, unlink

from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from test.test_sqlite3.test_userfunctions import with_tracebacks


class CollationTests(unittest.TestCase):
def test_create_collation_not_string(self):
con = sqlite.connect(":memory:")
Expand Down Expand Up @@ -224,6 +228,16 @@ def bad_progress():


class TraceCallbackTests(unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
traced = []
cx.set_trace_callback(lambda stmt: traced.append(stmt))
yield
finally:
self.assertEqual(traced, expected)
cx.set_trace_callback(None)

def test_trace_callback_used(self):
"""
Test that the trace callback is invoked once it is set.
Expand Down Expand Up @@ -289,6 +303,51 @@ def trace(statement):
con2.close()
self.assertEqual(traced_statements, queries)

@unittest.skipIf(sqlite.sqlite_version_info < (3, 14, 0),
"Requires SQLite 3.14.0 or newer")
def test_trace_expanded_sql(self):
expected = [
"create table t(t)",
"BEGIN ",
"insert into t values(0)",
"insert into t values(1)",
"insert into t values(2)",
"COMMIT",
]
with memory_database() as cx, self.check_stmt_trace(cx, expected):
with cx:
cx.execute("create table t(t)")
cx.executemany("insert into t values(?)", ((v,) for v in range(3)))

@with_tracebacks(
sqlite.DataError,
regex="Expanded SQL string exceeds the maximum string length"
)
def test_trace_too_much_expanded_sql(self):
# If the expanded string is too large, we'll fall back to the
# unexpanded SQL statement. The resulting string length is limited by
# SQLITE_LIMIT_LENGTH.
template = "select 'b' as \"a\" from sqlite_master where \"a\"="
category = sqlite.SQLITE_LIMIT_LENGTH
with memory_database() as cx, cx_limit(cx, category=category) as lim:
nextra = lim - (len(template) + 2) - 1
ok_param = "a" * nextra
bad_param = "a" * (nextra + 1)

unexpanded_query = template + "?"
with self.check_stmt_trace(cx, [unexpanded_query]):
cx.execute(unexpanded_query, (bad_param,))

expanded_query = f"{template}'{ok_param}'"
with self.check_stmt_trace(cx, [expanded_query]):
cx.execute(unexpanded_query, (ok_param,))

@with_tracebacks(ZeroDivisionError, regex="division by zero")
def test_trace_bad_handler(self):
with memory_database() as cx:
cx.set_trace_callback(lambda stmt: 5/0)
cx.execute("select 1")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
For SQLite 3.14.0 and newer, bound parameters are expanded in the statement
string passed to the :mod:`sqlite3` trace callback. Patch by Erlend E.
Aasland.
49 changes: 35 additions & 14 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,10 @@ progress_callback(void *ctx)
* to ensure future compatibility.
*/
static int
trace_callback(unsigned int type, void *ctx, void *prepared_statement,
void *statement_string)
trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
#else
static void
trace_callback(void *ctx, const char *statement_string)
trace_callback(void *ctx, const char *sql)
#endif
{
#ifdef HAVE_TRACE_V2
Expand All @@ -1094,24 +1093,46 @@ trace_callback(void *ctx, const char *statement_string)

PyGILState_STATE gilstate = PyGILState_Ensure();

PyObject *py_statement = NULL;
PyObject *ret = NULL;
py_statement = PyUnicode_DecodeUTF8(statement_string,
strlen(statement_string), "replace");
assert(ctx != NULL);
PyObject *py_statement = NULL;
#ifdef HAVE_TRACE_V2
assert(stmt != NULL);
const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt);
if (expanded_sql == NULL) {
sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt);
if (sqlite3_errcode(db) == SQLITE_NOMEM) {
(void)PyErr_NoMemory();
goto exit;
}

pysqlite_state *state = ((callback_context *)ctx)->state;
assert(state != NULL);
PyErr_SetString(state->DataError,
"Expanded SQL string exceeds the maximum string "
"length");
print_or_clear_traceback((callback_context *)ctx);

// Fall back to unexpanded sql
py_statement = PyUnicode_FromString((const char *)sql);
}
else {
py_statement = PyUnicode_FromString(expanded_sql);
sqlite3_free((void *)expanded_sql);
}
#else
py_statement = PyUnicode_FromString(sql);
#endif
if (py_statement) {
PyObject *callable = ((callback_context *)ctx)->callable;
ret = PyObject_CallOneArg(callable, py_statement);
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}

if (ret) {
Py_DECREF(ret);
}
else {
print_or_clear_traceback(ctx);
exit:
if (PyErr_Occurred()) {
print_or_clear_traceback((callback_context *)ctx);
}

PyGILState_Release(gilstate);
#ifdef HAVE_TRACE_V2
return 0;
Expand Down

0 comments on commit d177751

Please sign in to comment.