Skip to content

Commit

Permalink
:Janky fix for #43
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Harradon authored and mharradon committed Sep 12, 2024
1 parent 3a820f3 commit 47c47dc
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/dict_management.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use anyhow::Context as AContext;
use rusqlite::{functions::Context, params};
use rusqlite::{functions::Context, params, Connection, Result};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use lazy_static::__Deref;

use zstd::dict::{DecoderDictionary, EncoderDictionary};

fn get_db_path(conn: &Connection) -> Result<String> {
conn.query_row("PRAGMA database_list", [], |row| row.get(2))
.map(|path: String| path)
.or_else(|_| Ok("<unknown>".to_string()))
}

// TODO: the rust interface currently requires a level when preparing a dictionary, but the zstd interface (ZSTD_CCtx_loadDictionary) does not.
// TODO: Using LruCache here isn't very smart
pub fn encoder_dict_from_ctx(
Expand All @@ -16,20 +23,23 @@ pub fn encoder_dict_from_ctx(
// we cache the instantiated encoder dictionaries keyed by (DbConnection, dict_id, compression_level)
// DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases)
lazy_static::lazy_static! {
static ref DICTS: RwLock<LruCache<(usize, i32, i32), Arc<EncoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
static ref DICTS: RwLock<LruCache<(usize, i32, i32, String), Arc<EncoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));

Check failure on line 26 in src/dict_management.rs

View workflow job for this annotation

GitHub Actions / Tests (ubuntu-latest)

very complex type used. Consider factoring parts into `type` definitions
}
let id: i32 = ctx.get(arg_index)?;
let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213
let db_handle_pointer = unsafe { db.handle() } as usize; // SAFETY: We're only getting the pointer as an int, not using the raw connection
let conn: &Connection = db.deref();
let path = get_db_path(conn)?;

let mut dicts_write = DICTS.write().unwrap();
let entry = dicts_write.entry((db_handle_pointer, id, level));
let entry = dicts_write.entry((db_handle_pointer, id, level, path.clone()));
let res = match entry {
lru_time_cache::Entry::Vacant(e) => e.insert({
log::debug!(
"loading encoder dictionary {} level {} (should only happen once per 10s)",
"loading encoder dictionary {} level {} @ {} (should only happen once per 10s)",
id,
level
level,
path
);

let dict_raw: Vec<u8> = db
Expand All @@ -56,18 +66,21 @@ pub fn decoder_dict_from_ctx(
// we cache the instantiated decoder dictionaries keyed by (DbConnection, dict_id)
// DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases)
lazy_static::lazy_static! {
static ref DICTS: RwLock<LruCache<(usize, i32), Arc<DecoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
static ref DICTS: RwLock<LruCache<(usize, i32, String), Arc<DecoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
}
let id: i32 = ctx.get(arg_index)?;
let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213
let conn: &Connection = db.deref();
let path = get_db_path(conn)?;
let db_handle_pointer = unsafe { db.handle() } as usize; // SAFETY: We're only getting the pointer as an int, not using the raw connection
let mut dicts_write = DICTS.write().unwrap();
let entry = dicts_write.entry((db_handle_pointer, id));
let entry = dicts_write.entry((db_handle_pointer, id, path.clone()));
let res = match entry {
lru_time_cache::Entry::Vacant(e) => e.insert({
log::debug!(
"loading decoder dictionary {} (should only happen once per 10s)",
id
"loading decoder dictionary {} @ {} (should only happen once per 10s)",
id,
path
);
let db = unsafe { ctx.get_connection()? };
let dict_raw: Vec<u8> = db
Expand Down

0 comments on commit 47c47dc

Please sign in to comment.