diff --git a/src/dict_management.rs b/src/dict_management.rs index e995fa5..e714c39 100644 --- a/src/dict_management.rs +++ b/src/dict_management.rs @@ -1,7 +1,8 @@ use anyhow::Context as AContext; -use rusqlite::{functions::Context, params}; +use rusqlite::{functions::Context, params, Connection, Error as sqlite_error}; use std::sync::{Arc, RwLock}; use std::time::Duration; +use lazy_static::__Deref; use zstd::dict::{DecoderDictionary, EncoderDictionary}; @@ -16,20 +17,29 @@ pub fn encoder_dict_from_ctx( // we cache the instantiated encoder dictionaries keyed by (DbConnection, dict_id, compression_level) // DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases) lazy_static::lazy_static! { - static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); + static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); } let id: i32 = ctx.get(arg_index)?; let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213 let db_handle_pointer = unsafe { db.handle() } as usize; // SAFETY: We're only getting the pointer as an int, not using the raw connection + let conn: &Connection = db.deref(); + let path_result: Result, sqlite_error> = + conn.query_row_and_then("PRAGMA database_list", [], |row| row.get(2)); + let path = match path_result { + Ok(Some(path_result)) => Ok(path_result), + Ok(None) => Ok("".to_string()), + Err(e) => Err(e), + }?; let mut dicts_write = DICTS.write().unwrap(); - let entry = dicts_write.entry((db_handle_pointer, id, level)); + let entry = dicts_write.entry((db_handle_pointer, id, level, path.clone())); let res = match entry { lru_time_cache::Entry::Vacant(e) => e.insert({ log::debug!( - "loading encoder dictionary {} level {} (should only happen once per 10s)", + "loading encoder dictionary {} level {} @ {}(should only happen once per 10s)", id, - level + level, + path ); let dict_raw: Vec = db @@ -56,18 +66,27 @@ pub fn decoder_dict_from_ctx( // we cache the instantiated decoder dictionaries keyed by (DbConnection, dict_id) // DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases) lazy_static::lazy_static! { - static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); + static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); } let id: i32 = ctx.get(arg_index)?; - let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213 + let db = unsafe { ctx.get_connection()? } ; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213 + let conn: &Connection = db.deref(); + let path_result: Result, sqlite_error> = + conn.query_row_and_then("PRAGMA database_list", [], |row| row.get(2)); + let path = match path_result { + Ok(Some(path_result)) => Ok(path_result), + Ok(None) => Ok("".to_string()), + Err(e) => Err(e), + }?; let db_handle_pointer = unsafe { db.handle() } as usize; // SAFETY: We're only getting the pointer as an int, not using the raw connection let mut dicts_write = DICTS.write().unwrap(); - let entry = dicts_write.entry((db_handle_pointer, id)); + let entry = dicts_write.entry((db_handle_pointer, id, path.clone())); let res = match entry { lru_time_cache::Entry::Vacant(e) => e.insert({ log::debug!( - "loading decoder dictionary {} (should only happen once per 10s)", - id + "loading decoder dictionary {} @ {} (should only happen once per 10s)", + id, + path ); let db = unsafe { ctx.get_connection()? }; let dict_raw: Vec = db