Skip to content

Commit

Permalink
fix(sqlite): run sqlite3_reset() in StatementWorker
Browse files Browse the repository at this point in the history
this avoids possible race conditions without using a mutex
  • Loading branch information
abonander committed Jul 28, 2021
1 parent 0562928 commit 768d9b8
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 97 deletions.
34 changes: 20 additions & 14 deletions sqlx-core/src/sqlite/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::logger::QueryLogger;
use crate::sqlite::connection::describe::describe;
use crate::sqlite::statement::{StatementHandle, VirtualStatement};
use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement};
use crate::sqlite::{
Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
SqliteTypeInfo,
Expand All @@ -16,7 +16,8 @@ use libsqlite3_sys::sqlite3_last_insert_rowid;
use std::borrow::Cow;
use std::sync::Arc;

fn prepare<'a>(
async fn prepare<'a>(
worker: &mut StatementWorker,
statements: &'a mut StatementCache<VirtualStatement>,
statement: &'a mut Option<VirtualStatement>,
query: &str,
Expand All @@ -39,7 +40,7 @@ fn prepare<'a>(
if exists {
// as this statement has been executed before, we reset before continuing
// this also causes any rows that are from the statement to be inflated
statement.reset();
statement.reset(worker).await?;
}

Ok(statement)
Expand All @@ -61,21 +62,25 @@ fn bind(

/// A structure holding sqlite statement handle and resetting the
/// statement when it is dropped.
struct StatementResetter {
struct StatementResetter<'a> {
handle: Arc<StatementHandle>,
worker: &'a mut StatementWorker,
}

impl StatementResetter {
fn new(handle: &Arc<StatementHandle>) -> Self {
impl<'a> StatementResetter<'a> {
fn new(worker: &'a mut StatementWorker, handle: &Arc<StatementHandle>) -> Self {
Self {
worker,
handle: Arc::clone(handle),
}
}
}

impl Drop for StatementResetter {
impl Drop for StatementResetter<'_> {
fn drop(&mut self) {
self.handle.reset();
// this method is designed to eagerly send the reset command
// so we don't need to await or spawn it
let _ = self.worker.reset(&self.handle);
}
}

Expand Down Expand Up @@ -105,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;

// prepare statement object (or checkout from cache)
let stmt = prepare(statements, statement, sql, persistent)?;
let stmt = prepare(worker, statements, statement, sql, persistent).await?;

// keep track of how many arguments we have bound
let mut num_arguments = 0;
Expand All @@ -115,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
// is dropped. `StatementResetter` will reliably reset the
// statement even if the stream returned from `fetch_many`
// is dropped early.
let _resetter = StatementResetter::new(stmt);
let resetter = StatementResetter::new(worker, stmt);

// bind values to the statement
num_arguments += bind(stmt, &arguments, num_arguments)?;
Expand All @@ -127,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
let s = worker.step(stmt).await?;
let s = resetter.worker.step(stmt).await?;

match s {
Either::Left(changes) => {
Expand Down Expand Up @@ -190,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;

// prepare statement object (or checkout from cache)
let virtual_stmt = prepare(statements, statement, sql, persistent)?;
let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?;

// keep track of how many arguments we have bound
let mut num_arguments = 0;
Expand Down Expand Up @@ -218,7 +223,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

logger.increment_rows();

virtual_stmt.reset();
virtual_stmt.reset(worker).await?;
return Ok(Some(row));
}
}
Expand All @@ -240,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
handle: ref mut conn,
ref mut statements,
ref mut statement,
ref mut worker,
..
} = self;

// prepare statement object (or checkout from cache)
let statement = prepare(statements, statement, sql, true)?;
let statement = prepare(worker, statements, statement, sql, true).await?;

let mut parameters = 0;
let mut columns = None;
Expand Down
83 changes: 8 additions & 75 deletions sqlx-core/src/sqlite/statement/handle.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use either::Either;
use std::ffi::c_void;
use std::ffi::CStr;
use std::hint;

use std::os::raw::{c_char, c_int};
use std::ptr;
use std::ptr::NonNull;
use std::slice::from_raw_parts;
use std::str::{from_utf8, from_utf8_unchecked};
use std::sync::atomic::{AtomicU8, Ordering};

use libsqlite3_sys::{
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
Expand All @@ -27,7 +25,7 @@ use crate::sqlite::type_info::DataType;
use crate::sqlite::{SqliteError, SqliteTypeInfo};

#[derive(Debug)]
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>, Lock);
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);

// access to SQLite3 statement handles are safe to send and share between threads
// as long as the `sqlite3_step` call is serialized.
Expand All @@ -37,7 +35,11 @@ unsafe impl Sync for StatementHandle {}

impl StatementHandle {
pub(super) fn new(ptr: NonNull<sqlite3_stmt>) -> Self {
Self(ptr, Lock::new())
Self(ptr)
}

pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt {
self.0.as_ptr()
}

#[inline]
Expand Down Expand Up @@ -288,41 +290,13 @@ impl StatementHandle {
Ok(from_utf8(self.column_blob(index))?)
}

pub(crate) fn step(&self) -> Result<Either<u64, ()>, Error> {
self.1.enter_step();

let status = unsafe { sqlite3_step(self.0.as_ptr()) };
let result = match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(self.changes())),
_ => Err(self.last_error().into()),
};

if self.1.exit_step() {
unsafe { sqlite3_reset(self.0.as_ptr()) };
self.1.exit_reset();
}

result
}

pub(crate) fn reset(&self) {
if !self.1.enter_reset() {
// reset or step already in progress
return;
}

unsafe { sqlite3_reset(self.0.as_ptr()) };

self.1.exit_reset();
}

pub(crate) fn clear_bindings(&self) {
unsafe { sqlite3_clear_bindings(self.0.as_ptr()) };
}
}
impl Drop for StatementHandle {
fn drop(&mut self) {
// SAFETY: we have exclusive access to the `StatementHandle` here
unsafe {
// https://sqlite.org/c3ref/finalize.html
let status = sqlite3_finalize(self.0.as_ptr());
Expand All @@ -338,44 +312,3 @@ impl Drop for StatementHandle {
}
}
}

const RESET: u8 = 0b0000_0001;
const STEP: u8 = 0b0000_0010;

// Lock to synchronize calls to `step` and `reset`.
#[derive(Debug)]
struct Lock(AtomicU8);

impl Lock {
fn new() -> Self {
Self(AtomicU8::new(0))
}

// If this returns `true` reset can be performed, otherwise reset must be delayed until the
// current step finishes and `exit_step` is called.
fn enter_reset(&self) -> bool {
self.0.fetch_or(RESET, Ordering::Acquire) == 0
}

fn exit_reset(&self) {
self.0.fetch_and(!RESET, Ordering::Release);
}

fn enter_step(&self) {
// NOTE: spin loop should be fine here as we are only waiting for a `reset` to finish which
// should be quick.
while self
.0
.compare_exchange(0, STEP, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
hint::spin_loop();
}
}

// If this returns `true` it means a previous attempt to reset was delayed and must be
// performed now (followed by `exit_reset`).
fn exit_step(&self) -> bool {
self.0.fetch_and(!STEP, Ordering::Release) & RESET != 0
}
}
8 changes: 5 additions & 3 deletions sqlx-core/src/sqlite/statement/virtual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::sqlite::connection::ConnectionHandle;
use crate::sqlite::statement::StatementHandle;
use crate::sqlite::statement::{StatementHandle, StatementWorker};
use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue};
use crate::HashMap;
use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -176,7 +176,7 @@ impl VirtualStatement {
)))
}

