Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlite: use Arc<> around StatementHandle instead of Copying it #1186

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlx-core/src/sqlite/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>(
// fallback to [column_decltype]
if !stepped && stmt.read_only() {
stepped = true;
let _ = conn.worker.step(*stmt).await;
let _ = conn.worker.step(stmt).await;
}

let mut ty = stmt.column_type_info(col);
Expand Down
18 changes: 10 additions & 8 deletions sqlx-core/src/sqlite/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ fn bind(
/// A structure holding sqlite statement handle and resetting the
/// statement when it is dropped.
struct StatementResetter {
handle: StatementHandle,
handle: Arc<StatementHandle>,
}

impl StatementResetter {
fn new(handle: StatementHandle) -> Self {
Self { handle }
fn new(handle: &Arc<StatementHandle>) -> Self {
Self {
handle: Arc::clone(handle),
}
}
}

Expand Down Expand Up @@ -113,7 +115,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
// is dropped. `StatementResetter` will reliably reset the
// statement even if the stream returned from `fetch_many`
// is dropped early.
let _resetter = StatementResetter::new(*stmt);
let _resetter = StatementResetter::new(stmt);

// bind values to the statement
num_arguments += bind(stmt, &arguments, num_arguments)?;
Expand All @@ -125,7 +127,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
let s = worker.step(*stmt).await?;
let s = worker.step(stmt).await?;

match s {
Either::Left(changes) => {
Expand All @@ -145,7 +147,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

Either::Right(()) => {
let (row, weak_values_ref) = SqliteRow::current(
*stmt,
&stmt,
columns,
column_names
);
Expand Down Expand Up @@ -205,12 +207,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
match worker.step(*stmt).await? {
match worker.step(stmt).await? {
Either::Left(_) => (),

Either::Right(()) => {
let (row, weak_values_ref) =
SqliteRow::current(*stmt, columns, column_names);
SqliteRow::current(stmt, columns, column_names);

*last_row_values = Some(weak_values_ref);

Expand Down
6 changes: 3 additions & 3 deletions sqlx-core/src/sqlite/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct SqliteRow {
// IF the user drops the Row before iterating the stream (so
// nearly all of our internal stream iterators), the executor moves on; otherwise,
// it actually inflates this row with a list of owned sqlite3 values.
pub(crate) statement: StatementHandle,
pub(crate) statement: Arc<StatementHandle>,

pub(crate) values: Arc<AtomicPtr<SqliteValue>>,
pub(crate) num_values: usize,
Expand All @@ -48,7 +48,7 @@ impl SqliteRow {
// returns a weak reference to an atomic list where the executor should inflate if its going
// to increment the statement with [step]
pub(crate) fn current(
statement: StatementHandle,
statement: &Arc<StatementHandle>,
columns: &Arc<Vec<SqliteColumn>>,
column_names: &Arc<HashMap<UStr, usize>>,
) -> (Self, Weak<AtomicPtr<SqliteValue>>) {
Expand All @@ -57,7 +57,7 @@ impl SqliteRow {
let size = statement.column_count();

let row = Self {
statement,
statement: Arc::clone(statement),
values,
num_values: size,
columns: Arc::clone(columns),
Expand Down
25 changes: 21 additions & 4 deletions sqlx-core/src/sqlite/statement/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ use libsqlite3_sys::{
sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype,
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name,
sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type,
sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt,
sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK,
SQLITE_TRANSIENT, SQLITE_UTF8,
sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset, sqlite3_sql,
sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value,
SQLITE_MISUSE, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
};

use crate::error::{BoxDynError, Error};
use crate::sqlite::type_info::DataType;
use crate::sqlite::{SqliteError, SqliteTypeInfo};

#[derive(Debug, Copy, Clone)]
#[derive(Debug)]
pub(crate) struct StatementHandle(pub(super) NonNull<sqlite3_stmt>);

// access to SQLite3 statement handles are safe to send and share between threads
Expand Down Expand Up @@ -284,3 +284,20 @@ impl StatementHandle {
unsafe { sqlite3_reset(self.0.as_ptr()) };
}
}
impl Drop for StatementHandle {
fn drop(&mut self) {
unsafe {
// https://sqlite.org/c3ref/finalize.html
let status = sqlite3_finalize(self.0.as_ptr());
if status == SQLITE_MISUSE {
// Panic in case of detected misuse of SQLite API.
//
// sqlite3_finalize returns it at least in the
// case of detected double free, i.e. calling
// sqlite3_finalize on already finalized
// statement.
panic!("Detected sqlite3_finalize misuse.");
}
}
}
}
24 changes: 5 additions & 19 deletions sqlx-core/src/sqlite/statement/virtual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue};
use crate::HashMap;
use bytes::{Buf, Bytes};
use libsqlite3_sys::{
sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset,
sqlite3_stmt, SQLITE_MISUSE, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
sqlite3, sqlite3_clear_bindings, sqlite3_prepare_v3, sqlite3_reset, sqlite3_stmt, SQLITE_OK,
SQLITE_PREPARE_PERSISTENT,
};
use smallvec::SmallVec;
use std::i32;
Expand All @@ -31,7 +31,7 @@ pub(crate) struct VirtualStatement {
// underlying sqlite handles for each inner statement
// a SQL query string in SQLite is broken up into N statements
// we use a [`SmallVec`] to optimize for the most likely case of a single statement
pub(crate) handles: SmallVec<[StatementHandle; 1]>,
pub(crate) handles: SmallVec<[Arc<StatementHandle>; 1]>,

// each set of columns
pub(crate) columns: SmallVec<[Arc<Vec<SqliteColumn>>; 1]>,
Expand Down Expand Up @@ -126,7 +126,7 @@ impl VirtualStatement {
conn: &mut ConnectionHandle,
) -> Result<
Option<(
&StatementHandle,
&Arc<StatementHandle>,
&mut Arc<Vec<SqliteColumn>>,
&Arc<HashMap<UStr, usize>>,
&mut Option<Weak<AtomicPtr<SqliteValue>>>,
Expand Down Expand Up @@ -159,7 +159,7 @@ impl VirtualStatement {
column_names.insert(name, i);
}

self.handles.push(statement);
self.handles.push(Arc::new(statement));
self.columns.push(Arc::new(columns));
self.column_names.push(Arc::new(column_names));
self.last_row_values.push(None);
Expand Down Expand Up @@ -198,20 +198,6 @@ impl Drop for VirtualStatement {
fn drop(&mut self) {
for (i, handle) in self.handles.drain(..).enumerate() {
SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take());

unsafe {
// https://sqlite.org/c3ref/finalize.html
let status = sqlite3_finalize(handle.0.as_ptr());
if status == SQLITE_MISUSE {
// Panic in case of detected misuse of SQLite API.
//
// sqlite3_finalize returns it at least in the
// case of detected double free, i.e. calling
// sqlite3_finalize on already finalized
// statement.
panic!("Detected sqlite3_finalize misuse.");
}
}
}
}
}
27 changes: 18 additions & 9 deletions sqlx-core/src/sqlite/statement/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crossbeam_channel::{unbounded, Sender};
use either::Either;
use futures_channel::oneshot;
use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW};
use std::sync::{Arc, Weak};
use std::thread;

// Each SQLite connection has a dedicated thread.
Expand All @@ -18,7 +19,7 @@ pub(crate) struct StatementWorker {

enum StatementWorkerCommand {
Step {
statement: StatementHandle,
statement: Weak<StatementHandle>,
tx: oneshot::Sender<Result<Either<u64, ()>, Error>>,
},
}
Expand All @@ -31,14 +32,19 @@ impl StatementWorker {
for cmd in rx {
match cmd {
StatementWorkerCommand::Step { statement, tx } => {
let status = unsafe { sqlite3_step(statement.0.as_ptr()) };
let resp = if let Some(statement) = statement.upgrade() {
let status = unsafe { sqlite3_step(statement.0.as_ptr()) };

let resp = match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(statement.changes())),
_ => Err(statement.last_error().into()),
let resp = match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(statement.changes())),
_ => Err(statement.last_error().into()),
};
resp
Comment on lines +38 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let resp = match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(statement.changes())),
_ => Err(statement.last_error().into()),
};
resp
match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(statement.changes())),
_ => Err(statement.last_error().into()),
}

} else {
// Statement is already finalized.
Err(Error::WorkerCrashed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, the worker didn't actually crash here. I'd consider a different error variant, maybe even a new one if none of the existing fit.

};

let _ = tx.send(resp);
}
}
Expand All @@ -50,12 +56,15 @@ impl StatementWorker {

pub(crate) async fn step(
&mut self,
statement: StatementHandle,
statement: &Arc<StatementHandle>,
) -> Result<Either<u64, ()>, Error> {
let (tx, rx) = oneshot::channel();

self.tx
.send(StatementWorkerCommand::Step { statement, tx })
.send(StatementWorkerCommand::Step {
statement: Arc::downgrade(statement),
tx,
})
.map_err(|_| Error::WorkerCrashed)?;

rx.await.map_err(|_| Error::WorkerCrashed)?
Expand Down