diff --git a/src/dict_management.rs b/src/dict_management.rs index e995fa5..a498706 100644 --- a/src/dict_management.rs +++ b/src/dict_management.rs @@ -1,10 +1,17 @@ use anyhow::Context as AContext; -use rusqlite::{functions::Context, params}; +use rusqlite::{functions::Context, params, Connection, Result}; use std::sync::{Arc, RwLock}; use std::time::Duration; +use lazy_static::__Deref; use zstd::dict::{DecoderDictionary, EncoderDictionary}; +fn get_db_path(conn: &Connection) -> Result { + conn.query_row("PRAGMA database_list", [], |row| row.get(2)) + .map(|path: String| path) + .or_else(|_| Ok("".to_string())) +} + // TODO: the rust interface currently requires a level when preparing a dictionary, but the zstd interface (ZSTD_CCtx_loadDictionary) does not. // TODO: Using LruCache here isn't very smart pub fn encoder_dict_from_ctx( @@ -16,20 +23,23 @@ pub fn encoder_dict_from_ctx( // we cache the instantiated encoder dictionaries keyed by (DbConnection, dict_id, compression_level) // DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases) lazy_static::lazy_static! { - static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); + static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); } let id: i32 = ctx.get(arg_index)?; let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213 let db_handle_pointer = unsafe { db.handle() } as usize; // SAFETY: We're only getting the pointer as an int, not using the raw connection + let conn: &Connection = db.deref(); + let path = get_db_path(conn)?; let mut dicts_write = DICTS.write().unwrap(); - let entry = dicts_write.entry((db_handle_pointer, id, level)); + let entry = dicts_write.entry((db_handle_pointer, id, level, path.clone())); let res = match entry { lru_time_cache::Entry::Vacant(e) => e.insert({ log::debug!( - "loading encoder dictionary {} level {} (should only happen once per 10s)", + "loading encoder dictionary {} level {} @ {} (should only happen once per 10s)", id, - level + level, + path ); let dict_raw: Vec = db @@ -56,18 +66,21 @@ pub fn decoder_dict_from_ctx( // we cache the instantiated decoder dictionaries keyed by (DbConnection, dict_id) // DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases) lazy_static::lazy_static! { - static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); + static ref DICTS: RwLock>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10))); } let id: i32 = ctx.get(arg_index)?; let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213 + let conn: &Connection = db.deref(); + let path = get_db_path(conn)?; let db_handle_pointer = unsafe { db.handle() } as usize; // SAFETY: We're only getting the pointer as an int, not using the raw connection let mut dicts_write = DICTS.write().unwrap(); - let entry = dicts_write.entry((db_handle_pointer, id)); + let entry = dicts_write.entry((db_handle_pointer, id, path.clone())); let res = match entry { lru_time_cache::Entry::Vacant(e) => e.insert({ log::debug!( - "loading decoder dictionary {} (should only happen once per 10s)", - id + "loading decoder dictionary {} @ {} (should only happen once per 10s)", + id, + path ); let db = unsafe { ctx.get_connection()? }; let dict_raw: Vec = db