pub(crate) fn reset(&mut self) {
pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> {
self.index = 0;

for (i, handle) in self.handles.iter().enumerate() {
Expand All @@ -185,9 +185,11 @@ impl VirtualStatement {
// Reset A Prepared Statement Object
// https://www.sqlite.org/c3ref/reset.html
// https://www.sqlite.org/c3ref/clear_bindings.html
handle.reset();
worker.reset(handle).await?;
handle.clear_bindings();
}

Ok(())
}
}

Expand Down
71 changes: 66 additions & 5 deletions sqlx-core/src/sqlite/statement/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use futures_channel::oneshot;
use std::sync::{Arc, Weak};
use std::thread;

use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW};
use std::future::Future;

// Each SQLite connection has a dedicated thread.

// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
Expand All @@ -21,6 +24,10 @@ enum StatementWorkerCommand {
statement: Weak<StatementHandle>,
tx: oneshot::Sender<Result<Either<u64, ()>, Error>>,
},
Reset {
statement: Weak<StatementHandle>,
tx: oneshot::Sender<()>,
},
}

impl StatementWorker {
Expand All @@ -31,13 +38,37 @@ impl StatementWorker {
for cmd in rx {
match cmd {
StatementWorkerCommand::Step { statement, tx } => {
let resp = if let Some(statement) = statement.upgrade() {
statement.step()
let statement = if let Some(statement) = statement.upgrade() {
statement
} else {
// Statement is already finalized.
Err(Error::WorkerCrashed)
// statement is already finalized, the sender shouldn't be expecting a response
continue;
};

// SAFETY: only the `StatementWorker` calls this function
let status = unsafe { sqlite3_step(statement.as_ptr()) };
let result = match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(statement.changes())),
_ => Err(statement.last_error().into()),
};
let _ = tx.send(resp);

let _ = tx.send(result);
}
StatementWorkerCommand::Reset { statement, tx } => {
if let Some(statement) = statement.upgrade() {
// SAFETY: this must be the only place we call `sqlite3_reset`
unsafe { sqlite3_reset(statement.as_ptr()) };

// `sqlite3_reset()` always returns either `SQLITE_OK`
// or the last error code for the statement,
// which should have already been handled;
// so it's assumed the return value is safe to ignore.
//
// https://www.sqlite.org/c3ref/reset.html

let _ = tx.send(());
}
}
}
}
Expand All @@ -61,4 +92,34 @@ impl StatementWorker {

rx.await.map_err(|_| Error::WorkerCrashed)?
}

/// Send a command to the worker to execute `sqlite3_reset()` next.
///
/// This method is written to execute the sending of the command eagerly so
/// you do not need to await the returned future unless you want to.
///
/// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error
/// in the statement execution which should have already been handled from `step()`.
pub(crate) fn reset(
&mut self,
statement: &Arc<StatementHandle>,
) -> impl Future<Output = Result<(), Error>> {
// execute the sending eagerly so we don't need to spawn the future
let (tx, rx) = oneshot::channel();

let send_res = self
.tx
.send(StatementWorkerCommand::Reset {
statement: Arc::downgrade(statement),
tx,
})
.map_err(|_| Error::WorkerCrashed);

async move {
send_res?;

// wait for the response
rx.await.map_err(|_| Error::WorkerCrashed)
}
}
}
Loading

0 comments on commit 768d9b8

Please sign in to comment.