Skip to content

Commit

Permalink
feat: Add set_update_hook on SqliteConnection (launchbadge#3260)
Browse files Browse the repository at this point in the history
* feat: Add set_update_hook on SqliteConnection

* refactor: Address PR comments

* fix: Expose UpdateHookResult for public use

---------

Co-authored-by: John Smith <asserta4@gmail.com>
  • Loading branch information
gridbox and John Smith authored Jun 6, 2024
1 parent 8b7f352 commit 0ea9088
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 4 deletions.
1 change: 1 addition & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ impl EstablishParams {
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None
})
}
}
97 changes: 95 additions & 2 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cmp::Ordering;
use std::ffi::CStr;
use std::fmt::Write;
use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void};
Expand All @@ -8,7 +9,10 @@ use std::ptr::NonNull;
use futures_core::future::BoxFuture;
use futures_intrusive::sync::MutexGuard;
use futures_util::future;
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
use libsqlite3_sys::{
sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT,
SQLITE_UPDATE,
};

pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
Expand Down Expand Up @@ -58,6 +62,34 @@ pub struct LockedSqliteHandle<'a> {
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
unsafe impl Send for Handler {}

#[derive(Debug, PartialEq, Eq)]
pub enum SqliteOperation {
Insert,
Update,
Delete,
Unknown(i32),
}

impl From<i32> for SqliteOperation {
fn from(value: i32) -> Self {
match value {
SQLITE_INSERT => SqliteOperation::Insert,
SQLITE_UPDATE => SqliteOperation::Update,
SQLITE_DELETE => SqliteOperation::Delete,
code => SqliteOperation::Unknown(code),
}
}
}

pub struct UpdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
pub rowid: i64,
}
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}

pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

Expand All @@ -71,14 +103,25 @@ pub(crate) struct ConnectionState {
/// Stores the progress handler set on the current connection. If the handler returns `false`,
/// the query is interrupted.
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHookHandler>,
}

impl ConnectionState {
/// Drops the `progress_handler_callback` if it exists.
pub(crate) fn remove_progress_handler(&mut self) {
if let Some(mut handler) = self.progress_handler_callback.take() {
unsafe {
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}

pub(crate) fn remove_update_hook(&mut self) {
if let Some(mut handler) = self.update_hook_callback.take() {
unsafe {
sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
Expand Down Expand Up @@ -215,6 +258,31 @@ where
}
}

extern "C" fn update_hook<F>(
callback: *mut c_void,
op_code: c_int,
database: *const i8,
table: *const i8,
rowid: i64,
) where
F: FnMut(UpdateHookResult),
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
(*callback)(UpdateHookResult {
operation,
database,
table,
rowid,
})
});
}
}

impl LockedSqliteHandle<'_> {
/// Returns the underlying sqlite3* connection handle.
///
Expand Down Expand Up @@ -279,17 +347,42 @@ impl LockedSqliteHandle<'_> {
}
}

pub fn set_update_hook<F>(&mut self, callback: F)
where
F: FnMut(UpdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_update_hook();
self.guard.update_hook_callback = Some(UpdateHookHandler(callback));

sqlite3_update_hook(
self.as_raw_handle().as_mut(),
Some(update_hook::<F>),
handler,
);
}
}

/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
pub fn remove_progress_handler(&mut self) {
self.guard.remove_progress_handler();
}

pub fn remove_update_hook(&mut self) {
self.guard.remove_update_hook();
}
}

impl Drop for ConnectionState {
fn drop(&mut self) {
// explicitly drop statements before the connection handle is dropped
self.statements.clear();
self.remove_progress_handler();
self.remove_update_hook();
}
}

Expand Down
2 changes: 1 addition & 1 deletion sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::sync::atomic::AtomicBool;

pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use column::SqliteColumn;
pub use connection::{LockedSqliteHandle, SqliteConnection};
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
pub use database::Sqlite;
pub use error::SqliteError;
pub use options::{
Expand Down
56 changes: 55 additions & 1 deletion tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use futures::TryStreamExt;
use rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256PlusPlus;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
use sqlx::{
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
SqliteConnection, SqlitePool, Statement, TypeInfo,
Expand Down Expand Up @@ -794,3 +794,57 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}

#[sqlx_macros::test]
async fn test_query_with_update_hook() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;

// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
let state = format!("test");
conn.lock_handle().await?.set_update_hook(move |result| {
assert_eq!(state, "test");
assert_eq!(result.operation, SqliteOperation::Insert);
assert_eq!(result.database, "main");
assert_eq!(result.table, "tweet");
assert_eq!(result.rowid, 3);
});

let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
.execute(&mut conn)
.await?;

Ok(())
}

#[sqlx_macros::test]
async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Result<()> {
let ref_counted_object = Arc::new(0);
assert_eq!(1, Arc::strong_count(&ref_counted_object));

{
let mut conn = new::<Sqlite>().await?;

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

conn.lock_handle().await?.remove_update_hook();
}

assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}

0 comments on commit 0ea9088

Please sign in to comment.