diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index cbdeba030e..8b5b9fc772 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -1,21 +1,25 @@ +use either::Either; use std::ffi::c_void; use std::ffi::CStr; +use std::hint; use std::os::raw::{c_char, c_int}; use std::ptr; use std::ptr::NonNull; use std::slice::from_raw_parts; use std::str::{from_utf8, from_utf8_unchecked}; +use std::sync::atomic::{AtomicU8, Ordering}; use libsqlite3_sys::{ sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, - sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes, - sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype, - sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, - sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type, - sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset, sqlite3_sql, - sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, - SQLITE_MISUSE, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob, + sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name, + sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, + sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, + sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset, + sqlite3_sql, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, + sqlite3_value, SQLITE_DONE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT, + SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; @@ -23,7 +27,7 @@ use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteError, SqliteTypeInfo}; #[derive(Debug)] -pub(crate) struct StatementHandle(pub(super) NonNull); +pub(crate) struct StatementHandle(NonNull, Lock); // access to SQLite3 statement handles are safe to send and share between threads // as long as the `sqlite3_step` call is serialized. @@ -32,6 +36,10 @@ unsafe impl Send for StatementHandle {} unsafe impl Sync for StatementHandle {} impl StatementHandle { + pub(super) fn new(ptr: NonNull) -> Self { + Self(ptr, Lock::new()) + } + #[inline] pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 { // O(c) access to the connection handle for this statement handle @@ -280,8 +288,37 @@ impl StatementHandle { Ok(from_utf8(self.column_blob(index))?) } + pub(crate) fn step(&self) -> Result, Error> { + self.1.enter_step(); + + let status = unsafe { sqlite3_step(self.0.as_ptr()) }; + let result = match status { + SQLITE_ROW => Ok(Either::Right(())), + SQLITE_DONE => Ok(Either::Left(self.changes())), + _ => Err(self.last_error().into()), + }; + + if self.1.exit_step() { + unsafe { sqlite3_reset(self.0.as_ptr()) }; + self.1.exit_reset(); + } + + result + } + pub(crate) fn reset(&self) { + if !self.1.enter_reset() { + // reset or step already in progress + return; + } + unsafe { sqlite3_reset(self.0.as_ptr()) }; + + self.1.exit_reset(); + } + + pub(crate) fn clear_bindings(&self) { + unsafe { sqlite3_clear_bindings(self.0.as_ptr()) }; } } impl Drop for StatementHandle { @@ -301,3 +338,44 @@ impl Drop for StatementHandle { } } } + +const RESET: u8 = 0b0000_0001; +const STEP: u8 = 0b0000_0010; + +// Lock to synchronize calls to `step` and `reset`. +#[derive(Debug)] +struct Lock(AtomicU8); + +impl Lock { + fn new() -> Self { + Self(AtomicU8::new(0)) + } + + // If this returns `true` reset can be performed, otherwise reset must be delayed until the + // current step finishes and `exit_step` is called. + fn enter_reset(&self) -> bool { + self.0.fetch_or(RESET, Ordering::Acquire) == 0 + } + + fn exit_reset(&self) { + self.0.fetch_and(!RESET, Ordering::Release); + } + + fn enter_step(&self) { + // NOTE: spin loop should be fine here as we are only waiting for a `reset` to finish which + // should be quick. + while self + .0 + .compare_exchange(0, STEP, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + hint::spin_loop(); + } + } + + // If this returns `true` it means a previous attempt to reset was delayed and must be + // performed now (followed by `exit_reset`). + fn exit_step(&self) -> bool { + self.0.fetch_and(!STEP, Ordering::Release) & RESET != 0 + } +} diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 89dac81376..85141337b5 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -8,8 +8,7 @@ use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; use crate::HashMap; use bytes::{Buf, Bytes}; use libsqlite3_sys::{ - sqlite3, sqlite3_clear_bindings, sqlite3_prepare_v3, sqlite3_reset, sqlite3_stmt, SQLITE_OK, - SQLITE_PREPARE_PERSISTENT, + sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, }; use smallvec::SmallVec; use std::i32; @@ -92,7 +91,7 @@ fn prepare( query.advance(n); if let Some(handle) = NonNull::new(statement_handle) { - return Ok(Some(StatementHandle(handle))); + return Ok(Some(StatementHandle::new(handle))); } } @@ -183,13 +182,11 @@ impl VirtualStatement { for (i, handle) in self.handles.iter().enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - unsafe { - // Reset A Prepared Statement Object - // https://www.sqlite.org/c3ref/reset.html - // https://www.sqlite.org/c3ref/clear_bindings.html - sqlite3_reset(handle.0.as_ptr()); - sqlite3_clear_bindings(handle.0.as_ptr()); - } + // Reset A Prepared Statement Object + // https://www.sqlite.org/c3ref/reset.html + // https://www.sqlite.org/c3ref/clear_bindings.html + handle.reset(); + handle.clear_bindings(); } } } diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index 1503952133..60e44a1115 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -3,7 +3,6 @@ use crate::sqlite::statement::StatementHandle; use crossbeam_channel::{unbounded, Sender}; use either::Either; use futures_channel::oneshot; -use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW}; use std::sync::{Arc, Weak}; use std::thread; @@ -33,14 +32,7 @@ impl StatementWorker { match cmd { StatementWorkerCommand::Step { statement, tx } => { let resp = if let Some(statement) = statement.upgrade() { - let status = unsafe { sqlite3_step(statement.0.as_ptr()) }; - - let resp = match status { - SQLITE_ROW => Ok(Either::Right(())), - SQLITE_DONE => Ok(Either::Left(statement.changes())), - _ => Err(statement.last_error().into()), - }; - resp + statement.step() } else { // Statement is already finalized. Err(Error::WorkerCrashed)