From 608d00c3ccff5fa1f04bba7c4da74893f4cdf856 Mon Sep 17 00:00:00 2001 From: Nisheeth Barthwal Date: Fri, 3 Mar 2023 13:19:23 +0100 Subject: [PATCH] rebase main --- sqlx-sqlite/src/connection/mod.rs | 57 ++++++++++++++++++++++++++++++- tests/sqlite/sqlite.rs | 16 +++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 2ea1b66ed9..ce85adbd15 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,13 +1,14 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; -use libsqlite3_sys::sqlite3; +use libsqlite3_sys::{sqlite3, sqlite3_progress_handler}; use sqlx_core::common::StatementCache; use sqlx_core::error::Error; use sqlx_core::transaction::Transaction; use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; use std::ptr::NonNull; +use std::os::raw::{c_int, c_void}; use crate::connection::establish::EstablishParams; use crate::connection::worker::ConnectionWorker; @@ -89,6 +90,45 @@ impl SqliteConnection { Ok(LockedSqliteHandle { guard }) } + + /// Sets a progress handler that is invoked periodically during long running calls. If the progress callback + /// returns `false`, then the operation is interrupted. + /// + /// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html) + /// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the + /// progress handler is disabled. + /// + /// Only a single progress handler may be defined at one time per database connection; setting a new progress + /// handler cancels the old one. + /// + /// The progress handler callback must not do anything that will modify the database connection that invoked + /// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections + /// in this context. + pub async fn set_progress_handler(&mut self, num_ops: i32, callback: F) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe { + let callback = Box::new(callback); + if let Ok(mut lock_conn) = self.lock_handle().await { + sqlite3_progress_handler( + lock_conn.as_raw_handle().as_mut(), + num_ops, + Some(progress_callback::), + &*callback as *const F as *mut F as *mut _, + ); + } + } + } + + /// Removes a previously set progress handler on a database connection. + pub async fn remove_progress_handler(&mut self) { + unsafe { + if let Ok(mut lock_conn) = self.lock_handle().await { + sqlite3_progress_handler(lock_conn.as_raw_handle().as_mut(), 0, None, 0 as *mut _); + } + } + } } impl Debug for SqliteConnection { @@ -172,6 +212,21 @@ impl Connection for SqliteConnection { } } +/// Implements a C binding to a progress callback. The function returns `0` if the +/// user-provided callback returns `true`, and `1` otherwise to signal an interrupt. +extern "C" fn progress_callback(callback: *mut c_void) -> c_int +where + F: FnMut() -> bool, +{ + unsafe { + if (*(callback as *mut F))() { + 0 + } else { + 1 + } + } +} + impl LockedSqliteHandle<'_> { /// Returns the underlying sqlite3* connection handle. /// diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 03f2013eb4..fd2b715ca8 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -725,3 +725,19 @@ async fn concurrent_read_and_write() { read.await; write.await; } + +#[sqlx_macros::test] +async fn test_query_with_progress_handler() -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.set_progress_handler(1, || false).await; + + match sqlx::query("SELECT 'hello' AS title") + .fetch_all(&mut conn) + .await + { + Err(sqlx::Error::Database(err)) => assert_eq!(err.message(), String::from("interrupted")), + _ => panic!("expected an interrupt"), + } + + Ok(()) +}