From faa73fe0507bef7d6a3a73228c9428e8415b665e Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Mon, 23 Oct 2023 23:24:37 +0200 Subject: [PATCH] Slight rework of statement_cache's monomorphization prevention - Better encapsulate the "please don't monomorphize" internally to cached_statement - Place the dynamic dispatch edge at the "give me the SQL String" level --- diesel/src/connection/statement_cache.rs | 79 +++++++++++++++++++----- diesel/src/mysql/connection/mod.rs | 6 +- diesel/src/pg/connection/mod.rs | 2 +- diesel/src/sqlite/connection/mod.rs | 10 +-- 4 files changed, 71 insertions(+), 26 deletions(-) diff --git a/diesel/src/connection/statement_cache.rs b/diesel/src/connection/statement_cache.rs index 04e481dd41f0..caf38bb9936b 100644 --- a/diesel/src/connection/statement_cache.rs +++ b/diesel/src/connection/statement_cache.rs @@ -177,12 +177,32 @@ where /// parameter indicates if the constructed prepared statement will be cached or not. /// See the [module](self) documentation for details /// about which statements are cached and which are not cached. - // Note: This function is intentionally monomorphic over the "source" type. #[allow(unreachable_pub)] - pub fn cached_statement( + pub fn cached_statement( + &mut self, + source: &T, + backend: &DB, + bind_types: &[DB::TypeMetadata], + mut prepare_fn: F, + ) -> QueryResult> + where + T: QueryFragment + QueryId, + F: FnMut(&str, PrepareForCache) -> QueryResult, + { + self.cached_statement_non_generic( + T::query_id(), + source, + backend, + bind_types, + &mut prepare_fn, + ) + } + + /// Reduce the amount of monomorphized code by factoring this via dynamic dispatch + fn cached_statement_non_generic( &mut self, maybe_type_id: Option, - source: &dyn QueryFragment, + source: &dyn QueryFragmentForCachedStatement, backend: &DB, bind_types: &[DB::TypeMetadata], prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> QueryResult, @@ -212,6 +232,40 @@ where } } +/// Implemented for all `QueryFragment`s, dedicated to dynamic dispatch within the context of +/// `statement_cache` +/// +/// We want the generated code to be as small as possible, so for each query passed to +/// [`StatementCache::cached_statement`] the generated assembly will just call a non generic +/// version with dynamic dispatch pointing to the VTABLE of this minimal trait +/// +/// This preserves the opportunity for the compiler to entirely optimize the `construct_sql` +/// function as a function that simply returns a constant `String`. +#[allow(unreachable_pub)] +#[cfg_attr( + doc_cfg, + doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")) +)] +pub trait QueryFragmentForCachedStatement { + fn construct_sql(&self, backend: &DB) -> QueryResult; + fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult; +} +impl QueryFragmentForCachedStatement for T +where + DB: Backend, + DB::QueryBuilder: Default, + T: QueryFragment, +{ + fn construct_sql(&self, backend: &DB) -> QueryResult { + let mut query_builder = DB::QueryBuilder::default(); + self.to_sql(&mut query_builder, backend)?; + Ok(query_builder.finish()) + } + fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult { + >::is_safe_to_cache_prepared(self, backend) + } +} + /// Wraps a possibly cached prepared statement /// /// Essentially a customized version of [`Cow`] @@ -290,14 +344,14 @@ where #[allow(unreachable_pub)] pub fn for_source( maybe_type_id: Option, - source: &dyn QueryFragment, + source: &dyn QueryFragmentForCachedStatement, bind_types: &[DB::TypeMetadata], backend: &DB, ) -> QueryResult { match maybe_type_id { Some(id) => Ok(StatementCacheKey::Type(id)), None => { - let sql = Self::construct_sql(source, backend)?; + let sql = source.construct_sql(backend)?; Ok(StatementCacheKey::Sql { sql, bind_types: bind_types.into(), @@ -312,17 +366,14 @@ where /// twice if it's already part of the current cache key // Note: Intentionally monomorphic over source. #[allow(unreachable_pub)] - pub fn sql(&self, source: &dyn QueryFragment, backend: &DB) -> QueryResult> { + pub fn sql( + &self, + source: &dyn QueryFragmentForCachedStatement, + backend: &DB, + ) -> QueryResult> { match *self { - StatementCacheKey::Type(_) => Self::construct_sql(source, backend).map(Cow::Owned), + StatementCacheKey::Type(_) => source.construct_sql(backend).map(Cow::Owned), StatementCacheKey::Sql { ref sql, .. } => Ok(Cow::Borrowed(sql)), } } - - // Note: Intentionally monomorphic over source. - fn construct_sql(source: &dyn QueryFragment, backend: &DB) -> QueryResult { - let mut query_builder = DB::QueryBuilder::default(); - source.to_sql(&mut query_builder, backend)?; - Ok(query_builder.finish()) - } } diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index e2ecf780882e..f61770e005af 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -248,10 +248,8 @@ fn prepared_query<'a, T: QueryFragment + QueryId>( statement_cache: &'a mut StatementCache, raw_connection: &'a mut RawConnection, ) -> QueryResult> { - let mut stmt = - statement_cache.cached_statement(T::query_id(), source, &Mysql, &[], &mut |sql, _| { - raw_connection.prepare(sql) - })?; + let mut stmt = statement_cache + .cached_statement(source, &Mysql, &[], |sql, _| raw_connection.prepare(sql))?; let mut bind_collector = RawBytesBindCollector::new(); source.collect_binds(&mut bind_collector, &mut (), &Mysql)?; let binds = bind_collector diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 67c6fdb2e731..bcf9b18f1cfe 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -356,7 +356,7 @@ impl PgConnection { let cache_len = self.statement_cache.len(); let cache = &mut self.statement_cache; let conn = &mut self.connection_and_transaction_manager.raw_connection; - let query = cache.cached_statement(T::query_id(), source, &Pg, &metadata, &mut |sql, _| { + let query = cache.cached_statement(source, &Pg, &metadata, |sql, _| { let query_name = if source.is_safe_to_cache_prepared(&Pg)? { Some(format!("__diesel_stmt_{cache_len}")) } else { diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index aca2a93275c4..e21234086b86 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -308,13 +308,9 @@ impl SqliteConnection { { let raw_connection = &self.raw_connection; let cache = &mut self.statement_cache; - let statement = cache.cached_statement( - T::query_id(), - &source, - &Sqlite, - &[], - &mut |sql, is_cached| Statement::prepare(raw_connection, sql, is_cached), - )?; + let statement = cache.cached_statement(&source, &Sqlite, &[], |sql, is_cached| { + Statement::prepare(raw_connection, sql, is_cached) + })?; StatementUse::bind(statement, source) }