Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add set_update_hook on SqliteConnection #3260

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ impl EstablishParams {
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None
})
}
}
97 changes: 95 additions & 2 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cmp::Ordering;
use std::ffi::CStr;
use std::fmt::Write;
use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void};
Expand All @@ -8,7 +9,10 @@ use std::ptr::NonNull;
use futures_core::future::BoxFuture;
use futures_intrusive::sync::MutexGuard;
use futures_util::future;
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
use libsqlite3_sys::{
sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT,
SQLITE_UPDATE,
};

pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
Expand Down Expand Up @@ -58,6 +62,34 @@ pub struct LockedSqliteHandle<'a> {
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
unsafe impl Send for Handler {}

#[derive(Debug, PartialEq, Eq)]
pub enum SqliteOperation {
Insert,
Update,
Delete,
Unknown(i32),
}

impl From<i32> for SqliteOperation {
fn from(value: i32) -> Self {
match value {
SQLITE_INSERT => SqliteOperation::Insert,
SQLITE_UPDATE => SqliteOperation::Update,
SQLITE_DELETE => SqliteOperation::Delete,
code => SqliteOperation::Unknown(code),
}
}
}

pub struct UpdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
pub rowid: i64,
}
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}

pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

Expand All @@ -71,14 +103,25 @@ pub(crate) struct ConnectionState {
/// Stores the progress handler set on the current connection. If the handler returns `false`,
/// the query is interrupted.
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHookHandler>,
}

impl ConnectionState {
/// Drops the `progress_handler_callback` if it exists.
pub(crate) fn remove_progress_handler(&mut self) {
if let Some(mut handler) = self.progress_handler_callback.take() {
unsafe {
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}

pub(crate) fn remove_update_hook(&mut self) {
if let Some(mut handler) = self.update_hook_callback.take() {
unsafe {
sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
Expand Down Expand Up @@ -215,6 +258,31 @@ where
}
}

extern "C" fn update_hook<F>(
callback: *mut c_void,
op_code: c_int,
database: *const i8,
table: *const i8,
rowid: i64,
) where
F: FnMut(UpdateHookResult),
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
(*callback)(UpdateHookResult {
operation,
database,
table,
rowid,
})
});
}
}

impl LockedSqliteHandle<'_> {
/// Returns the underlying sqlite3* connection handle.
///
Expand Down Expand Up @@ -279,17 +347,42 @@ impl LockedSqliteHandle<'_> {
}
}

pub fn set_update_hook<F>(&mut self, callback: F)
where
F: FnMut(UpdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_update_hook();
self.guard.update_hook_callback = Some(UpdateHookHandler(callback));

sqlite3_update_hook(
self.as_raw_handle().as_mut(),
Some(update_hook::<F>),
handler,
);
}
}

/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
pub fn remove_progress_handler(&mut self) {
self.guard.remove_progress_handler();
}

pub fn remove_update_hook(&mut self) {
self.guard.remove_update_hook();
}
}

impl Drop for ConnectionState {
fn drop(&mut self) {
// explicitly drop statements before the connection handle is dropped
self.statements.clear();
self.remove_progress_handler();
self.remove_update_hook();
}
}

Expand Down
2 changes: 1 addition & 1 deletion sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::sync::atomic::AtomicBool;

pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use column::SqliteColumn;
pub use connection::{LockedSqliteHandle, SqliteConnection};
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
pub use database::Sqlite;
pub use error::SqliteError;
pub use options::{
Expand Down
56 changes: 55 additions & 1 deletion tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use futures::TryStreamExt;
use rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256PlusPlus;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
use sqlx::{
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
SqliteConnection, SqlitePool, Statement, TypeInfo,
Expand Down Expand Up @@ -794,3 +794,57 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}

#[sqlx_macros::test]
async fn test_query_with_update_hook() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;

// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
let state = format!("test");
conn.lock_handle().await?.set_update_hook(move |result| {
assert_eq!(state, "test");
assert_eq!(result.operation, SqliteOperation::Insert);
assert_eq!(result.database, "main");
assert_eq!(result.table, "tweet");
assert_eq!(result.rowid, 3);
});

let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
.execute(&mut conn)
.await?;

Ok(())
}

#[sqlx_macros::test]
async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Result<()> {
let ref_counted_object = Arc::new(0);
assert_eq!(1, Arc::strong_count(&ref_counted_object));

{
let mut conn = new::<Sqlite>().await?;

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

conn.lock_handle().await?.remove_update_hook();
}

assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}
Loading