From 608d00c3ccff5fa1f04bba7c4da74893f4cdf856 Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Fri, 3 Mar 2023 13:19:23 +0100 Subject: [PATCH 1/6] rebase main --- sqlx-sqlite/src/connection/mod.rs | 57 ++++++++++++++++++++++++++++++- tests/sqlite/sqlite.rs | 16 +++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 2ea1b66ed9..ce85adbd15 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,13 +1,14 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; -use libsqlite3_sys::sqlite3; +use libsqlite3_sys::{sqlite3, sqlite3_progress_handler}; use sqlx_core::common::StatementCache; use sqlx_core::error::Error; use sqlx_core::transaction::Transaction; use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; use std::ptr::NonNull; +use std::os::raw::{c_int, c_void}; use crate::connection::establish::EstablishParams; use crate::connection::worker::ConnectionWorker; @@ -89,6 +90,45 @@ impl SqliteConnection { Ok(LockedSqliteHandle { guard }) } + + /// Sets a progress handler that is invoked periodically during long running calls. If the progress callback + /// returns `false`, then the operation is interrupted. + /// + /// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html) + /// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the + /// progress handler is disabled. + /// + /// Only a single progress handler may be defined at one time per database connection; setting a new progress + /// handler cancels the old one. + /// + /// The progress handler callback must not do anything that will modify the database connection that invoked + /// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections + /// in this context. + pub async fn set_progress_handler(&mut self, num_ops: i32, callback: F) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe { + let callback = Box::new(callback); + if let Ok(mut lock_conn) = self.lock_handle().await { + sqlite3_progress_handler( + lock_conn.as_raw_handle().as_mut(), + num_ops, + Some(progress_callback::), + &*callback as *const F as *mut F as *mut _, + ); + } + } + } + + /// Removes a previously set progress handler on a database connection. + pub async fn remove_progress_handler(&mut self) { + unsafe { + if let Ok(mut lock_conn) = self.lock_handle().await { + sqlite3_progress_handler(lock_conn.as_raw_handle().as_mut(), 0, None, 0 as *mut _); + } + } + } } impl Debug for SqliteConnection { @@ -172,6 +212,21 @@ impl Connection for SqliteConnection { } } +/// Implements a C binding to a progress callback. The function returns `0` if the +/// user-provided callback returns `true`, and `1` otherwise to signal an interrupt. +extern "C" fn progress_callback(callback: *mut c_void) -> c_int +where + F: FnMut() -> bool, +{ + unsafe { + if (*(callback as *mut F))() { + 0 + } else { + 1 + } + } +} + impl LockedSqliteHandle<'_> { /// Returns the underlying sqlite3* connection handle. /// diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 03f2013eb4..fd2b715ca8 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -725,3 +725,19 @@ async fn concurrent_read_and_write() { read.await; write.await; } + +#[sqlx_macros::test] +async fn test_query_with_progress_handler() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.set_progress_handler(1, || false).await; + + match sqlx::query("SELECT 'hello' AS title") + .fetch_all(&mut conn) + .await + { + Err(sqlx::Error::Database(err)) => assert_eq!(err.message(), String::from("interrupted")), + _ => panic!("expected an interrupt"), + } + + Ok(()) +} From 866131af1f11ff208494d2f0e3fb0020a4da685e Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Sat, 4 Mar 2023 13:31:50 +0100 Subject: [PATCH 2/6] fmt --- sqlx-sqlite/src/connection/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index ce85adbd15..e19432a232 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -7,8 +7,8 @@ use sqlx_core::error::Error; use sqlx_core::transaction::Transaction; use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; -use std::ptr::NonNull; use std::os::raw::{c_int, c_void}; +use std::ptr::NonNull; use crate::connection::establish::EstablishParams; use crate::connection::worker::ConnectionWorker; From 765f820db21c730a7c0d07d8fedd4906451fabba Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Wed, 15 Mar 2023 14:25:11 +0100 Subject: [PATCH 3/6] use NonNull to fix UB --- sqlx-sqlite/src/connection/establish.rs | 1 + sqlx-sqlite/src/connection/mod.rs | 109 ++++++++++++++---------- tests/sqlite/sqlite.rs | 5 +- 3 files changed, 70 insertions(+), 45 deletions(-) diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index c5425dd19b..91a3aff05e 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -282,6 +282,7 @@ impl EstablishParams { statements: Statements::new(self.statement_cache_capacity), transaction_depth: 0, log_settings: self.log_settings.clone(), + progress_handler_callback: None, }) } } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index e19432a232..9ec833bf83 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -8,6 +8,7 @@ use sqlx_core::transaction::Transaction; use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; use std::os::raw::{c_int, c_void}; +use std::panic::catch_unwind; use std::ptr::NonNull; use crate::connection::establish::EstablishParams; @@ -52,6 +53,11 @@ pub struct LockedSqliteHandle<'a> { pub(crate) guard: MutexGuard<'a, ConnectionState>, } +/// Represents a callback handler that will be shared with the underlying sqlite3 connection. +pub(crate) struct Handler(NonNull bool>); +unsafe impl Send for Handler {} +unsafe impl Sync for Handler {} + pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -61,6 +67,19 @@ pub(crate) struct ConnectionState { pub(crate) statements: Statements, log_settings: LogSettings, + + /// Stores the progress handler set on the current connection. If the handler returns `false`, + /// the query is interrupted. + progress_handler_callback: Option, +} + +impl ConnectionState { + /// Drops the `progress_handler_callback` if it exists. + pub(crate) fn drop_progress_handler_callback(&mut self) { + if let Some(mut handler) = self.progress_handler_callback.take() { + let _ = unsafe { Box::from_raw(handler.0.as_mut()) }; + } + } } pub(crate) struct Statements { @@ -90,45 +109,6 @@ impl SqliteConnection { Ok(LockedSqliteHandle { guard }) } - - /// Sets a progress handler that is invoked periodically during long running calls. If the progress callback - /// returns `false`, then the operation is interrupted. - /// - /// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html) - /// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the - /// progress handler is disabled. - /// - /// Only a single progress handler may be defined at one time per database connection; setting a new progress - /// handler cancels the old one. - /// - /// The progress handler callback must not do anything that will modify the database connection that invoked - /// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections - /// in this context. - pub async fn set_progress_handler(&mut self, num_ops: i32, callback: F) - where - F: FnMut() -> bool + Send + 'static, - { - unsafe { - let callback = Box::new(callback); - if let Ok(mut lock_conn) = self.lock_handle().await { - sqlite3_progress_handler( - lock_conn.as_raw_handle().as_mut(), - num_ops, - Some(progress_callback::), - &*callback as *const F as *mut F as *mut _, - ); - } - } - } - - /// Removes a previously set progress handler on a database connection. - pub async fn remove_progress_handler(&mut self) { - unsafe { - if let Ok(mut lock_conn) = self.lock_handle().await { - sqlite3_progress_handler(lock_conn.as_raw_handle().as_mut(), 0, None, 0 as *mut _); - } - } - } } impl Debug for SqliteConnection { @@ -219,11 +199,11 @@ where F: FnMut() -> bool, { unsafe { - if (*(callback as *mut F))() { - 0 - } else { - 1 - } + let r = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + (*callback)() + }); + c_int::from(!r.unwrap_or_default()) } } @@ -256,12 +236,53 @@ impl LockedSqliteHandle<'_> { ) -> Result<(), Error> { collation::create_collation(&mut self.guard.handle, name, compare) } + + /// Sets a progress handler that is invoked periodically during long running calls. If the progress callback + /// returns `false`, then the operation is interrupted. + /// + /// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html) + /// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the + /// progress handler is disabled. + /// + /// Only a single progress handler may be defined at one time per database connection; setting a new progress + /// handler cancels the old one. + /// + /// The progress handler callback must not do anything that will modify the database connection that invoked + /// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections + /// in this context. + pub async fn set_progress_handler(&mut self, num_ops: i32, mut callback: F) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe { + let callback = NonNull::new_unchecked(&mut callback as *mut _); + let handler = callback.as_ptr() as *mut _; + self.guard.drop_progress_handler_callback(); + self.guard.progress_handler_callback = Some(Handler(callback)); + + sqlite3_progress_handler( + self.as_raw_handle().as_mut(), + num_ops, + Some(progress_callback::), + handler, + ); + } + } + + /// Removes a previously set progress handler on a database connection. + pub async fn remove_progress_handler(&mut self) { + unsafe { + sqlite3_progress_handler(self.as_raw_handle().as_mut(), 0, None, 0 as *mut _); + self.guard.drop_progress_handler_callback(); + } + } } impl Drop for ConnectionState { fn drop(&mut self) { // explicitly drop statements before the connection handle is dropped self.statements.clear(); + self.drop_progress_handler_callback(); } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index fd2b715ca8..0f84293e8b 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -729,7 +729,10 @@ async fn concurrent_read_and_write() { #[sqlx_macros::test] async fn test_query_with_progress_handler() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.set_progress_handler(1, || false).await; + conn.lock_handle() + .await? + .set_progress_handler(1, || false) + .await; match sqlx::query("SELECT 'hello' AS title") .fetch_all(&mut conn) From 1b5382457c540fae4d5632d61c9f511129841f94 Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Thu, 16 Mar 2023 09:38:42 +0100 Subject: [PATCH 4/6] apply code suggestions --- sqlx-sqlite/src/connection/mod.rs | 28 ++++++++++++---------------- tests/sqlite/sqlite.rs | 11 +++++++---- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 9ec833bf83..cb4609a4a2 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -54,9 +54,8 @@ pub struct LockedSqliteHandle<'a> { } /// Represents a callback handler that will be shared with the underlying sqlite3 connection. -pub(crate) struct Handler(NonNull bool>); +pub(crate) struct Handler(NonNull bool + Send>); unsafe impl Send for Handler {} -unsafe impl Sync for Handler {} pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -75,9 +74,12 @@ pub(crate) struct ConnectionState { impl ConnectionState { /// Drops the `progress_handler_callback` if it exists. - pub(crate) fn drop_progress_handler_callback(&mut self) { + pub(crate) fn remove_progress_handler(&mut self) { if let Some(mut handler) = self.progress_handler_callback.take() { - let _ = unsafe { Box::from_raw(handler.0.as_mut()) }; + unsafe { + sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } } } } @@ -250,14 +252,16 @@ impl LockedSqliteHandle<'_> { /// The progress handler callback must not do anything that will modify the database connection that invoked /// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections /// in this context. - pub async fn set_progress_handler(&mut self, num_ops: i32, mut callback: F) + pub fn set_progress_handler(&mut self, num_ops: i32, mut callback: F) where F: FnMut() -> bool + Send + 'static, { unsafe { - let callback = NonNull::new_unchecked(&mut callback as *mut _); + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); let handler = callback.as_ptr() as *mut _; - self.guard.drop_progress_handler_callback(); + self.guard.remove_progress_handler(); self.guard.progress_handler_callback = Some(Handler(callback)); sqlite3_progress_handler( @@ -268,21 +272,13 @@ impl LockedSqliteHandle<'_> { ); } } - - /// Removes a previously set progress handler on a database connection. - pub async fn remove_progress_handler(&mut self) { - unsafe { - sqlite3_progress_handler(self.as_raw_handle().as_mut(), 0, None, 0 as *mut _); - self.guard.drop_progress_handler_callback(); - } - } } impl Drop for ConnectionState { fn drop(&mut self) { // explicitly drop statements before the connection handle is dropped self.statements.clear(); - self.drop_progress_handler_callback(); + self.remove_progress_handler(); } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 0f84293e8b..b3ee896ff8 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -729,10 +729,13 @@ async fn concurrent_read_and_write() { #[sqlx_macros::test] async fn test_query_with_progress_handler() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.lock_handle() - .await? - .set_progress_handler(1, || false) - .await; + + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_progress_handler(1, move || { + assert_eq!(state, "test"); + false + }); match sqlx::query("SELECT 'hello' AS title") .fetch_all(&mut conn) From 9501f73baaeb36c711f4c4e3e0e78e258be312e2 Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Thu, 16 Mar 2023 10:10:25 +0100 Subject: [PATCH 5/6] add test for multiple handler drops --- sqlx-sqlite/src/connection/mod.rs | 5 +++ tests/sqlite/sqlite.rs | 62 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index cb4609a4a2..66253f19db 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -272,6 +272,11 @@ impl LockedSqliteHandle<'_> { ); } } + + /// Removes the progress handler on a database connection. The method does nothing if no handler was set. + pub fn remove_progress_handler(&mut self) { + self.guard.remove_progress_handler(); + } } impl Drop for ConnectionState { diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b3ee896ff8..c3a9beb91a 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,3 +1,6 @@ +#![feature(unboxed_closures)] +#![feature(fn_traits)] + use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; @@ -7,6 +10,7 @@ use sqlx::{ SqliteConnection, SqlitePool, Statement, TypeInfo, }; use sqlx_test::new; +use std::sync::atomic::{AtomicUsize, Ordering}; #[sqlx_macros::test] async fn it_connects() -> anyhow::Result<()> { @@ -747,3 +751,61 @@ async fn test_query_with_progress_handler() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::Result<()> { + static OBJECTS_DROPPED: AtomicUsize = AtomicUsize::new(0); + + struct Handler(pub &'static str); + impl FnOnce<()> for Handler { + type Output = bool; + + extern "rust-call" fn call_once(mut self, args: ()) -> bool { + self.call_mut(args) + } + } + impl FnMut<()> for Handler { + extern "rust-call" fn call_mut(&mut self, _args: ()) -> bool { + assert_eq!(3, self.0.len()); + false + } + } + impl Drop for Handler { + fn drop(&mut self) { + OBJECTS_DROPPED.fetch_add(1, Ordering::Relaxed); + } + } + + { + let mut conn = new::().await?; + + conn.lock_handle() + .await? + .set_progress_handler(1, Handler("foo")); + conn.lock_handle() + .await? + .set_progress_handler(1, Handler("bar")); + conn.lock_handle() + .await? + .set_progress_handler(1, Handler("baz")); + + match sqlx::query("SELECT 'hello' AS title") + .fetch_all(&mut conn) + .await + { + Err(sqlx::Error::Database(err)) => { + assert_eq!(err.message(), String::from("interrupted")) + } + _ => panic!("expected an interrupt"), + } + + conn.lock_handle().await?.remove_progress_handler(); + } + + assert_eq!( + 3, + OBJECTS_DROPPED.load(Ordering::Relaxed), + "expected all handlers to be dropped" + ); + Ok(()) +} From 459a9f339aba237491bf8513881ad4aa4fa3d66c Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Wed, 22 Mar 2023 12:21:29 +0100 Subject: [PATCH 6/6] remove nightly features for test --- sqlx-sqlite/src/connection/mod.rs | 2 +- tests/sqlite/sqlite.rs | 61 ++++++++++++------------------- 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 66253f19db..36aa345431 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -54,7 +54,7 @@ pub struct LockedSqliteHandle<'a> { } /// Represents a callback handler that will be shared with the underlying sqlite3 connection. -pub(crate) struct Handler(NonNull bool + Send>); +pub(crate) struct Handler(NonNull bool + Send + 'static>); unsafe impl Send for Handler {} pub(crate) struct ConnectionState { diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index c3a9beb91a..0c79bec1f3 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,6 +1,3 @@ -#![feature(unboxed_closures)] -#![feature(fn_traits)] - use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; @@ -10,7 +7,7 @@ use sqlx::{ SqliteConnection, SqlitePool, Statement, TypeInfo, }; use sqlx_test::new; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; #[sqlx_macros::test] async fn it_connects() -> anyhow::Result<()> { @@ -754,40 +751,32 @@ async fn test_query_with_progress_handler() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::Result<()> { - static OBJECTS_DROPPED: AtomicUsize = AtomicUsize::new(0); + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); - struct Handler(pub &'static str); - impl FnOnce<()> for Handler { - type Output = bool; + { + let mut conn = new::().await?; - extern "rust-call" fn call_once(mut self, args: ()) -> bool { - self.call_mut(args) - } - } - impl FnMut<()> for Handler { - extern "rust-call" fn call_mut(&mut self, _args: ()) -> bool { - assert_eq!(3, self.0.len()); + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_progress_handler(1, move || { + println!("{:?}", o); false - } - } - impl Drop for Handler { - fn drop(&mut self) { - OBJECTS_DROPPED.fetch_add(1, Ordering::Relaxed); - } - } + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); - { - let mut conn = new::().await?; + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_progress_handler(1, move || { + println!("{:?}", o); + false + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); - conn.lock_handle() - .await? - .set_progress_handler(1, Handler("foo")); - conn.lock_handle() - .await? - .set_progress_handler(1, Handler("bar")); - conn.lock_handle() - .await? - .set_progress_handler(1, Handler("baz")); + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_progress_handler(1, move || { + println!("{:?}", o); + false + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); match sqlx::query("SELECT 'hello' AS title") .fetch_all(&mut conn) @@ -802,10 +791,6 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow:: conn.lock_handle().await?.remove_progress_handler(); } - assert_eq!( - 3, - OBJECTS_DROPPED.load(Ordering::Relaxed), - "expected all handlers to be dropped" - ); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) }