Skip to content

Commit

Permalink
Slight rework of statement_cache's monomorphization prevention
Browse files Browse the repository at this point in the history
- Better encapsulate the "please don't monomorphize" internally to cached_statement
- Place the dynamic dispatch edge at the "give me the SQL String" level
  • Loading branch information
Ten0 committed Oct 23, 2023
1 parent 7042083 commit faa73fe
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 26 deletions.
79 changes: 65 additions & 14 deletions diesel/src/connection/statement_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,32 @@ where
/// parameter indicates if the constructed prepared statement will be cached or not.
/// See the [module](self) documentation for details
/// about which statements are cached and which are not cached.
// Note: This function is intentionally monomorphic over the "source" type.
#[allow(unreachable_pub)]
pub fn cached_statement(
pub fn cached_statement<T, F>(
&mut self,
source: &T,
backend: &DB,
bind_types: &[DB::TypeMetadata],
mut prepare_fn: F,
) -> QueryResult<MaybeCached<'_, Statement>>
where
T: QueryFragment<DB> + QueryId,
F: FnMut(&str, PrepareForCache) -> QueryResult<Statement>,
{
self.cached_statement_non_generic(
T::query_id(),
source,
backend,
bind_types,
&mut prepare_fn,
)
}

/// Reduce the amount of monomorphized code by factoring this via dynamic dispatch
fn cached_statement_non_generic(
&mut self,
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragment<DB>,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
bind_types: &[DB::TypeMetadata],
prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> QueryResult<Statement>,
Expand Down Expand Up @@ -212,6 +232,40 @@ where
}
}

/// Implemented for all `QueryFragment`s, dedicated to dynamic dispatch within the context of
/// `statement_cache`
///
/// We want the generated code to be as small as possible, so for each query passed to
/// [`StatementCache::cached_statement`] the generated assembly will just call a non generic
/// version with dynamic dispatch pointing to the VTABLE of this minimal trait
///
/// This preserves the opportunity for the compiler to entirely optimize the `construct_sql`
/// function as a function that simply returns a constant `String`.
#[allow(unreachable_pub)]
#[cfg_attr(
doc_cfg,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
pub trait QueryFragmentForCachedStatement<DB> {
fn construct_sql(&self, backend: &DB) -> QueryResult<String>;
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool>;
}
impl<T, DB> QueryFragmentForCachedStatement<DB> for T
where
DB: Backend,
DB::QueryBuilder: Default,
T: QueryFragment<DB>,
{
fn construct_sql(&self, backend: &DB) -> QueryResult<String> {
let mut query_builder = DB::QueryBuilder::default();
self.to_sql(&mut query_builder, backend)?;
Ok(query_builder.finish())
}
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool> {
<T as QueryFragment<DB>>::is_safe_to_cache_prepared(self, backend)
}
}

/// Wraps a possibly cached prepared statement
///
/// Essentially a customized version of [`Cow`]
Expand Down Expand Up @@ -290,14 +344,14 @@ where
#[allow(unreachable_pub)]
pub fn for_source(
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragment<DB>,
source: &dyn QueryFragmentForCachedStatement<DB>,
bind_types: &[DB::TypeMetadata],
backend: &DB,
) -> QueryResult<Self> {
match maybe_type_id {
Some(id) => Ok(StatementCacheKey::Type(id)),
None => {
let sql = Self::construct_sql(source, backend)?;
let sql = source.construct_sql(backend)?;
Ok(StatementCacheKey::Sql {
sql,
bind_types: bind_types.into(),
Expand All @@ -312,17 +366,14 @@ where
/// twice if it's already part of the current cache key
// Note: Intentionally monomorphic over source.
#[allow(unreachable_pub)]
pub fn sql(&self, source: &dyn QueryFragment<DB>, backend: &DB) -> QueryResult<Cow<'_, str>> {
pub fn sql(
&self,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
) -> QueryResult<Cow<'_, str>> {
match *self {
StatementCacheKey::Type(_) => Self::construct_sql(source, backend).map(Cow::Owned),
StatementCacheKey::Type(_) => source.construct_sql(backend).map(Cow::Owned),
StatementCacheKey::Sql { ref sql, .. } => Ok(Cow::Borrowed(sql)),
}
}

// Note: Intentionally monomorphic over source.
fn construct_sql(source: &dyn QueryFragment<DB>, backend: &DB) -> QueryResult<String> {
let mut query_builder = DB::QueryBuilder::default();
source.to_sql(&mut query_builder, backend)?;
Ok(query_builder.finish())
}
}
6 changes: 2 additions & 4 deletions diesel/src/mysql/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,8 @@ fn prepared_query<'a, T: QueryFragment<Mysql> + QueryId>(
statement_cache: &'a mut StatementCache<Mysql, Statement>,
raw_connection: &'a mut RawConnection,
) -> QueryResult<MaybeCached<'a, Statement>> {
let mut stmt =
statement_cache.cached_statement(T::query_id(), source, &Mysql, &[], &mut |sql, _| {
raw_connection.prepare(sql)
})?;
let mut stmt = statement_cache
.cached_statement(source, &Mysql, &[], |sql, _| raw_connection.prepare(sql))?;
let mut bind_collector = RawBytesBindCollector::new();
source.collect_binds(&mut bind_collector, &mut (), &Mysql)?;
let binds = bind_collector
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/pg/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ impl PgConnection {
let cache_len = self.statement_cache.len();
let cache = &mut self.statement_cache;
let conn = &mut self.connection_and_transaction_manager.raw_connection;
let query = cache.cached_statement(T::query_id(), source, &Pg, &metadata, &mut |sql, _| {
let query = cache.cached_statement(source, &Pg, &metadata, |sql, _| {
let query_name = if source.is_safe_to_cache_prepared(&Pg)? {
Some(format!("__diesel_stmt_{cache_len}"))
} else {
Expand Down
10 changes: 3 additions & 7 deletions diesel/src/sqlite/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,9 @@ impl SqliteConnection {
{
let raw_connection = &self.raw_connection;
let cache = &mut self.statement_cache;
let statement = cache.cached_statement(
T::query_id(),
&source,
&Sqlite,
&[],
&mut |sql, is_cached| Statement::prepare(raw_connection, sql, is_cached),
)?;
let statement = cache.cached_statement(&source, &Sqlite, &[], |sql, is_cached| {
Statement::prepare(raw_connection, sql, is_cached)
})?;

StatementUse::bind(statement, source)
}
Expand Down

0 comments on commit faa73fe

Please sign in to comment.