diff --git a/Cargo.toml b/Cargo.toml index 3a3ab054..5e9160ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["askar-bbs", "askar-crypto"] +members = ["askar-bbs", "askar-crypto", "askar-storage"] [package] name = "aries-askar" @@ -26,65 +26,39 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["all_backends", "ffi", "logger", "migration"] -all_backends = ["any", "postgres", "sqlite"] -any = [] -ffi = ["any", "ffi-support", "logger"] +all_backends = ["postgres", "sqlite"] +ffi = ["ffi-support", "logger"] jemalloc = ["jemallocator"] -logger = ["env_logger", "log"] -postgres = ["sqlx", "sqlx/postgres", "sqlx/tls"] -sqlite = ["num_cpus", "sqlx", "sqlx/sqlite"] -pg_test = ["postgres"] -migration = ["rmp", "rmp-serde"] - -[dev-dependencies] -hex-literal = "0.3" +logger = ["env_logger", "log", "askar-storage/log"] +postgres = ["askar-storage/postgres"] +sqlite = ["askar-storage/sqlite"] +pg_test = ["askar-storage/pg_test"] +migration = ["askar-storage/migration"] [dependencies] -arc-swap = "1.6" async-lock = "2.5" -async-stream = "0.3" -bs58 = "0.4" -chrono = "0.4" -digest = "0.10" env_logger = { version = "0.9", optional = true } ffi-support = { version = "0.4", optional = true } -futures-lite = "1.11" -hex = "0.4" -hmac = "0.12" -itertools = "0.10" jemallocator = { version = "0.5", optional = true } log = { version = "0.4", optional = true } -num_cpus = { version = "1.0", optional = true } once_cell = "1.5" -percent-encoding = "2.0" -rand = { version = "0.8", default-features = false } -rmp = { version = "0.8.11", optional = true } -rmp-serde = { version = "1.1.1", optional = true } serde = { version = "1.0", features = ["derive"] } -serde_bytes = "0.11" serde_cbor = "0.11" serde_json = "1.0" -sha2 = "0.10" -tokio = { version = "1.5", features = ["time"] } -url = { version = "2.1", default-features = false } -uuid = { version = "1.2", features = ["v4"] } -zeroize = "1.4" +zeroize = "1.5" [dependencies.askar-crypto] version = "0.2.5" path = "./askar-crypto" features = ["all_keys", "any_key", "argon2", "crypto_box", "std"] -[dependencies.sqlx] -version = "0.6.2" +[dependencies.askar-storage] +version = "0.1.0" +path = "./askar-storage" default-features = false -features = ["chrono", "runtime-tokio-rustls", "macros"] -optional = true +features = ["any"] [profile.release] codegen-units = 1 lto = true panic = "abort" - -[[test]] -name = "backends" diff --git a/askar-bbs/Cargo.toml b/askar-bbs/Cargo.toml index 6b0819e4..99f9fff5 100644 --- a/askar-bbs/Cargo.toml +++ b/askar-bbs/Cargo.toml @@ -33,7 +33,7 @@ subtle = "2.4" criterion = "0.3" # override transitive dependency from criterion to support rust versions older than 1.60 csv = "=1.1" -hex-literal = "0.3" +hex-literal = "0.4" serde-json-core = { version = "0.4", default-features = false, features = ["std"] } diff --git a/askar-crypto/Cargo.toml b/askar-crypto/Cargo.toml index bd60b1b0..c308eb34 100644 --- a/askar-crypto/Cargo.toml +++ b/askar-crypto/Cargo.toml @@ -33,7 +33,7 @@ std_rng = ["getrandom", "rand/std", "rand/std_rng"] [dev-dependencies] base64 = { version = "0.13", default-features = false, features = ["alloc"] } criterion = "0.4" -hex-literal = "0.3" +hex-literal = "0.4" serde_cbor = "0.11" serde-json-core = { version = "0.5", default-features = false, features = ["std"] } @@ -55,8 +55,8 @@ base64 = { version = "0.13", default-features = false } blake2 = { version = "0.10", default-features = false } block-modes = { version = "0.8", default-features = false, optional = true } bls12_381 = { version = "0.6", default-features = false, features = ["groups", "zeroize"], optional = true } -chacha20 = { version = "0.7" } # should match chacha20poly1305 -chacha20poly1305 = { version = "0.8", default-features = false, optional = true } +chacha20 = { version = "0.8" } # should match dependency of chacha20poly1305 +chacha20poly1305 = { version = "0.9", default-features = false, optional = true } crypto_box_rs = { package = "crypto_box", version = "0.6", default-features = false, features = ["u64_backend"], optional = true } curve25519-dalek = { version = "3.1", default-features = false, features = ["u64_backend"], optional = true } ed25519-dalek = { version = "1.0", default-features = false, features = ["u64_backend"], optional = true } @@ -73,4 +73,4 @@ serde-json-core = { version = "0.5", default-features = false } subtle = "2.4" sha2 = { version = "0.10", default-features = false } x25519-dalek = { version = "=1.1", default-features = false, features = ["u64_backend"], optional = true } -zeroize = { version = "1.4", features = ["zeroize_derive"] } +zeroize = { version = "1.5", features = ["zeroize_derive"] } diff --git a/askar-crypto/src/buffer/array.rs b/askar-crypto/src/buffer/array.rs index 31b9bdec..67e44a9c 100644 --- a/askar-crypto/src/buffer/array.rs +++ b/askar-crypto/src/buffer/array.rs @@ -129,7 +129,7 @@ impl> From> for ArrayKey { impl> Debug for ArrayKey { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if cfg!(test) { - f.debug_tuple("ArrayKey").field(&*self).finish() + f.debug_tuple("ArrayKey").field(&self.0).finish() } else { f.debug_tuple("ArrayKey").field(&"").finish() } diff --git a/askar-crypto/src/jwk/mod.rs b/askar-crypto/src/jwk/mod.rs index 3373be29..7cc7648f 100644 --- a/askar-crypto/src/jwk/mod.rs +++ b/askar-crypto/src/jwk/mod.rs @@ -71,7 +71,7 @@ pub fn write_jwk_thumbprint( buf.finalize()?; let hash = hasher.finalize(); let mut buf = [0u8; 43]; - let len = base64::encode_config_slice(&hash, base64::URL_SAFE_NO_PAD, &mut buf); + let len = base64::encode_config_slice(hash, base64::URL_SAFE_NO_PAD, &mut buf); output.buffer_write(&buf[..len])?; Ok(()) } diff --git a/askar-storage/Cargo.toml b/askar-storage/Cargo.toml new file mode 100644 index 00000000..46521985 --- /dev/null +++ b/askar-storage/Cargo.toml @@ -0,0 +1,70 @@ +[package] +name = "askar-storage" +version = "0.1.0" +authors = ["Hyperledger Aries Contributors "] +edition = "2021" +description = "Hyperledger Aries Askar secure storage" +license = "MIT OR Apache-2.0" +readme = "README.md" +repository = "https://github.com/hyperledger/aries-askar/" +categories = ["cryptography", "database"] +keywords = ["hyperledger", "aries", "ssi", "verifiable", "credentials"] +rust-version = "1.58" + +[package.metadata.docs.rs] +features = ["all_backends"] +no-default-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[features] +default = ["all_backends", "log"] +all_backends = ["any", "postgres", "sqlite"] +any = [] +migration = ["rmp-serde", "sqlx/macros"] +postgres = ["sqlx", "sqlx/postgres", "sqlx/tls"] +sqlite = ["sqlx", "sqlx/sqlite"] +pg_test = ["postgres"] + +[dependencies] +arc-swap = "1.6" +async-lock = "2.5" +async-stream = "0.3" +bs58 = "0.4" +chrono = "0.4" +digest = "0.10" +futures-lite = "1.11" +hex = "0.4" +hmac = "0.12" +itertools = "0.10" +log = { version = "0.4", optional = true } +once_cell = "1.5" +percent-encoding = "2.0" +rmp-serde = { version= "1.1", optional = true } +serde = { version = "1.0", features = ["derive"] } +serde_cbor = "0.11" +serde_json = "1.0" +sha2 = "0.10" +tokio = { version = "1.5", features = ["time"] } +url = { version = "2.1", default-features = false } +uuid = { version = "1.2", features = ["v4"] } +zeroize = "1.5" + +[dependencies.askar-crypto] +version = "0.2.5" +path = "../askar-crypto" +default-features = false +features = ["alloc", "argon2", "chacha", "std_rng"] + +[dependencies.sqlx] +version = "0.6.2" +default-features = false +features = ["chrono", "runtime-tokio-rustls"] +optional = true + +[dev-dependencies] +env_logger = "0.9" +hex-literal = "0.4" +rand = { version = "0.8" } + +[[test]] +name = "backends" diff --git a/askar-storage/src/any.rs b/askar-storage/src/any.rs new file mode 100644 index 00000000..85ecf1d4 --- /dev/null +++ b/askar-storage/src/any.rs @@ -0,0 +1,254 @@ +//! Generic backend support + +use std::{fmt::Debug, sync::Arc}; + +use super::{Backend, BackendSession, ManageBackend}; +use crate::{ + entry::{Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, + error::Error, + future::BoxFuture, + options::IntoOptions, + protect::{PassKey, StoreKeyMethod}, +}; + +#[cfg(feature = "postgres")] +use super::postgres; + +#[cfg(feature = "sqlite")] +use super::sqlite; + +/// A dynamic store backend instance +pub type AnyBackend = Arc>; + +/// Wrap a backend instance into an AnyBackend +pub fn into_any_backend(inst: impl Backend + 'static) -> AnyBackend { + Arc::new(WrapBackend(inst)) +} + +#[derive(Debug)] +struct WrapBackend(B); + +impl Backend for WrapBackend { + type Session = AnyBackendSession; + + #[inline] + fn create_profile(&self, name: Option) -> BoxFuture<'_, Result> { + self.0.create_profile(name) + } + + #[inline] + fn get_profile_name(&self) -> &str { + self.0.get_profile_name() + } + + #[inline] + fn remove_profile(&self, name: String) -> BoxFuture<'_, Result> { + self.0.remove_profile(name) + } + + #[inline] + fn scan( + &self, + profile: Option, + kind: Option, + category: Option, + tag_filter: Option, + offset: Option, + limit: Option, + ) -> BoxFuture<'_, Result, Error>> { + self.0 + .scan(profile, kind, category, tag_filter, offset, limit) + } + + #[inline] + fn session(&self, profile: Option, transaction: bool) -> Result { + Ok(AnyBackendSession(Box::new( + self.0.session(profile, transaction)?, + ))) + } + + #[inline] + fn rekey( + &mut self, + method: StoreKeyMethod, + key: PassKey<'_>, + ) -> BoxFuture<'_, Result<(), Error>> { + self.0.rekey(method, key) + } + + #[inline] + fn close(&self) -> BoxFuture<'_, Result<(), Error>> { + self.0.close() + } +} + +/// A dynamic store session instance +#[derive(Debug)] +pub struct AnyBackendSession(Box); + +impl BackendSession for AnyBackendSession { + /// Count the number of matching records in the store + fn count<'q>( + &'q mut self, + kind: Option, + category: Option<&'q str>, + tag_filter: Option, + ) -> BoxFuture<'q, Result> { + self.0.count(kind, category, tag_filter) + } + + /// Fetch a single record from the store by category and name + fn fetch<'q>( + &'q mut self, + kind: EntryKind, + category: &'q str, + name: &'q str, + for_update: bool, + ) -> BoxFuture<'q, Result, Error>> { + self.0.fetch(kind, category, name, for_update) + } + + /// Fetch all matching records from the store + fn fetch_all<'q>( + &'q mut self, + kind: Option, + category: Option<&'q str>, + tag_filter: Option, + limit: Option, + for_update: bool, + ) -> BoxFuture<'q, Result, Error>> { + self.0 + .fetch_all(kind, category, tag_filter, limit, for_update) + } + + /// Remove all matching records from the store + fn remove_all<'q>( + &'q mut self, + kind: Option, + category: Option<&'q str>, + tag_filter: Option, + ) -> BoxFuture<'q, Result> { + self.0.remove_all(kind, category, tag_filter) + } + + /// Insert or replace a record in the store + #[allow(clippy::too_many_arguments)] + fn update<'q>( + &'q mut self, + kind: EntryKind, + operation: EntryOperation, + category: &'q str, + name: &'q str, + value: Option<&'q [u8]>, + tags: Option<&'q [EntryTag]>, + expiry_ms: Option, + ) -> BoxFuture<'q, Result<(), Error>> { + self.0 + .update(kind, operation, category, name, value, tags, expiry_ms) + } + + /// Close the current store session + fn close(&mut self, commit: bool) -> BoxFuture<'_, Result<(), Error>> { + self.0.close(commit) + } +} + +impl<'a> ManageBackend<'a> for &'a str { + type Backend = AnyBackend; + + fn open_backend( + self, + method: Option, + pass_key: PassKey<'a>, + profile: Option<&'a str>, + ) -> BoxFuture<'a, Result> { + Box::pin(async move { + let opts = self.into_options()?; + debug!("Open store with options: {:?}", &opts); + + match opts.schema.as_ref() { + #[cfg(feature = "postgres")] + "postgres" => { + let opts = postgres::PostgresStoreOptions::new(opts)?; + let mgr = opts.open(method, pass_key, profile).await?; + Ok(into_any_backend(mgr)) + } + + #[cfg(feature = "sqlite")] + "sqlite" => { + let opts = sqlite::SqliteStoreOptions::new(opts)?; + let mgr = opts.open(method, pass_key, profile).await?; + Ok(into_any_backend(mgr)) + } + + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), + } + }) + } + + fn provision_backend( + self, + method: StoreKeyMethod, + pass_key: PassKey<'a>, + profile: Option<&'a str>, + recreate: bool, + ) -> BoxFuture<'a, Result> { + Box::pin(async move { + let opts = self.into_options()?; + debug!("Provision store with options: {:?}", &opts); + + match opts.schema.as_ref() { + #[cfg(feature = "postgres")] + "postgres" => { + let opts = postgres::PostgresStoreOptions::new(opts)?; + let mgr = opts.provision(method, pass_key, profile, recreate).await?; + Ok(into_any_backend(mgr)) + } + + #[cfg(feature = "sqlite")] + "sqlite" => { + let opts = sqlite::SqliteStoreOptions::new(opts)?; + let mgr = opts.provision(method, pass_key, profile, recreate).await?; + Ok(into_any_backend(mgr)) + } + + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), + } + }) + } + + fn remove_backend(self) -> BoxFuture<'a, Result> { + Box::pin(async move { + let opts = self.into_options()?; + debug!("Remove store with options: {:?}", &opts); + + match opts.schema.as_ref() { + #[cfg(feature = "postgres")] + "postgres" => { + let opts = postgres::PostgresStoreOptions::new(opts)?; + Ok(opts.remove().await?) + } + + #[cfg(feature = "sqlite")] + "sqlite" => { + let opts = sqlite::SqliteStoreOptions::new(opts)?; + Ok(opts.remove().await?) + } + + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), + } + }) + } +} diff --git a/src/backend/db_utils.rs b/askar-storage/src/backend/db_utils.rs similarity index 97% rename from src/backend/db_utils.rs rename to askar-storage/src/backend/db_utils.rs index cd8403d7..a12b419d 100644 --- a/src/backend/db_utils.rs +++ b/askar-storage/src/backend/db_utils.rs @@ -8,15 +8,13 @@ use sqlx::{ }; use crate::{ + entry::{EncEntryTag, Entry, EntryTag, TagFilter}, error::Error, future::BoxFuture, protect::{EntryEncryptor, KeyCache, PassKey, ProfileId, ProfileKey, StoreKey, StoreKeyMethod}, - storage::{ - wql::{ - sql::TagSqlEncoder, - tags::{tag_query, TagQueryEncoder}, - }, - {EncEntryTag, Entry, EntryTag, TagFilter}, + wql::{ + sql::TagSqlEncoder, + tags::{tag_query, TagQueryEncoder}, }, }; @@ -138,7 +136,7 @@ impl DbSession { DbSessionRef::Owned(self) } - pub(crate) async fn close(mut self, commit: bool) -> Result<(), Error> { + pub(crate) async fn close(&mut self, commit: bool) -> Result<(), Error> { if self.txn_depth > 0 { self.txn_depth = 0; if let Some(conn) = self.connection_mut() { @@ -253,6 +251,7 @@ impl<'q, DB: ExtDatabase> DbSessionActive<'q, DB> { self.inner.connection_mut().unwrap() } + #[allow(unused)] pub fn in_transaction(&self) -> bool { self.inner.in_transaction() } @@ -345,6 +344,7 @@ where } pub struct EncScanEntry { + pub category: Vec, pub name: Vec, pub value: Vec, pub tags: Vec, @@ -508,22 +508,26 @@ pub(crate) fn decode_tags(tags: Vec) -> Result, ()> { } pub fn decrypt_scan_batch( - category: String, + category: Option, enc_rows: Vec, key: &ProfileKey, ) -> Result, Error> { let mut batch = Vec::with_capacity(enc_rows.len()); for enc_entry in enc_rows { - batch.push(decrypt_scan_entry(category.clone(), enc_entry, key)?); + batch.push(decrypt_scan_entry(category.as_deref(), enc_entry, key)?); } Ok(batch) } pub fn decrypt_scan_entry( - category: String, + category: Option<&str>, enc_entry: EncScanEntry, key: &ProfileKey, ) -> Result { + let category = match category { + Some(c) => c.to_owned(), + None => key.decrypt_entry_category(enc_entry.category)?, + }; let name = key.decrypt_entry_name(enc_entry.name)?; let value = key.decrypt_entry_value(category.as_bytes(), name.as_bytes(), enc_entry.value)?; let tags = key.decrypt_entry_tags( diff --git a/src/backend/types.rs b/askar-storage/src/backend/mod.rs similarity index 73% rename from src/backend/types.rs rename to askar-storage/src/backend/mod.rs index 4ede8fee..31342f74 100644 --- a/src/backend/types.rs +++ b/askar-storage/src/backend/mod.rs @@ -1,14 +1,31 @@ +//! Storage backends supported by aries-askar + +use std::fmt::Debug; + use crate::{ + entry::{Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, error::Error, future::BoxFuture, protect::{PassKey, StoreKeyMethod}, - storage::{Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, }; +#[cfg(any(feature = "postgres", feature = "sqlite"))] +pub(crate) mod db_utils; + +#[cfg(feature = "postgres")] +#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] +/// Postgres database support +pub mod postgres; + +#[cfg(feature = "sqlite")] +#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] +/// Sqlite database support +pub mod sqlite; + /// Represents a generic backend implementation -pub trait Backend: Send + Sync { +pub trait Backend: Debug + Send + Sync { /// The type of session managed by this backend - type Session: QueryBackend; + type Session: BackendSession + 'static; /// Create a new profile fn create_profile(&self, name: Option) -> BoxFuture<'_, Result>; @@ -23,8 +40,8 @@ pub trait Backend: Send + Sync { fn scan( &self, profile: Option, - kind: EntryKind, - category: String, + kind: Option, + category: Option, tag_filter: Option, offset: Option, limit: Option, @@ -34,7 +51,7 @@ pub trait Backend: Send + Sync { fn session(&self, profile: Option, transaction: bool) -> Result; /// Replace the wrapping key of the store - fn rekey_backend( + fn rekey( &mut self, method: StoreKeyMethod, key: PassKey<'_>, @@ -46,8 +63,8 @@ pub trait Backend: Send + Sync { /// Create, open, or remove a generic backend implementation pub trait ManageBackend<'a> { - /// The type of store being managed - type Store; + /// The type of backend being managed + type Backend; /// Open an existing store fn open_backend( @@ -55,7 +72,7 @@ pub trait ManageBackend<'a> { method: Option, pass_key: PassKey<'a>, profile: Option<&'a str>, - ) -> BoxFuture<'a, Result>; + ) -> BoxFuture<'a, Result>; /// Provision a new store fn provision_backend( @@ -64,19 +81,19 @@ pub trait ManageBackend<'a> { pass_key: PassKey<'a>, profile: Option<&'a str>, recreate: bool, - ) -> BoxFuture<'a, Result>; + ) -> BoxFuture<'a, Result>; /// Remove an existing store fn remove_backend(self) -> BoxFuture<'a, Result>; } /// Query from a generic backend implementation -pub trait QueryBackend: Send { +pub trait BackendSession: Debug + Send { /// Count the number of matching records in the store fn count<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, ) -> BoxFuture<'q, Result>; @@ -92,8 +109,8 @@ pub trait QueryBackend: Send { /// Fetch all matching records from the store fn fetch_all<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, limit: Option, for_update: bool, @@ -102,8 +119,8 @@ pub trait QueryBackend: Send { /// Remove all matching records from the store fn remove_all<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, ) -> BoxFuture<'q, Result>; @@ -121,5 +138,5 @@ pub trait QueryBackend: Send { ) -> BoxFuture<'q, Result<(), Error>>; /// Close the current store session - fn close(self, commit: bool) -> BoxFuture<'static, Result<(), Error>>; + fn close(&mut self, commit: bool) -> BoxFuture<'_, Result<(), Error>>; } diff --git a/src/backend/postgres/mod.rs b/askar-storage/src/backend/postgres/mod.rs similarity index 86% rename from src/backend/postgres/mod.rs rename to askar-storage/src/backend/postgres/mod.rs index e087f344..868f0342 100644 --- a/src/backend/postgres/mod.rs +++ b/askar-storage/src/backend/postgres/mod.rs @@ -15,24 +15,34 @@ use sqlx::{ Row, }; -use crate::{ - backend::{ - db_utils::{ - decode_tags, decrypt_scan_batch, encode_profile_key, encode_tag_filter, - expiry_timestamp, extend_query, prepare_tags, random_profile_name, - replace_arg_placeholders, DbSession, DbSessionActive, DbSessionRef, DbSessionTxn, - EncScanEntry, ExtDatabase, QueryParams, QueryPrepare, PAGE_SIZE, - }, - types::{Backend, QueryBackend}, +use super::{ + db_utils::{ + decode_tags, decrypt_scan_batch, encode_profile_key, encode_tag_filter, expiry_timestamp, + extend_query, prepare_tags, random_profile_name, replace_arg_placeholders, DbSession, + DbSessionActive, DbSessionRef, DbSessionTxn, EncScanEntry, ExtDatabase, QueryParams, + QueryPrepare, PAGE_SIZE, }, + Backend, BackendSession, +}; +use crate::{ + entry::{EncEntryTag, Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, error::Error, future::{unblock, BoxFuture}, protect::{EntryEncryptor, KeyCache, PassKey, ProfileId, ProfileKey, StoreKeyMethod}, - storage::{EncEntryTag, Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, }; +mod provision; +pub use self::provision::PostgresStoreOptions; + +#[cfg(any(test, feature = "pg_test"))] +mod test_db; +#[cfg(any(test, feature = "pg_test"))] +pub use self::test_db::TestDB; + const COUNT_QUERY: &str = "SELECT COUNT(*) FROM items i - WHERE profile_id = $1 AND kind = $2 AND category = $3 + WHERE profile_id = $1 + AND (kind = $2 OR $2 IS NULL) + AND (category = $3 OR $3 IS NULL) AND (expiry IS NULL OR expiry > CURRENT_TIMESTAMP)"; const DELETE_QUERY: &str = "DELETE FROM items WHERE profile_id = $1 AND kind = $2 AND category = $3 AND name = $4"; @@ -56,27 +66,25 @@ const INSERT_QUERY: &str = "INSERT INTO items (profile_id, kind, category, name, const UPDATE_QUERY: &str = "UPDATE items SET value=$5, expiry=$6 WHERE profile_id=$1 AND kind=$2 AND category=$3 AND name=$4 RETURNING id"; -const SCAN_QUERY: &str = "SELECT id, name, value, +const SCAN_QUERY: &str = "SELECT id, category, name, value, (SELECT ARRAY_TO_STRING(ARRAY_AGG(it.plaintext || ':' || ENCODE(it.name, 'hex') || ':' || ENCODE(it.value, 'hex')), ',') FROM items_tags it WHERE it.item_id = i.id) tags - FROM items i WHERE profile_id = $1 AND kind = $2 AND category = $3 + FROM items i WHERE profile_id = $1 + AND (kind = $2 OR $2 IS NULL) + AND (category = $3 OR $3 IS NULL) AND (expiry IS NULL OR expiry > CURRENT_TIMESTAMP)"; const DELETE_ALL_QUERY: &str = "DELETE FROM items i - WHERE i.profile_id = $1 AND i.kind = $2 AND i.category = $3"; + WHERE profile_id = $1 + AND (kind = $2 OR $2 IS NULL) + AND (category = $3 OR $3 IS NULL)"; const TAG_INSERT_QUERY: &str = "INSERT INTO items_tags (item_id, name, value, plaintext) VALUES ($1, $2, $3, $4)"; const TAG_DELETE_QUERY: &str = "DELETE FROM items_tags WHERE item_id=$1"; -mod provision; -pub use provision::PostgresStoreOptions; - -#[cfg(any(test, feature = "pg_test"))] -pub mod test_db; - /// A PostgreSQL database store -pub struct PostgresStore { +pub struct PostgresBackend { conn_pool: PgPool, default_profile: String, key_cache: Arc, @@ -84,7 +92,7 @@ pub struct PostgresStore { name: String, } -impl PostgresStore { +impl PostgresBackend { pub(crate) fn new( conn_pool: PgPool, default_profile: String, @@ -102,7 +110,7 @@ impl PostgresStore { } } -impl Backend for PostgresStore { +impl Backend for PostgresBackend { type Session = DbSession; fn create_profile(&self, name: Option) -> BoxFuture<'_, Result> { @@ -151,7 +159,7 @@ impl Backend for PostgresStore { }) } - fn rekey_backend( + fn rekey( &mut self, method: StoreKeyMethod, pass_key: PassKey<'_>, @@ -206,8 +214,8 @@ impl Backend for PostgresStore { fn scan( &self, profile: Option, - kind: EntryKind, - category: String, + kind: Option, + category: Option, tag_filter: Option, offset: Option, limit: Option, @@ -215,7 +223,7 @@ impl Backend for PostgresStore { Box::pin(async move { let session = self.session(profile, false)?; let mut active = session.owned_ref(); - let (profile_id, key) = acquire_key(&mut *active).await?; + let (profile_id, key) = acquire_key(&mut active).await?; let scan = perform_scan( active, profile_id, @@ -253,7 +261,7 @@ impl Backend for PostgresStore { } } -impl Debug for PostgresStore { +impl Debug for PostgresBackend { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("PostgresStore") .field("default_profile", &self.default_profile) @@ -263,33 +271,35 @@ impl Debug for PostgresStore { } } -impl QueryBackend for DbSession { +impl BackendSession for DbSession { fn count<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, ) -> BoxFuture<'q, Result> { - let category = ProfileKey::prepare_input(category.as_bytes()); + let enc_category = category.map(|c| ProfileKey::prepare_input(c.as_bytes())); Box::pin(async move { let (profile_id, key) = acquire_key(&mut *self).await?; let mut params = QueryParams::new(); params.push(profile_id); - params.push(kind as i16); + params.push(kind.map(|k| k as i16)); let (enc_category, tag_filter) = unblock({ let params_len = params.len() + 1; // plus category move || { Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - encode_tag_filter::(tag_filter, &key, params_len)?, + enc_category + .map(|c| key.encrypt_entry_category(c)) + .transpose()?, + encode_tag_filter::(tag_filter, &key, params_len)?, )) } }) .await?; params.push(enc_category); let query = - extend_query::(COUNT_QUERY, &mut params, tag_filter, None, None)?; + extend_query::(COUNT_QUERY, &mut params, tag_filter, None, None)?; let mut active = acquire_session(&mut *self).await?; let count = sqlx::query_scalar_with(query.as_str(), params) .fetch_one(active.connection_mut()) @@ -359,17 +369,17 @@ impl QueryBackend for DbSession { fn fetch_all<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, limit: Option, for_update: bool, ) -> BoxFuture<'q, Result, Error>> { - let category = category.to_string(); + let category = category.map(|c| c.to_string()); Box::pin(async move { let for_update = for_update && self.in_transaction(); let mut active = self.borrow_mut(); - let (profile_id, key) = acquire_key(&mut *active).await?; + let (profile_id, key) = acquire_key(&mut active).await?; let scan = perform_scan( active, profile_id, @@ -392,29 +402,31 @@ impl QueryBackend for DbSession { fn remove_all<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, ) -> BoxFuture<'q, Result> { - let category = ProfileKey::prepare_input(category.as_bytes()); + let enc_category = category.map(|c| ProfileKey::prepare_input(c.as_bytes())); Box::pin(async move { let (profile_id, key) = acquire_key(&mut *self).await?; let mut params = QueryParams::new(); params.push(profile_id); - params.push(kind as i16); + params.push(kind.map(|k| k as i16)); let (enc_category, tag_filter) = unblock({ let params_len = params.len() + 1; // plus category move || { Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - encode_tag_filter::(tag_filter, &key, params_len)?, + enc_category + .map(|c| key.encrypt_entry_category(c)) + .transpose()?, + encode_tag_filter::(tag_filter, &key, params_len)?, )) } }) .await?; params.push(enc_category); - let query = extend_query::( + let query = extend_query::( DELETE_ALL_QUERY, &mut params, tag_filter, @@ -446,7 +458,7 @@ impl QueryBackend for DbSession { match operation { op @ EntryOperation::Insert | op @ EntryOperation::Replace => { - let value = ProfileKey::prepare_input(value.unwrap()); + let value = ProfileKey::prepare_input(value.unwrap_or_default()); let tags = tags.map(prepare_tags); Box::pin(async move { let (_, key) = acquire_key(&mut *self).await?; @@ -496,14 +508,14 @@ impl QueryBackend for DbSession { } } - fn close(self, commit: bool) -> BoxFuture<'static, Result<(), Error>> { - Box::pin(DbSession::close(self, commit)) + fn close(&mut self, commit: bool) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(self.close(commit)) } } impl ExtDatabase for Postgres {} -impl QueryPrepare for PostgresStore { +impl QueryPrepare for PostgresBackend { type DB = Postgres; fn placeholder(index: i64) -> String { @@ -650,8 +662,8 @@ fn perform_scan( mut active: DbSessionRef<'_, Postgres>, profile_id: ProfileId, key: Arc, - kind: EntryKind, - category: String, + kind: Option, + category: Option, tag_filter: Option, offset: Option, limit: Option, @@ -660,31 +672,33 @@ fn perform_scan( try_stream! { let mut params = QueryParams::new(); params.push(profile_id); - params.push(kind as i16); + params.push(kind.map(|k| k as i16)); let (enc_category, tag_filter) = unblock({ let key = key.clone(); - let category = ProfileKey::prepare_input(category.as_bytes()); + let enc_category = category.map(|c| ProfileKey::prepare_input(c.as_bytes())); let params_len = params.len() + 1; // plus category move || { Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - encode_tag_filter::(tag_filter, &key, params_len)? + enc_category + .map(|c| key.encrypt_entry_category(c)) + .transpose()?, + encode_tag_filter::(tag_filter, &key, params_len)? )) } }).await?; params.push(enc_category); - let mut query = extend_query::(SCAN_QUERY, &mut params, tag_filter, offset, limit)?; + let mut query = extend_query::(SCAN_QUERY, &mut params, tag_filter, offset, limit)?; if for_update { query.push_str(" FOR NO KEY UPDATE"); } let mut batch = Vec::with_capacity(PAGE_SIZE); - let mut acquired = acquire_session(&mut *active).await?; + let mut acquired = acquire_session(&mut active).await?; let mut rows = sqlx::query_with(query.as_str(), params).fetch(acquired.connection_mut()); while let Some(row) = rows.try_next().await? { - let tags = row.try_get::, _>(3)?.map(String::into_bytes).unwrap_or_default(); + let tags = row.try_get::, _>(4)?.map(String::into_bytes).unwrap_or_default(); batch.push(EncScanEntry { - name: row.try_get(1)?, value: row.try_get(2)?, tags + category: row.try_get(1)?, name: row.try_get(2)?, value: row.try_get(3)?, tags }); if batch.len() == PAGE_SIZE { yield batch.split_off(0); @@ -693,7 +707,7 @@ fn perform_scan( drop(rows); drop(active); - if batch.len() > 0 { + if !batch.is_empty() { yield batch; } } @@ -707,7 +721,7 @@ mod tests { #[test] fn postgres_simple_and_convert_args_works() { assert_eq!( - &replace_arg_placeholders::("This $$ is $10 a $$ string!", 3), + &replace_arg_placeholders::("This $$ is $10 a $$ string!", 3), "This $3 is $12 a $5 string!", ); } diff --git a/src/backend/postgres/provision.rs b/askar-storage/src/backend/postgres/provision.rs similarity index 95% rename from src/backend/postgres/provision.rs rename to askar-storage/src/backend/postgres/provision.rs index b66ce9d4..336afd9b 100644 --- a/src/backend/postgres/provision.rs +++ b/askar-storage/src/backend/postgres/provision.rs @@ -10,15 +10,15 @@ use sqlx::{ use crate::{ backend::{ db_utils::{init_keys, random_profile_name}, - types::ManageBackend, + ManageBackend, }, error::Error, future::{unblock, BoxFuture}, + options::IntoOptions, protect::{KeyCache, PassKey, ProfileId, StoreKeyMethod, StoreKeyReference}, - storage::{IntoOptions, Store}, }; -use super::PostgresStore; +use super::PostgresBackend; const DEFAULT_CONNECT_TIMEOUT: u64 = 30; const DEFAULT_IDLE_TIMEOUT: u64 = 300; @@ -89,7 +89,7 @@ impl PostgresStoreOptions { if path.len() < 2 { return Err(err_msg!(Input, "Missing database name")); } - let name = (&path[1..]).to_string(); + let name = path[1..].to_string(); if name.find(|c| c == '"' || c == '\0').is_some() { return Err(err_msg!( Input, @@ -162,7 +162,7 @@ impl PostgresStoreOptions { admin_conn.close().await?; Ok(self.pool().await?) } - Err(err) => return Err(err_msg!(Backend, "Error opening database").with_cause(err)), + Err(err) => Err(err_msg!(Backend, "Error opening database").with_cause(err)), } } @@ -173,13 +173,13 @@ impl PostgresStoreOptions { pass_key: PassKey<'_>, profile: Option<&str>, recreate: bool, - ) -> Result, Error> { + ) -> Result { let conn_pool = self.create_db_pool().await?; let mut txn = conn_pool.begin().await?; if recreate { // remove expected tables - reset_db(&mut *txn).await?; + reset_db(&mut txn).await?; } else if sqlx::query_scalar::<_, i64>( "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema='public' AND table_name='config'", @@ -213,13 +213,13 @@ impl PostgresStoreOptions { let mut key_cache = KeyCache::new(store_key); key_cache.add_profile_mut(default_profile.clone(), profile_id, profile_key); - Ok(Store::new(PostgresStore::new( + Ok(PostgresBackend::new( conn_pool, default_profile, key_cache, self.host, self.name, - ))) + )) } /// Open an existing Postgres store from this set of configuration options @@ -228,7 +228,7 @@ impl PostgresStoreOptions { method: Option, pass_key: PassKey<'_>, profile: Option<&str>, - ) -> Result, Error> { + ) -> Result { let pool = match self.pool().await { Ok(p) => Ok(p), Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => { @@ -263,14 +263,14 @@ impl PostgresStoreOptions { } impl<'a> ManageBackend<'a> for PostgresStoreOptions { - type Store = Store; + type Backend = PostgresBackend; fn open_backend( self, method: Option, pass_key: PassKey<'_>, profile: Option<&'a str>, - ) -> BoxFuture<'a, Result, Error>> { + ) -> BoxFuture<'a, Result> { let pass_key = pass_key.into_owned(); Box::pin(self.open(method, pass_key, profile)) } @@ -281,7 +281,7 @@ impl<'a> ManageBackend<'a> for PostgresStoreOptions { pass_key: PassKey<'_>, profile: Option<&'a str>, recreate: bool, - ) -> BoxFuture<'a, Result, Error>> { + ) -> BoxFuture<'a, Result> { let pass_key = pass_key.into_owned(); Box::pin(self.provision(method, pass_key, profile, recreate)) } @@ -389,7 +389,7 @@ pub(crate) async fn open_db( profile: Option<&str>, host: String, name: String, -) -> Result, Error> { +) -> Result { let mut conn = conn_pool.acquire().await?; let mut ver_ok = false; let mut default_profile: Option = None; @@ -450,9 +450,9 @@ pub(crate) async fn open_db( let profile_key = key_cache.load_key(row.try_get(1)?).await?; key_cache.add_profile_mut(profile.clone(), profile_id, profile_key); - Ok(Store::new(PostgresStore::new( + Ok(PostgresBackend::new( conn_pool, profile, key_cache, host, name, - ))) + )) } #[cfg(test)] diff --git a/src/backend/postgres/test_db.rs b/askar-storage/src/backend/postgres/test_db.rs similarity index 84% rename from src/backend/postgres/test_db.rs rename to askar-storage/src/backend/postgres/test_db.rs index 8dccf404..00295f16 100644 --- a/src/backend/postgres/test_db.rs +++ b/askar-storage/src/backend/postgres/test_db.rs @@ -7,37 +7,37 @@ use sqlx::{ use std::time::Duration; use super::provision::{init_db, reset_db, PostgresStoreOptions}; -use super::PostgresStore; +use super::PostgresBackend; use crate::{ + any::{into_any_backend, AnyBackend}, backend::db_utils::{init_keys, random_profile_name}, error::Error, future::{sleep, spawn_ok, timeout, unblock}, protect::{generate_raw_store_key, KeyCache, StoreKeyMethod}, - storage::Store, }; #[derive(Debug)] /// Postgres test database wrapper instance pub struct TestDB { - inst: Option>, + inst: Option, lock_txn: Option, } impl TestDB { + /// Access the backend instance + pub fn backend(&self) -> AnyBackend { + self.inst.clone().expect("Database not opened") + } + /// Provision a new instance of the test database. /// This method blocks until the database lock can be acquired. - pub async fn provision() -> Result { - let path = match std::env::var("POSTGRES_URL") { - Ok(p) if !p.is_empty() => p, - _ => panic!("'POSTGRES_URL' must be defined"), - }; - + pub async fn provision(db_url: &str) -> Result { let key = generate_raw_store_key(None)?; let (profile_key, enc_profile_key, store_key, store_key_ref) = unblock(|| init_keys(StoreKeyMethod::RawKey, key)).await?; let default_profile = random_profile_name(); - let opts = PostgresStoreOptions::new(path.as_str())?; + let opts = PostgresStoreOptions::new(db_url)?; let conn_pool = opts.create_db_pool().await?; // we hold a transaction open with a fixed advisory lock value. @@ -60,7 +60,7 @@ impl TestDB { let mut init_txn = conn_pool.begin().await?; // delete existing tables - reset_db(&mut *init_txn).await?; + reset_db(&mut init_txn).await?; // create tables and add default profile let profile_id = @@ -68,7 +68,7 @@ impl TestDB { let mut key_cache = KeyCache::new(store_key); key_cache.add_profile_mut(default_profile.clone(), profile_id, profile_key); - let inst = Store::new(PostgresStore::new( + let inst = into_any_backend(PostgresBackend::new( conn_pool, default_profile, key_cache, @@ -84,7 +84,7 @@ impl TestDB { async fn close_internal( mut lock_txn: Option, - mut inst: Option>, + mut inst: Option, ) -> Result<(), Error> { if let Some(lock_txn) = lock_txn.take() { lock_txn.close().await?; @@ -109,14 +109,6 @@ impl TestDB { } } -impl std::ops::Deref for TestDB { - type Target = Store; - - fn deref(&self) -> &Self::Target { - self.inst.as_ref().unwrap() - } -} - impl Drop for TestDB { fn drop(&mut self) { if self.lock_txn.is_some() || self.inst.is_some() { diff --git a/src/backend/sqlite/mod.rs b/askar-storage/src/backend/sqlite/mod.rs similarity index 86% rename from src/backend/sqlite/mod.rs rename to askar-storage/src/backend/sqlite/mod.rs index 01384fab..32b2336a 100644 --- a/src/backend/sqlite/mod.rs +++ b/askar-storage/src/backend/sqlite/mod.rs @@ -14,27 +14,28 @@ use sqlx::{ Database, Error as SqlxError, Row, TransactionManager, }; -use crate::{ - backend::{ - db_utils::{ - decode_tags, decrypt_scan_batch, encode_profile_key, encode_tag_filter, - expiry_timestamp, extend_query, prepare_tags, random_profile_name, DbSession, - DbSessionActive, DbSessionRef, DbSessionTxn, EncScanEntry, ExtDatabase, QueryParams, - QueryPrepare, PAGE_SIZE, - }, - types::{Backend, QueryBackend}, +use super::{ + db_utils::{ + decode_tags, decrypt_scan_batch, encode_profile_key, encode_tag_filter, expiry_timestamp, + extend_query, prepare_tags, random_profile_name, DbSession, DbSessionActive, DbSessionRef, + DbSessionTxn, EncScanEntry, ExtDatabase, QueryParams, QueryPrepare, PAGE_SIZE, }, + Backend, BackendSession, +}; +use crate::{ + entry::{EncEntryTag, Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, error::Error, future::{unblock, BoxFuture}, protect::{EntryEncryptor, KeyCache, PassKey, ProfileId, ProfileKey, StoreKeyMethod}, - storage::{EncEntryTag, Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, }; mod provision; pub use provision::SqliteStoreOptions; const COUNT_QUERY: &str = "SELECT COUNT(*) FROM items i - WHERE profile_id = ?1 AND kind = ?2 AND category = ?3 + WHERE profile_id = ?1 + AND (kind = ?2 OR ?2 IS NULL) + AND (category = ?3 OR ?3 IS NULL) AND (expiry IS NULL OR expiry > DATETIME('now'))"; const DELETE_QUERY: &str = "DELETE FROM items WHERE profile_id = ?1 AND kind = ?2 AND category = ?3 AND name = ?4"; @@ -49,27 +50,31 @@ const INSERT_QUERY: &str = VALUES (?1, ?2, ?3, ?4, ?5, ?6)"; const UPDATE_QUERY: &str = "UPDATE items SET value=?5, expiry=?6 WHERE profile_id=?1 AND kind=?2 AND category=?3 AND name=?4 RETURNING id"; -const SCAN_QUERY: &str = "SELECT i.id, i.name, i.value, +const SCAN_QUERY: &str = "SELECT i.id, i.category, i.name, i.value, (SELECT GROUP_CONCAT(it.plaintext || ':' || HEX(it.name) || ':' || HEX(it.value)) FROM items_tags it WHERE it.item_id = i.id) AS tags - FROM items i WHERE i.profile_id = ?1 AND i.kind = ?2 AND i.category = ?3 + FROM items i WHERE i.profile_id = ?1 + AND (i.kind = ?2 OR ?2 IS NULL) + AND (i.category = ?3 OR ?3 IS NULL) AND (i.expiry IS NULL OR i.expiry > DATETIME('now'))"; const DELETE_ALL_QUERY: &str = "DELETE FROM items AS i - WHERE i.profile_id = ?1 AND i.kind = ?2 AND i.category = ?3"; + WHERE i.profile_id = ?1 + AND (i.kind = ?2 OR ?2 IS NULL) + AND (i.category = ?3 OR ?3 IS NULL)"; const TAG_INSERT_QUERY: &str = "INSERT INTO items_tags (item_id, name, value, plaintext) VALUES (?1, ?2, ?3, ?4)"; const TAG_DELETE_QUERY: &str = "DELETE FROM items_tags WHERE item_id=?1"; /// A Sqlite database store -pub struct SqliteStore { +pub struct SqliteBackend { conn_pool: SqlitePool, default_profile: String, key_cache: Arc, path: String, } -impl SqliteStore { +impl SqliteBackend { pub(crate) fn new( conn_pool: SqlitePool, default_profile: String, @@ -85,7 +90,7 @@ impl SqliteStore { } } -impl Debug for SqliteStore { +impl Debug for SqliteBackend { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("SqliteStore") .field("default_profile", &self.default_profile) @@ -94,11 +99,11 @@ impl Debug for SqliteStore { } } -impl QueryPrepare for SqliteStore { +impl QueryPrepare for SqliteBackend { type DB = Sqlite; } -impl Backend for SqliteStore { +impl Backend for SqliteBackend { type Session = DbSession; fn create_profile(&self, name: Option) -> BoxFuture<'_, Result> { @@ -148,7 +153,7 @@ impl Backend for SqliteStore { }) } - fn rekey_backend( + fn rekey( &mut self, method: StoreKeyMethod, pass_key: PassKey<'_>, @@ -203,8 +208,8 @@ impl Backend for SqliteStore { fn scan( &self, profile: Option, - kind: EntryKind, - category: String, + kind: Option, + category: Option, tag_filter: Option, offset: Option, limit: Option, @@ -212,7 +217,7 @@ impl Backend for SqliteStore { Box::pin(async move { let session = self.session(profile, false)?; let mut active = session.owned_ref(); - let (profile_id, key) = acquire_key(&mut *active).await?; + let (profile_id, key) = acquire_key(&mut active).await?; let scan = perform_scan( active, profile_id, @@ -249,33 +254,35 @@ impl Backend for SqliteStore { } } -impl QueryBackend for DbSession { +impl BackendSession for DbSession { fn count<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, ) -> BoxFuture<'q, Result> { - let category = ProfileKey::prepare_input(category.as_bytes()); + let enc_category = category.map(|c| ProfileKey::prepare_input(c.as_bytes())); Box::pin(async move { let (profile_id, key) = acquire_key(&mut *self).await?; let mut params = QueryParams::new(); params.push(profile_id); - params.push(kind as i16); + params.push(kind.map(|k| k as i16)); let (enc_category, tag_filter) = unblock({ let params_len = params.len() + 1; // plus category move || { Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - encode_tag_filter::(tag_filter, &key, params_len)?, + enc_category + .map(|c| key.encrypt_entry_category(c)) + .transpose()?, + encode_tag_filter::(tag_filter, &key, params_len)?, )) } }) .await?; params.push(enc_category); let query = - extend_query::(COUNT_QUERY, &mut params, tag_filter, None, None)?; + extend_query::(COUNT_QUERY, &mut params, tag_filter, None, None)?; let mut active = acquire_session(&mut *self).await?; let count = sqlx::query_scalar_with(query.as_str(), params) .fetch_one(active.connection_mut()) @@ -336,16 +343,16 @@ impl QueryBackend for DbSession { fn fetch_all<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, limit: Option, _for_update: bool, ) -> BoxFuture<'q, Result, Error>> { - let category = category.to_string(); + let category = category.map(|c| c.to_string()); Box::pin(async move { let mut active = self.borrow_mut(); - let (profile_id, key) = acquire_key(&mut *active).await?; + let (profile_id, key) = acquire_key(&mut active).await?; let scan = perform_scan( active, profile_id, @@ -367,30 +374,37 @@ impl QueryBackend for DbSession { fn remove_all<'q>( &'q mut self, - kind: EntryKind, - category: &'q str, + kind: Option, + category: Option<&'q str>, tag_filter: Option, ) -> BoxFuture<'q, Result> { - let category = ProfileKey::prepare_input(category.as_bytes()); + let enc_category = category.map(|c| ProfileKey::prepare_input(c.as_bytes())); Box::pin(async move { let (profile_id, key) = acquire_key(&mut *self).await?; let mut params = QueryParams::new(); params.push(profile_id); - params.push(kind as i16); + params.push(kind.map(|k| k as i16)); let (enc_category, tag_filter) = unblock({ let params_len = params.len() + 1; // plus category move || { Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - encode_tag_filter::(tag_filter, &key, params_len)?, + enc_category + .map(|c| key.encrypt_entry_category(c)) + .transpose()?, + encode_tag_filter::(tag_filter, &key, params_len)?, )) } }) .await?; params.push(enc_category); - let query = - extend_query::(DELETE_ALL_QUERY, &mut params, tag_filter, None, None)?; + let query = extend_query::( + DELETE_ALL_QUERY, + &mut params, + tag_filter, + None, + None, + )?; let mut active = acquire_session(&mut *self).await?; let removed = sqlx::query_with(query.as_str(), params) @@ -416,7 +430,7 @@ impl QueryBackend for DbSession { match operation { op @ EntryOperation::Insert | op @ EntryOperation::Replace => { - let value = ProfileKey::prepare_input(value.unwrap()); + let value = ProfileKey::prepare_input(value.unwrap_or_default()); let tags = tags.map(prepare_tags); Box::pin(async move { let (_, key) = acquire_key(&mut *self).await?; @@ -466,8 +480,8 @@ impl QueryBackend for DbSession { } } - fn close(self, commit: bool) -> BoxFuture<'static, Result<(), Error>> { - Box::pin(DbSession::close(self, commit)) + fn close(&mut self, commit: bool) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(self.close(commit)) } } @@ -615,8 +629,8 @@ fn perform_scan( mut active: DbSessionRef<'_, Sqlite>, profile_id: ProfileId, key: Arc, - kind: EntryKind, - category: String, + kind: Option, + category: Option, tag_filter: Option, offset: Option, limit: Option, @@ -624,28 +638,28 @@ fn perform_scan( try_stream! { let mut params = QueryParams::new(); params.push(profile_id); - params.push(kind as i16); + params.push(kind.map(|k| k as i16)); let (enc_category, tag_filter) = unblock({ let key = key.clone(); - let category = ProfileKey::prepare_input(category.as_bytes()); + let enc_category = category.as_ref().map(|c| ProfileKey::prepare_input(c.as_bytes())); let params_len = params.len() + 1; // plus category move || { Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - encode_tag_filter::(tag_filter, &key, params_len)? + enc_category.map(|c| key.encrypt_entry_category(c)).transpose()?, + encode_tag_filter::(tag_filter, &key, params_len)? )) } }).await?; params.push(enc_category); - let query = extend_query::(SCAN_QUERY, &mut params, tag_filter, offset, limit)?; + let query = extend_query::(SCAN_QUERY, &mut params, tag_filter, offset, limit)?; let mut batch = Vec::with_capacity(PAGE_SIZE); - let mut acquired = acquire_session(&mut *active).await?; + let mut acquired = acquire_session(&mut active).await?; let mut rows = sqlx::query_with(query.as_str(), params).fetch(acquired.connection_mut()); while let Some(row) = rows.try_next().await? { batch.push(EncScanEntry { - name: row.try_get(1)?, value: row.try_get(2)?, tags: row.try_get(3)? + category: row.try_get(1)?, name: row.try_get(2)?, value: row.try_get(3)?, tags: row.try_get(4)? }); if batch.len() == PAGE_SIZE { yield batch.split_off(0); @@ -677,7 +691,7 @@ mod tests { let ts = expiry_timestamp(1000).unwrap(); let check = sqlx::query("SELECT datetime('now'), ?1, ?1 > datetime('now')") .bind(ts) - .fetch_one(&db.inner().conn_pool) + .fetch_one(&db.conn_pool) .await?; let now: String = check.try_get(0)?; let cmp_ts: String = check.try_get(1)?; @@ -693,11 +707,11 @@ mod tests { #[test] fn sqlite_query_placeholders() { assert_eq!( - &replace_arg_placeholders::("This $$ is $10 a $$ string!", 3), + &replace_arg_placeholders::("This $$ is $10 a $$ string!", 3), "This ?3 is ?12 a ?5 string!", ); assert_eq!( - &replace_arg_placeholders::("This $a is a string!", 1), + &replace_arg_placeholders::("This $a is a string!", 1), "This $a is a string!", ); } diff --git a/src/backend/sqlite/provision.rs b/askar-storage/src/backend/sqlite/provision.rs similarity index 93% rename from src/backend/sqlite/provision.rs rename to askar-storage/src/backend/sqlite/provision.rs index 27abc4be..54bb06bd 100644 --- a/src/backend/sqlite/provision.rs +++ b/askar-storage/src/backend/sqlite/provision.rs @@ -1,5 +1,6 @@ use std::{ - borrow::Cow, fs::remove_file, io::ErrorKind as IoErrorKind, str::FromStr, time::Duration, + borrow::Cow, fs::remove_file, io::ErrorKind as IoErrorKind, str::FromStr, + thread::available_parallelism, time::Duration, }; use sqlx::{ @@ -10,21 +11,21 @@ use sqlx::{ ConnectOptions, Error as SqlxError, Row, }; -use super::SqliteStore; +use super::SqliteBackend; use crate::{ backend::{ db_utils::{init_keys, random_profile_name}, - types::ManageBackend, + ManageBackend, }, error::Error, future::{unblock, BoxFuture}, + options::{IntoOptions, Options}, protect::{KeyCache, PassKey, StoreKeyMethod, StoreKeyReference}, - storage::{IntoOptions, Options, Store}, }; -const DEFAULT_MIN_CONNECTIONS: u32 = 1; -const DEFAULT_LOWER_MAX_CONNECTIONS: u32 = 2; -const DEFAULT_UPPER_MAX_CONNECTIONS: u32 = 8; +const DEFAULT_MIN_CONNECTIONS: usize = 1; +const DEFAULT_LOWER_MAX_CONNECTIONS: usize = 2; +const DEFAULT_UPPER_MAX_CONNECTIONS: usize = 8; const DEFAULT_BUSY_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_JOURNAL_MODE: SqliteJournalMode = SqliteJournalMode::Wal; const DEFAULT_LOCKING_MODE: SqliteLockingMode = SqliteLockingMode::Normal; @@ -69,16 +70,21 @@ impl SqliteStoreOptions { .parse() .map_err(err_map!(Input, "Error parsing 'max_connections' parameter"))? } else { - (num_cpus::get() as u32) + available_parallelism() + .map_err(err_map!( + Unexpected, + "Error determining available parallelism" + ))? + .get() .max(DEFAULT_LOWER_MAX_CONNECTIONS) - .min(DEFAULT_UPPER_MAX_CONNECTIONS) + .min(DEFAULT_UPPER_MAX_CONNECTIONS) as u32 }; let min_connections = if let Some(min_conn) = opts.query.remove("min_connections") { min_conn .parse() .map_err(err_map!(Input, "Error parsing 'min_connections' parameter"))? } else { - DEFAULT_MIN_CONNECTIONS + DEFAULT_MIN_CONNECTIONS as u32 }; let journal_mode = if let Some(mode) = opts.query.remove("journal_mode") { SqliteJournalMode::from_str(&mode) @@ -105,7 +111,7 @@ impl SqliteStoreOptions { }; let mut path = opts.host.to_string(); - path.push_str(&*opts.path); + path.push_str(&opts.path); Ok(Self { in_memory: path == ":memory:", path, @@ -152,7 +158,7 @@ impl SqliteStoreOptions { pass_key: PassKey<'_>, profile: Option<&'_ str>, recreate: bool, - ) -> Result, Error> { + ) -> Result { if recreate && !self.in_memory { try_remove_file(self.path.to_string()).await?; } @@ -182,12 +188,12 @@ impl SqliteStoreOptions { .unwrap_or_else(random_profile_name); let key_cache = init_db(&conn_pool, &default_profile, method, pass_key).await?; - Ok(Store::new(SqliteStore::new( + Ok(SqliteBackend::new( conn_pool, default_profile, key_cache, self.path.to_string(), - ))) + )) } /// Open an existing Sqlite store from this set of configuration options @@ -196,7 +202,7 @@ impl SqliteStoreOptions { method: Option, pass_key: PassKey<'_>, profile: Option<&'_ str>, - ) -> Result, Error> { + ) -> Result { let conn_pool = match self.pool(false).await { Ok(pool) => Ok(pool), Err(SqlxError::Database(db_err)) => { @@ -240,14 +246,14 @@ impl SqliteStoreOptions { } impl<'a> ManageBackend<'a> for SqliteStoreOptions { - type Store = Store; + type Backend = SqliteBackend; fn open_backend( self, method: Option, pass_key: PassKey<'a>, profile: Option<&'a str>, - ) -> BoxFuture<'a, Result, Error>> { + ) -> BoxFuture<'a, Result> { Box::pin(self.open(method, pass_key, profile)) } @@ -257,7 +263,7 @@ impl<'a> ManageBackend<'a> for SqliteStoreOptions { pass_key: PassKey<'a>, profile: Option<&'a str>, recreate: bool, - ) -> BoxFuture<'a, Result, Error>> { + ) -> BoxFuture<'a, Result> { Box::pin(self.provision(method, pass_key, profile, recreate)) } @@ -361,7 +367,7 @@ async fn open_db( pass_key: PassKey<'_>, profile: Option<&str>, path: String, -) -> Result, Error> { +) -> Result { let mut conn = conn_pool.acquire().await?; let mut ver_ok = false; let mut default_profile: Option = None; @@ -422,9 +428,7 @@ async fn open_db( let profile_key = key_cache.load_key(row.try_get(1)?).await?; key_cache.add_profile_mut(profile.clone(), profile_id, profile_key); - Ok(Store::new(SqliteStore::new( - conn_pool, profile, key_cache, path, - ))) + Ok(SqliteBackend::new(conn_pool, profile, key_cache, path)) } async fn try_remove_file(path: String) -> Result { diff --git a/src/storage/entry.rs b/askar-storage/src/entry.rs similarity index 56% rename from src/storage/entry.rs rename to askar-storage/src/entry.rs index 43fabd48..e37ff06f 100644 --- a/src/storage/entry.rs +++ b/askar-storage/src/entry.rs @@ -1,16 +1,12 @@ +//! Entry type definitions + use std::{ - borrow::Cow, fmt::{self, Debug, Formatter}, pin::Pin, str::FromStr, }; use futures_lite::stream::{Stream, StreamExt}; -use serde::{ - de::{Error as SerdeError, MapAccess, SeqAccess, Visitor}, - ser::SerializeMap, - Deserialize, Deserializer, Serialize, Serializer, -}; use zeroize::Zeroize; use super::wql; @@ -73,9 +69,12 @@ impl PartialEq for Entry { } } +/// Set of distinct entry kinds for separating records. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum EntryKind { + /// Key manager entry Kms = 1, + /// General stored item Item = 2, } @@ -107,7 +106,8 @@ impl EntryTag { } } - pub(crate) fn map_ref(&self, f: impl FnOnce(&str, &str) -> (String, String)) -> Self { + /// Create a new EntryTag using references to the name and value + pub fn map_ref(&self, f: impl FnOnce(&str, &str) -> (String, String)) -> Self { match self { Self::Encrypted(name, val) => { let (name, val) = f(name.as_str(), val.as_str()); @@ -121,7 +121,7 @@ impl EntryTag { } /// Setter for the tag name - pub(crate) fn update_name(&mut self, f: impl FnOnce(&mut String)) { + pub fn update_name(&mut self, f: impl FnOnce(&mut String)) { match self { Self::Encrypted(name, _) | Self::Plaintext(name, _) => f(name), } @@ -135,7 +135,7 @@ impl EntryTag { } /// Unwrap the tag value - pub(crate) fn into_value(self) -> String { + pub fn into_value(self) -> String { match self { Self::Encrypted(_, value) | Self::Plaintext(_, value) => value, } @@ -159,184 +159,6 @@ impl Debug for EntryTag { } } -/// A wrapper type used for managing (de)serialization of tags -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct EntryTagSet<'e>(Cow<'e, [EntryTag]>); - -impl EntryTagSet<'_> { - #[inline] - pub fn into_vec(self) -> Vec { - self.0.into_owned() - } -} - -impl<'e> From<&'e [EntryTag]> for EntryTagSet<'e> { - fn from(tags: &'e [EntryTag]) -> Self { - Self(Cow::Borrowed(tags)) - } -} - -impl From> for EntryTagSet<'static> { - fn from(tags: Vec) -> Self { - Self(Cow::Owned(tags)) - } -} - -impl<'de> Deserialize<'de> for EntryTagSet<'static> { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct TagSetVisitor; - - impl<'d> Visitor<'d> for TagSetVisitor { - type Value = EntryTagSet<'static>; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("an object containing zero or more entry tags") - } - - fn visit_map(self, mut access: M) -> Result - where - M: MapAccess<'d>, - { - let mut v = Vec::with_capacity(access.size_hint().unwrap_or_default()); - - while let Some((key, values)) = access.next_entry::<&str, EntryTagValues>()? { - let (tag, enc) = match key.chars().next() { - Some('~') => (key[1..].to_owned(), false), - None => return Err(M::Error::custom("invalid tag name: empty string")), - _ => (key.to_owned(), true), - }; - match (values, enc) { - (EntryTagValues::Single(value), true) => { - v.push(EntryTag::Encrypted(tag, value)) - } - (EntryTagValues::Single(value), false) => { - v.push(EntryTag::Plaintext(tag, value)) - } - (EntryTagValues::Multiple(values), true) => { - for value in values { - v.push(EntryTag::Encrypted(tag.clone(), value)) - } - } - (EntryTagValues::Multiple(values), false) => { - for value in values { - v.push(EntryTag::Plaintext(tag.clone(), value)) - } - } - } - } - - Ok(EntryTagSet(Cow::Owned(v))) - } - } - - deserializer.deserialize_map(TagSetVisitor) - } -} - -enum EntryTagValues { - Single(String), - Multiple(Vec), -} - -impl<'de> Deserialize<'de> for EntryTagValues { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct TagValuesVisitor; - - impl<'d> Visitor<'d> for TagValuesVisitor { - type Value = EntryTagValues; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("a string or list of strings") - } - - fn visit_str(self, value: &str) -> Result - where - E: SerdeError, - { - Ok(EntryTagValues::Single(value.to_owned())) - } - - fn visit_string(self, value: String) -> Result - where - E: SerdeError, - { - Ok(EntryTagValues::Single(value)) - } - - fn visit_seq(self, mut access: S) -> Result - where - S: SeqAccess<'d>, - { - let mut v = Vec::with_capacity(access.size_hint().unwrap_or_default()); - while let Some(value) = access.next_element()? { - v.push(value) - } - Ok(EntryTagValues::Multiple(v)) - } - } - - deserializer.deserialize_any(TagValuesVisitor) - } -} - -impl Serialize for EntryTagSet<'_> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - use std::collections::BTreeMap; - - #[derive(PartialOrd, Ord)] - struct TagName<'a>(&'a str, bool); - - impl<'a> PartialEq for TagName<'a> { - fn eq(&self, other: &Self) -> bool { - self.1 == other.1 && self.0 == other.0 - } - } - - impl<'a> Eq for TagName<'a> {} - - impl Serialize for TagName<'_> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - if self.1 { - serializer.serialize_str(self.0) - } else { - serializer.collect_str(&format_args!("~{}", self.0)) - } - } - } - - let mut tags = BTreeMap::new(); - for tag in self.0.iter() { - let (name, value) = match tag { - EntryTag::Encrypted(name, val) => (TagName(name.as_str(), true), val.as_str()), - EntryTag::Plaintext(name, val) => (TagName(name.as_str(), false), val.as_str()), - }; - tags.entry(name).or_insert_with(Vec::new).push(value); - } - - let mut map = serializer.serialize_map(Some(tags.len()))?; - for (tag_name, values) in tags.into_iter() { - if values.len() > 1 { - map.serialize_entry(&tag_name, &values)?; - } else { - map.serialize_entry(&tag_name, &values[0])?; - } - } - map.end() - } -} - #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct EncEntryTag { pub name: Vec, @@ -452,6 +274,11 @@ impl TagFilter { pub fn to_string(&self) -> Result { serde_json::to_string(&self.query).map_err(err_map!("Error encoding tag filter")) } + + /// Unwrap into a wql::Query + pub fn into_query(self) -> wql::Query { + self.query + } } impl From for TagFilter { @@ -512,21 +339,3 @@ impl Debug for Scan<'_, S> { .finish() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn serialize_tags() { - let tags = EntryTagSet::from(vec![ - EntryTag::Encrypted("a".to_owned(), "aval".to_owned()), - EntryTag::Plaintext("b".to_owned(), "bval".to_owned()), - EntryTag::Plaintext("b".to_owned(), "bval-2".to_owned()), - ]); - let ser = serde_json::to_string(&tags).unwrap(); - assert_eq!(ser, r#"{"a":"aval","~b":["bval","bval-2"]}"#); - let tags2 = serde_json::from_str(&ser).unwrap(); - assert_eq!(tags, tags2); - } -} diff --git a/askar-storage/src/error.rs b/askar-storage/src/error.rs new file mode 100644 index 00000000..6d0232ec --- /dev/null +++ b/askar-storage/src/error.rs @@ -0,0 +1,190 @@ +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; + +use crate::crypto::{Error as CryptoError, ErrorKind as CryptoErrorKind}; + +/// The possible kinds of error produced by the crate +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ErrorKind { + /// An unexpected error from the store backend + Backend, + + /// The store backend was too busy to handle the request + Busy, + + /// A custom error type for external integrations + Custom, + + /// An insert operation failed due to a unique key conflict + Duplicate, + + /// An encryption or decryption operation failed + Encryption, + + /// The input parameters to the method were incorrect + Input, + + /// The requested record was not found + NotFound, + + /// An unexpected error occurred + Unexpected, + + /// An unsupported operation was requested + Unsupported, +} + +impl ErrorKind { + /// Convert the error kind to a string reference + pub fn as_str(&self) -> &'static str { + match self { + Self::Backend => "Backend error", + Self::Busy => "Busy", + Self::Custom => "Custom error", + Self::Duplicate => "Duplicate", + Self::Encryption => "Encryption error", + Self::Input => "Input error", + Self::NotFound => "Not found", + Self::Unexpected => "Unexpected error", + Self::Unsupported => "Unsupported", + } + } +} + +impl Display for ErrorKind { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// The standard crate error type +#[derive(Debug)] +pub struct Error { + pub(crate) kind: ErrorKind, + pub(crate) cause: Option>, + pub(crate) message: Option, +} + +impl Error { + pub(crate) fn from_msg>(kind: ErrorKind, msg: T) -> Self { + Self { + kind, + cause: None, + message: Some(msg.into()), + } + } + + /// Accessor for the error kind + pub fn kind(&self) -> ErrorKind { + self.kind + } + + /// Accessor for the error message + pub fn message(&self) -> Option<&str> { + self.message.as_deref() + } + + /// Split the error into its components + pub fn into_parts( + self, + ) -> ( + ErrorKind, + Option>, + Option, + ) { + (self.kind, self.cause, self.message) + } + + pub(crate) fn with_cause>>( + mut self, + err: T, + ) -> Self { + self.cause = Some(err.into()); + self + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if let Some(msg) = self.message.as_ref() { + f.write_str(msg)?; + } else { + f.write_str(self.kind.as_str())?; + } + if let Some(cause) = self.cause.as_ref() { + write!(f, "\nCaused by: {}", cause)?; + } + Ok(()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause + .as_ref() + .map(|err| &**err as &(dyn StdError + 'static)) + } +} + +impl PartialEq for Error { + fn eq(&self, other: &Self) -> bool { + self.kind == other.kind && self.message == other.message + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { + kind, + cause: None, + message: None, + } + } +} + +// FIXME would be preferable to remove this auto-conversion and handle +// all sqlx errors manually, to ensure there is some context around the error +#[cfg(any(feature = "indy_compat", feature = "postgres", feature = "sqlite"))] +impl From for Error { + fn from(err: sqlx::Error) -> Self { + Error::from(ErrorKind::Backend).with_cause(err) + } +} + +impl From for Error { + fn from(err: CryptoError) -> Self { + let kind = match err.kind() { + CryptoErrorKind::Custom => ErrorKind::Custom, + CryptoErrorKind::Encryption => ErrorKind::Encryption, + CryptoErrorKind::ExceededBuffer | CryptoErrorKind::Unexpected => ErrorKind::Unexpected, + CryptoErrorKind::Invalid + | CryptoErrorKind::InvalidKeyData + | CryptoErrorKind::InvalidNonce + | CryptoErrorKind::MissingSecretKey + | CryptoErrorKind::Usage => ErrorKind::Input, + CryptoErrorKind::Unsupported => ErrorKind::Unsupported, + }; + Error::from_msg(kind, err.message()) + } +} + +macro_rules! err_msg { + () => { + $crate::error::Error::from($crate::error::ErrorKind::Input) + }; + ($kind:ident) => { + $crate::error::Error::from($crate::error::ErrorKind::$kind) + }; + ($kind:ident, $($args:tt)+) => { + $crate::error::Error::from_msg($crate::error::ErrorKind::$kind, format!($($args)+)) + }; + ($($args:tt)+) => { + $crate::error::Error::from_msg($crate::error::ErrorKind::Input, format!($($args)+)) + }; +} + +macro_rules! err_map { + ($($params:tt)*) => { + |err| err_msg!($($params)*).with_cause(err) + }; +} diff --git a/src/future.rs b/askar-storage/src/future.rs similarity index 100% rename from src/future.rs rename to askar-storage/src/future.rs diff --git a/askar-storage/src/lib.rs b/askar-storage/src/lib.rs new file mode 100644 index 00000000..a744a9f6 --- /dev/null +++ b/askar-storage/src/lib.rs @@ -0,0 +1,57 @@ +//! Secure storage designed for Hyperledger Aries agents + +#![cfg_attr(docsrs, feature(doc_cfg))] +#![deny(missing_docs, missing_debug_implementations, rust_2018_idioms)] + +pub use askar_crypto as crypto; + +#[macro_use] +mod error; +pub use self::error::{Error, ErrorKind}; + +#[cfg(test)] +#[macro_use] +extern crate hex_literal; + +#[macro_use] +mod macros; + +#[cfg(any(test, feature = "log"))] +#[macro_use] +extern crate log; + +#[cfg(feature = "migration")] +#[macro_use] +extern crate serde; + +pub mod backend; +pub use self::backend::{Backend, BackendSession, ManageBackend}; + +#[cfg(feature = "any")] +pub mod any; + +#[cfg(feature = "postgres")] +pub use self::backend::postgres; + +#[cfg(feature = "sqlite")] +pub use self::backend::sqlite; + +pub mod entry; + +#[doc(hidden)] +pub mod future; + +#[cfg(all(feature = "migration", feature = "sqlite"))] +pub mod migration; + +mod options; +pub use options::{IntoOptions, Options}; + +mod protect; +pub use protect::{ + generate_raw_store_key, + kdf::{Argon2Level, KdfMethod}, + PassKey, StoreKeyMethod, +}; + +mod wql; diff --git a/src/macros.rs b/askar-storage/src/macros.rs similarity index 100% rename from src/macros.rs rename to askar-storage/src/macros.rs diff --git a/src/migration/mod.rs b/askar-storage/src/migration/mod.rs similarity index 98% rename from src/migration/mod.rs rename to askar-storage/src/migration/mod.rs index 8005ddf0..be0a5809 100644 --- a/src/migration/mod.rs +++ b/askar-storage/src/migration/mod.rs @@ -8,18 +8,19 @@ use std::str::FromStr; use self::strategy::Strategy; use crate::backend::sqlite::SqliteStoreOptions; +use crate::backend::Backend; use crate::crypto::alg::chacha20::{Chacha20Key, C20P}; use crate::crypto::generic_array::typenum::U32; +use crate::entry::EncEntryTag; use crate::error::Error; use crate::protect::kdf::Argon2Level; use crate::protect::{ProfileKey, StoreKey, StoreKeyReference}; -use crate::storage::EncEntryTag; mod strategy; const CHACHAPOLY_NONCE_LEN: u8 = 12; -#[derive(Serialize, Deserialize, Debug, Default)] +#[derive(Deserialize, Debug, Default)] pub(crate) struct IndyKeyMetadata { keys: Vec, #[serde(skip_serializing_if = "Option::is_none")] @@ -30,14 +31,16 @@ pub(crate) type EncryptionKey = Chacha20Key; pub(crate) type MacKey = crate::protect::hmac_key::HmacKey; /// Copies: https://github.com/hyperledger/indy-sdk/blob/83547c4c01162f6323cf138f8b071da2e15f0c90/libindy/indy-wallet/src/wallet.rs#L18 -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub(crate) struct IndyKey { type_key: EncryptionKey, name_key: EncryptionKey, value_key: EncryptionKey, + #[allow(unused)] item_hmac_key: MacKey, tag_name_key: EncryptionKey, tag_value_key: EncryptionKey, + #[allow(unused)] tag_hmac_key: MacKey, } diff --git a/src/migration/strategy.rs b/askar-storage/src/migration/strategy.rs similarity index 96% rename from src/migration/strategy.rs rename to askar-storage/src/migration/strategy.rs index b7be6b49..b8c64bfd 100644 --- a/src/migration/strategy.rs +++ b/askar-storage/src/migration/strategy.rs @@ -5,8 +5,8 @@ use super::{ use crate::crypto::buffer::SecretBytes; use crate::crypto::encrypt::KeyAeadInPlace; use crate::crypto::repr::KeySecretBytes; +use crate::entry::EntryTag; use crate::protect::EntryEncryptor; -use crate::storage::EntryTag; use crate::Error; #[derive(Default)] @@ -139,8 +139,8 @@ impl Strategy { Some(rows) => { let mut upd = vec![]; for row in rows { - let result = Self::decrypt_item(row, &indy_key)?; - upd.push(Self::update_item(result, &profile_key)?); + let result = Self::decrypt_item(row, indy_key)?; + upd.push(Self::update_item(result, profile_key)?); } conn.update_items_in_db(upd).await?; } diff --git a/src/storage/options.rs b/askar-storage/src/options.rs similarity index 92% rename from src/storage/options.rs rename to askar-storage/src/options.rs index 37aa5bbd..9b3f739c 100644 --- a/src/storage/options.rs +++ b/askar-storage/src/options.rs @@ -6,17 +6,26 @@ use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC use crate::error::Error; #[derive(Clone, Debug, Default, PartialEq, Eq)] +/// Parsed representation of database connection URI pub struct Options<'a> { + /// The URI schema pub schema: Cow<'a, str>, + /// The authenticating user name pub user: Cow<'a, str>, + /// The authenticating user password pub password: Cow<'a, str>, + /// The host name pub host: Cow<'a, str>, + /// The path component pub path: Cow<'a, str>, + /// The query component pub query: HashMap, + /// The fragment component pub fragment: Cow<'a, str>, } impl<'a> Options<'a> { + /// Parse a URI string into an Options structure pub fn parse_uri(uri: &str) -> Result, Error> { let mut fragment_and_remain = uri.splitn(2, '#'); let uri = fragment_and_remain.next().unwrap_or_default(); @@ -79,6 +88,7 @@ impl<'a> Options<'a> { }) } + /// Convert an options structure back into a string pub fn into_uri(self) -> String { let mut uri = String::new(); if !self.schema.is_empty() { @@ -126,7 +136,9 @@ fn percent_encode_into(result: &mut String, s: &str) { push_iter_str(result, utf8_percent_encode(s, NON_ALPHANUMERIC)) } +/// A trait implemented by types that can be converted into Options pub trait IntoOptions<'a> { + /// Try to convert self into an Options structure fn into_options(self) -> Result, Error>; } diff --git a/src/protect/hmac_key.rs b/askar-storage/src/protect/hmac_key.rs similarity index 98% rename from src/protect/hmac_key.rs rename to askar-storage/src/protect/hmac_key.rs index 611640b7..c1ad6a64 100644 --- a/src/protect/hmac_key.rs +++ b/askar-storage/src/protect/hmac_key.rs @@ -54,7 +54,7 @@ impl> AsRef> for HmacKey { impl> Debug for HmacKey { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if cfg!(test) { - f.debug_tuple("HmacKey").field(&*self).finish() + f.debug_tuple("HmacKey").field(&self.0).finish() } else { f.debug_tuple("HmacKey").field(&"").finish() } diff --git a/src/protect/kdf/argon2.rs b/askar-storage/src/protect/kdf/argon2.rs similarity index 100% rename from src/protect/kdf/argon2.rs rename to askar-storage/src/protect/kdf/argon2.rs diff --git a/src/protect/kdf/mod.rs b/askar-storage/src/protect/kdf/mod.rs similarity index 98% rename from src/protect/kdf/mod.rs rename to askar-storage/src/protect/kdf/mod.rs index 0d36265b..2eb75d9b 100644 --- a/src/protect/kdf/mod.rs +++ b/askar-storage/src/protect/kdf/mod.rs @@ -1,8 +1,10 @@ +//! Key derivations + use super::store_key::{StoreKey, PREFIX_KDF}; use crate::{ crypto::{buffer::ArrayKey, generic_array::ArrayLength}, error::Error, - storage::Options, + options::Options, }; mod argon2; diff --git a/src/protect/mod.rs b/askar-storage/src/protect/mod.rs similarity index 98% rename from src/protect/mod.rs rename to askar-storage/src/protect/mod.rs index d4c4941e..9ec5d942 100644 --- a/src/protect/mod.rs +++ b/askar-storage/src/protect/mod.rs @@ -1,3 +1,5 @@ +//! Storage encryption + use std::{collections::HashMap, sync::Arc}; use async_lock::RwLock; @@ -17,9 +19,9 @@ pub use self::store_key::{generate_raw_store_key, StoreKey, StoreKeyMethod, Stor use crate::{ crypto::buffer::SecretBytes, + entry::{EncEntryTag, EntryTag}, error::Error, future::unblock, - storage::{EncEntryTag, EntryTag}, }; pub type ProfileId = i64; diff --git a/src/protect/pass_key.rs b/askar-storage/src/protect/pass_key.rs similarity index 92% rename from src/protect/pass_key.rs rename to askar-storage/src/protect/pass_key.rs index ad1a4aef..e5fbb789 100644 --- a/src/protect/pass_key.rs +++ b/askar-storage/src/protect/pass_key.rs @@ -26,7 +26,8 @@ impl PassKey<'_> { self.0.is_none() } - pub(crate) fn into_owned(self) -> PassKey<'static> { + /// Convert to an owned instance, allocating if necessary + pub fn into_owned(self) -> PassKey<'static> { let mut slf = ManuallyDrop::new(self); let val = slf.0.take(); PassKey(match val { @@ -40,7 +41,7 @@ impl PassKey<'_> { impl Debug for PassKey<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if cfg!(test) { - f.debug_tuple("PassKey").field(&*self).finish() + f.debug_tuple("PassKey").field(&self.0).finish() } else { f.debug_tuple("PassKey").field(&"").finish() } diff --git a/src/protect/profile_key.rs b/askar-storage/src/protect/profile_key.rs similarity index 99% rename from src/protect/profile_key.rs rename to askar-storage/src/protect/profile_key.rs index e281da09..fba692e3 100644 --- a/src/protect/profile_key.rs +++ b/askar-storage/src/protect/profile_key.rs @@ -12,8 +12,8 @@ use crate::{ kdf::FromKeyDerivation, repr::KeyGen, }, + entry::{EncEntryTag, EntryTag}, error::Error, - storage::{EncEntryTag, EntryTag}, }; pub type ProfileKey = ProfileKeyImpl, HmacKey>; @@ -250,7 +250,7 @@ fn decode_utf8(value: Vec) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::storage::Entry; + use crate::entry::Entry; #[test] fn encrypt_entry_round_trip() { diff --git a/src/protect/store_key.rs b/askar-storage/src/protect/store_key.rs similarity index 97% rename from src/protect/store_key.rs rename to askar-storage/src/protect/store_key.rs index 11d4e36c..a70842e3 100644 --- a/src/protect/store_key.rs +++ b/askar-storage/src/protect/store_key.rs @@ -114,7 +114,8 @@ pub enum StoreKeyMethod { } impl StoreKeyMethod { - pub(crate) fn parse_uri(uri: &str) -> Result { + /// Parse a URI string into a store key method + pub fn parse_uri(uri: &str) -> Result { let mut prefix_and_detail = uri.splitn(2, ':'); let prefix = prefix_and_detail.next().unwrap_or_default(); // let detail = prefix_and_detail.next().unwrap_or_default(); @@ -138,7 +139,7 @@ impl StoreKeyMethod { // Self::ExistingManagedKey(String) => unimplemented!(), Self::DeriveKey(method) => { if !pass_key.is_none() { - let (key, detail) = method.derive_new_key(&*pass_key)?; + let (key, detail) = method.derive_new_key(&pass_key)?; let key_ref = StoreKeyReference::DeriveKey(*method, detail); Ok((key, key_ref)) } else { @@ -147,7 +148,7 @@ impl StoreKeyMethod { } Self::RawKey => { let key = if !pass_key.is_empty() { - parse_raw_store_key(&*pass_key)? + parse_raw_store_key(&pass_key)? } else { StoreKey::random()? }; @@ -225,14 +226,14 @@ impl StoreKeyReference { // Self::ManagedKey(_key_ref) => unimplemented!(), Self::DeriveKey(method, detail) => { if !pass_key.is_none() { - method.derive_key(&*pass_key, detail) + method.derive_key(&pass_key, detail) } else { Err(err_msg!(Input, "Key derivation password not provided")) } } Self::RawKey => { if !pass_key.is_empty() { - parse_raw_store_key(&*pass_key) + parse_raw_store_key(&pass_key) } else { Err(err_msg!(Input, "Encoded raw key not provided")) } diff --git a/src/storage/wql/mod.rs b/askar-storage/src/wql/mod.rs similarity index 100% rename from src/storage/wql/mod.rs rename to askar-storage/src/wql/mod.rs diff --git a/src/storage/wql/query.rs b/askar-storage/src/wql/query.rs similarity index 100% rename from src/storage/wql/query.rs rename to askar-storage/src/wql/query.rs diff --git a/src/storage/wql/sql.rs b/askar-storage/src/wql/sql.rs similarity index 98% rename from src/storage/wql/sql.rs rename to askar-storage/src/wql/sql.rs index 810fbe7c..f644cc23 100644 --- a/src/storage/wql/sql.rs +++ b/askar-storage/src/wql/sql.rs @@ -84,7 +84,7 @@ where op.as_sql_str(), idx + 2, op_prefix.as_str(), - if is_plaintext { 1 } else { 0 } + i32::from(is_plaintext) ); Ok(Some(query)) } @@ -102,7 +102,7 @@ where "i.id {} (SELECT item_id FROM items_tags WHERE name = $$ AND value IN ({}) AND plaintext = {})", if negate { "NOT IN" } else { "IN" }, args_in, - if is_plaintext { 1 } else { 0 } + i32::from(is_plaintext) ); self.arguments.push(enc_name); self.arguments.extend(enc_values); @@ -118,7 +118,7 @@ where let query = format!( "i.id {} (SELECT item_id FROM items_tags WHERE name = $$ AND plaintext = {})", if negate { "NOT IN" } else { "IN" }, - if is_plaintext { 1 } else { 0 } + i32::from(is_plaintext) ); self.arguments.push(enc_name); Ok(Some(query)) diff --git a/src/storage/wql/tags.rs b/askar-storage/src/wql/tags.rs similarity index 100% rename from src/storage/wql/tags.rs rename to askar-storage/src/wql/tags.rs diff --git a/tests/.gitignore b/askar-storage/tests/.gitignore similarity index 100% rename from tests/.gitignore rename to askar-storage/tests/.gitignore diff --git a/tests/backends.rs b/askar-storage/tests/backends.rs similarity index 59% rename from tests/backends.rs rename to askar-storage/tests/backends.rs index b6b123bc..1c6f0a55 100644 --- a/tests/backends.rs +++ b/askar-storage/tests/backends.rs @@ -5,215 +5,102 @@ mod utils; const ERR_CLOSE: &str = "Error closing database"; macro_rules! backend_tests { - ($init:expr) => { - use aries_askar::future::block_on; - use std::sync::Arc; - use $crate::utils::TestStore; - + ($run:expr) => { #[test] fn init() { - block_on(async { - let db = $init.await; - db.close().await.expect(ERR_CLOSE); - }); + $run(|db| async move { + let _ = db; + }) } #[test] fn create_remove_profile() { - block_on(async { - let db = $init.await; - super::utils::db_create_remove_profile(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_create_remove_profile) } #[test] fn fetch_fail() { - block_on(async { - let db = $init.await; - super::utils::db_fetch_fail(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_fetch_fail) } #[test] fn insert_fetch() { - block_on(async { - let db = $init.await; - super::utils::db_insert_fetch(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_insert_fetch) } #[test] fn insert_duplicate() { - block_on(async { - let db = $init.await; - super::utils::db_insert_duplicate(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_insert_duplicate) } #[test] fn insert_remove() { - block_on(async { - let db = $init.await; - super::utils::db_insert_remove(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_insert_remove) } #[test] fn remove_missing() { - block_on(async { - let db = $init.await; - super::utils::db_remove_missing(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_remove_missing) } #[test] fn replace_fetch() { - block_on(async { - let db = $init.await; - super::utils::db_replace_fetch(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_replace_fetch) } #[test] fn replace_missing() { - block_on(async { - let db = $init.await; - super::utils::db_replace_missing(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_replace_missing) } #[test] fn count() { - block_on(async { - let db = $init.await; - super::utils::db_count(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_count) } #[test] fn count_exist() { - block_on(async { - let db = $init.await; - super::utils::db_count_exist(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_count_exist) } #[test] fn scan() { - block_on(async { - let db = $init.await; - super::utils::db_scan(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_scan) } #[test] fn remove_all() { - block_on(async { - let db = $init.await; - super::utils::db_remove_all(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) - } - - #[test] - fn keypair_create_fetch() { - block_on(async { - let db = $init.await; - super::utils::db_keypair_insert_fetch(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_remove_all) } - // #[test] - // fn keypair_sign_verify() { - // block_on(async { - // let db = $init.await; - // super::utils::db_keypair_sign_verify(db.clone()).await; - // db.close().await.expect(ERR_CLOSE); - // }) - // } - - // #[test] - // fn keypair_pack_unpack_anon() { - // block_on(async { - // let db = $init.await; - // super::utils::db_keypair_pack_unpack_anon(db.clone()).await; - // db.close().await.expect(ERR_CLOSE); - // }) - // } - - // #[test] - // fn keypair_pack_unpack_auth() { - // block_on(async { - // let db = $init.await; - // super::utils::db_keypair_pack_unpack_auth(db).await; - // db.close().await.expect(ERR_CLOSE); - // }) - // } - #[test] fn txn_rollback() { - block_on(async { - let db = $init.await; - super::utils::db_txn_rollback(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_txn_rollback) } #[test] fn txn_drop() { - block_on(async { - let db = $init.await; - super::utils::db_txn_drop(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_txn_drop) } #[test] fn session_drop() { - block_on(async { - let db = $init.await; - super::utils::db_session_drop(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_session_drop) } #[test] fn txn_commit() { - block_on(async { - let db = $init.await; - super::utils::db_txn_commit(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_txn_commit) } #[test] fn txn_fetch_for_update() { - block_on(async { - let db = $init.await; - super::utils::db_txn_fetch_for_update(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_txn_fetch_for_update) } #[test] fn txn_contention() { - block_on(async { - let db = $init.await; - super::utils::db_txn_contention(db.clone()).await; - db.close().await.expect(ERR_CLOSE); - }) + $run(super::utils::db_txn_contention) } }; } @@ -224,9 +111,11 @@ fn log_init() { #[cfg(feature = "sqlite")] mod sqlite { - use aries_askar::backend::sqlite::{SqliteStore, SqliteStoreOptions}; - use aries_askar::{generate_raw_store_key, ManageBackend, Store, StoreKeyMethod}; - use std::path::Path; + use askar_storage::any::{into_any_backend, AnyBackend}; + use askar_storage::backend::sqlite::SqliteStoreOptions; + use askar_storage::future::block_on; + use askar_storage::{generate_raw_store_key, Backend, ManageBackend, StoreKeyMethod}; + use std::{future::Future, path::Path}; use super::*; @@ -332,7 +221,7 @@ mod sqlite { .await .expect("Error provisioning sqlite store"); - let db = std::sync::Arc::new(store); + let db = into_any_backend(store); super::utils::db_txn_contention(db.clone()).await; db.close().await.expect("Error closing sqlite store"); @@ -413,18 +302,26 @@ mod sqlite { }); } - async fn init_db() -> Arc> { + fn with_sqlite_in_memory(f: F) + where + F: FnOnce(AnyBackend) -> G, + G: Future, + { log_init(); - let key = generate_raw_store_key(None).expect("Error creating raw key"); - Arc::new( - SqliteStoreOptions::in_memory() - .provision(StoreKeyMethod::RawKey, key, None, false) - .await - .expect("Error provisioning sqlite store"), - ) + let key = generate_raw_store_key(None).expect("Error generating store key"); + block_on(async move { + let db = into_any_backend( + SqliteStoreOptions::in_memory() + .provision(StoreKeyMethod::RawKey, key, None, false) + .await + .expect("Error provisioning sqlite store"), + ); + f(db.clone()).await; + db.close().await.expect(ERR_CLOSE); + }) } - backend_tests!(init_db()); + backend_tests!(with_sqlite_in_memory); #[test] fn provision_from_str() { @@ -450,39 +347,31 @@ mod sqlite { #[cfg(feature = "pg_test")] mod postgres { - use aries_askar::{backend::postgres::test_db::TestDB, postgres::PostgresStore, Store}; - use std::{future::Future, ops::Deref, pin::Pin}; + use askar_storage::any::AnyBackend; + use askar_storage::backend::postgres::TestDB; + use askar_storage::future::block_on; + use std::future::Future; use super::*; - #[derive(Clone, Debug)] - struct Wrap(Arc); - - impl Deref for Wrap { - type Target = Store; - - fn deref(&self) -> &Self::Target { - &**self.0 - } - } - - impl TestStore for Wrap { - type DB = PostgresStore; - - fn close(self) -> Pin>>> { - let db = Arc::try_unwrap(self.0).unwrap(); - Box::pin(db.close()) - } - } - - async fn init_db() -> Wrap { + fn with_postgres(f: F) + where + F: FnOnce(AnyBackend) -> G, + G: Future, + { + let db_url = match std::env::var("POSTGRES_URL") { + Ok(p) if !p.is_empty() => p, + _ => panic!("'POSTGRES_URL' must be defined"), + }; log_init(); - Wrap(Arc::new( - TestDB::provision() + block_on(async move { + let db = TestDB::provision(db_url.as_str()) .await - .expect("Error provisioning postgres test database"), - )) + .expect("Error provisioning postgres test database"); + f(db.backend()).await; + db.close().await.expect(ERR_CLOSE); + }) } - backend_tests!(init_db()); + backend_tests!(with_postgres); } diff --git a/tests/docker_pg.sh b/askar-storage/tests/docker_pg.sh similarity index 100% rename from tests/docker_pg.sh rename to askar-storage/tests/docker_pg.sh diff --git a/tests/indy_wallet_sqlite.db b/askar-storage/tests/indy_wallet_sqlite.db similarity index 100% rename from tests/indy_wallet_sqlite.db rename to askar-storage/tests/indy_wallet_sqlite.db diff --git a/tests/migration.rs b/askar-storage/tests/migration.rs similarity index 84% rename from tests/migration.rs rename to askar-storage/tests/migration.rs index d1c20add..102ee761 100644 --- a/tests/migration.rs +++ b/askar-storage/tests/migration.rs @@ -1,7 +1,10 @@ +#![cfg(all(feature = "sqlite", feature = "migration"))] + use std::path::PathBuf; -use aries_askar::migration::IndySdkToAriesAskarMigration; -use aries_askar::{future::block_on, Error}; +use askar_storage::future::block_on; +use askar_storage::migration::IndySdkToAriesAskarMigration; +use askar_storage::Error; const DB_TEMPLATE_PATH: &str = "./tests/indy_wallet_sqlite.db"; const DB_UPGRADE_PATH: &str = "./tests/indy_wallet_sqlite_upgraded.db"; @@ -28,14 +31,14 @@ fn prepare_db() { } #[test] -fn test_migration() { +fn test_sqlite_migration() { prepare_db(); let res = block_on(async { let wallet_name = "walletwallet.0"; let wallet_key = "GfwU1DC7gEZNs3w41tjBiZYj7BNToDoFEqKY6wZXqs1A"; let migrator = - IndySdkToAriesAskarMigration::connect(DB_UPGRADE_PATH, wallet_name, &wallet_key, "RAW") + IndySdkToAriesAskarMigration::connect(DB_UPGRADE_PATH, wallet_name, wallet_key, "RAW") .await?; migrator.migrate().await?; Result::<_, Error>::Ok(()) diff --git a/tests/utils/mod.rs b/askar-storage/tests/utils/mod.rs similarity index 61% rename from tests/utils/mod.rs rename to askar-storage/tests/utils/mod.rs index 04afabf8..3d5d65aa 100644 --- a/tests/utils/mod.rs +++ b/askar-storage/tests/utils/mod.rs @@ -1,8 +1,7 @@ -use std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc}; - -use aries_askar::{ - kms::{KeyAlg, LocalKey}, - Backend, Entry, EntryTag, Error, ErrorKind, Store, TagFilter, +use askar_storage::{ + any::AnyBackend, + entry::{Entry, EntryKind, EntryOperation, EntryTag, TagFilter}, + BackendSession, ErrorKind, }; use tokio::task::spawn; @@ -21,27 +20,8 @@ const ERR_REPLACE: &str = "Error replacing test row"; const ERR_REMOVE_ALL: &str = "Error removing test rows"; const ERR_SCAN: &str = "Error starting scan"; const ERR_SCAN_NEXT: &str = "Error fetching scan rows"; -const ERR_CREATE_KEYPAIR: &str = "Error creating keypair"; -const ERR_INSERT_KEY: &str = "Error inserting key"; -const ERR_FETCH_KEY: &str = "Error fetching key"; -const ERR_LOAD_KEY: &str = "Error loading key"; - -pub trait TestStore: Clone + Deref> + Send + Sync { - type DB: Backend + Debug + 'static; - - fn close(self) -> Pin>>>; -} - -impl TestStore for Arc> { - type DB = B; - fn close(self) -> Pin>>> { - let db = Arc::try_unwrap(self).unwrap(); - Box::pin(db.close()) - } -} - -pub async fn db_create_remove_profile(db: impl TestStore) { +pub async fn db_create_remove_profile(db: AnyBackend) { let profile = db.create_profile(None).await.expect(ERR_PROFILE); assert!(db .remove_profile(profile) @@ -53,13 +33,16 @@ pub async fn db_create_remove_profile(db: impl TestStore) { .expect("Error removing profile"),); } -pub async fn db_fetch_fail(db: impl TestStore) { - let mut conn = db.session(None).await.expect(ERR_SESSION); - let result = conn.fetch("cat", "name", false).await.expect(ERR_FETCH); +pub async fn db_fetch_fail(db: AnyBackend) { + let mut conn = db.session(None, false).expect(ERR_SESSION); + let result = conn + .fetch(EntryKind::Item, "cat", "name", false) + .await + .expect(ERR_FETCH); assert!(result.is_none()); } -pub async fn db_insert_fetch(db: impl TestStore) { +pub async fn db_insert_fetch(db: AnyBackend) { let test_row = Entry::new( "category", "name", @@ -70,12 +53,14 @@ pub async fn db_insert_fetch(db: impl TestStore) { ], ); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -83,29 +68,37 @@ pub async fn db_insert_fetch(db: impl TestStore) { .expect(ERR_INSERT); let row = conn - .fetch(&test_row.category, &test_row.name, false) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, false) .await .expect(ERR_FETCH) .expect(ERR_REQ_ROW); assert_eq!(row, test_row); let rows = conn - .fetch_all(&test_row.category, None, None, false) + .fetch_all( + Some(EntryKind::Item), + Some(&test_row.category), + None, + None, + false, + ) .await .expect(ERR_FETCH_ALL); assert_eq!(rows.len(), 1); assert_eq!(rows[0], test_row); } -pub async fn db_insert_duplicate(db: impl TestStore) { +pub async fn db_insert_duplicate(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -113,10 +106,12 @@ pub async fn db_insert_duplicate(db: impl TestStore) { .expect(ERR_INSERT); let err = conn - .insert( + .update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -125,42 +120,65 @@ pub async fn db_insert_duplicate(db: impl TestStore) { assert_eq!(err.kind(), ErrorKind::Duplicate); } -pub async fn db_insert_remove(db: impl TestStore) { +pub async fn db_insert_remove(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) .await .expect(ERR_INSERT); - conn.remove(&test_row.category, &test_row.name) - .await - .expect(ERR_REQ_ROW); + conn.update( + EntryKind::Item, + EntryOperation::Remove, + &test_row.category, + &test_row.name, + None, + None, + None, + ) + .await + .expect(ERR_REQ_ROW); } -pub async fn db_remove_missing(db: impl TestStore) { - let mut conn = db.session(None).await.expect(ERR_SESSION); +pub async fn db_remove_missing(db: AnyBackend) { + let mut conn = db.session(None, false).expect(ERR_SESSION); - let err = conn.remove("cat", "name").await.expect_err(ERR_REQ_ERR); + let err = conn + .update( + EntryKind::Item, + EntryOperation::Remove, + "cat", + "name", + None, + None, + None, + ) + .await + .expect_err(ERR_REQ_ERR); assert_eq!(err.kind(), ErrorKind::NotFound); } -pub async fn db_replace_fetch(db: impl TestStore) { +pub async fn db_replace_fetch(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -169,10 +187,12 @@ pub async fn db_replace_fetch(db: impl TestStore) { let mut replace_row = test_row.clone(); replace_row.value = "new value".into(); - conn.replace( + conn.update( + EntryKind::Item, + EntryOperation::Replace, &replace_row.category, &replace_row.name, - &replace_row.value, + Some(&replace_row.value), Some(replace_row.tags.as_slice()), None, ) @@ -180,23 +200,30 @@ pub async fn db_replace_fetch(db: impl TestStore) { .expect(ERR_REPLACE); let row = conn - .fetch(&replace_row.category, &replace_row.name, false) + .fetch( + EntryKind::Item, + &replace_row.category, + &replace_row.name, + false, + ) .await .expect(ERR_FETCH) .expect(ERR_REQ_ROW); assert_eq!(row, replace_row); } -pub async fn db_replace_missing(db: impl TestStore) { +pub async fn db_replace_missing(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); let err = conn - .replace( + .update( + EntryKind::Item, + EntryOperation::Replace, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -205,17 +232,19 @@ pub async fn db_replace_missing(db: impl TestStore) { assert_eq!(err.kind(), ErrorKind::NotFound); } -pub async fn db_count(db: impl TestStore) { +pub async fn db_count(db: AnyBackend) { let category = "category".to_string(); let test_rows = vec![Entry::new(&category, "name", "value", Vec::new())]; - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); for upd in test_rows.iter() { - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &upd.category, &upd.name, - &upd.value, + Some(&upd.value), Some(upd.tags.as_slice()), None, ) @@ -224,15 +253,21 @@ pub async fn db_count(db: impl TestStore) { } let tag_filter = None; - let count = conn.count(&category, tag_filter).await.expect(ERR_COUNT); + let count = conn + .count(Some(EntryKind::Item), Some(&category), tag_filter) + .await + .expect(ERR_COUNT); assert_eq!(count, 1); let tag_filter = Some(TagFilter::is_eq("sometag", "someval")); - let count = conn.count(&category, tag_filter).await.expect(ERR_COUNT); + let count = conn + .count(Some(EntryKind::Item), Some(&category), tag_filter) + .await + .expect(ERR_COUNT); assert_eq!(count, 0); } -pub async fn db_count_exist(db: impl TestStore) { +pub async fn db_count_exist(db: AnyBackend) { let test_row = Entry::new( "category", "name", @@ -243,21 +278,38 @@ pub async fn db_count_exist(db: impl TestStore) { ], ); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) .await .expect(ERR_INSERT); + assert_eq!( + conn.count(Some(EntryKind::Item), Some(&test_row.category), None) + .await + .expect(ERR_COUNT), + 1 + ); + + assert_eq!( + conn.count(Some(EntryKind::Kms), Some(&test_row.category), None) + .await + .expect(ERR_COUNT), + 0 + ); + assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::exist(vec!["enc".to_string()])) ) .await @@ -267,7 +319,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::exist(vec!["~plain".to_string()])) ) .await @@ -277,7 +330,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::exist(vec!["~enc".to_string()])) ) .await @@ -287,7 +341,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::exist(vec!["plain".to_string()])) ) .await @@ -297,7 +352,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::exist(vec!["other".to_string()])) ) .await @@ -307,7 +363,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::exist(vec![ "enc".to_string(), "other".to_string() @@ -320,7 +377,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::all_of(vec![ TagFilter::exist(vec!["enc".to_string()]), TagFilter::exist(vec!["~plain".to_string()]) @@ -333,7 +391,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::any_of(vec![ TagFilter::exist(vec!["~enc".to_string()]), TagFilter::exist(vec!["~plain".to_string()]) @@ -346,7 +405,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::all_of(vec![ TagFilter::exist(vec!["~enc".to_string()]), TagFilter::exist(vec!["~plain".to_string()]) @@ -359,7 +419,8 @@ pub async fn db_count_exist(db: impl TestStore) { assert_eq!( conn.count( - &test_row.category, + Some(EntryKind::Item), + Some(&test_row.category), Some(TagFilter::negate(TagFilter::exist(vec![ "enc".to_string(), "other".to_string() @@ -371,7 +432,7 @@ pub async fn db_count_exist(db: impl TestStore) { ); } -pub async fn db_scan(db: impl TestStore) { +pub async fn db_scan(db: AnyBackend) { let category = "category".to_string(); let test_rows = vec![Entry::new( &category, @@ -383,13 +444,15 @@ pub async fn db_scan(db: impl TestStore) { ], )]; - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); for upd in test_rows.iter() { - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &upd.category, &upd.name, - &upd.value, + Some(&upd.value), Some(upd.tags.as_slice()), None, ) @@ -402,7 +465,14 @@ pub async fn db_scan(db: impl TestStore) { let offset = None; let limit = None; let mut scan = db - .scan(None, category.clone(), tag_filter, offset, limit) + .scan( + None, + Some(EntryKind::Item), + Some(category.clone()), + tag_filter, + offset, + limit, + ) .await .expect(ERR_SCAN); let rows = scan.fetch_next().await.expect(ERR_SCAN_NEXT); @@ -412,14 +482,21 @@ pub async fn db_scan(db: impl TestStore) { let tag_filter = Some(TagFilter::is_eq("sometag", "someval")); let mut scan = db - .scan(None, category.clone(), tag_filter, offset, limit) + .scan( + None, + Some(EntryKind::Item), + Some(category.clone()), + tag_filter, + offset, + limit, + ) .await .expect(ERR_SCAN); let rows = scan.fetch_next().await.expect(ERR_SCAN_NEXT); assert_eq!(rows, None); } -pub async fn db_remove_all(db: impl TestStore) { +pub async fn db_remove_all(db: AnyBackend) { let test_rows = vec![ Entry::new( "category", @@ -450,13 +527,15 @@ pub async fn db_remove_all(db: impl TestStore) { ), ]; - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); for test_row in test_rows.iter() { - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -468,7 +547,8 @@ pub async fn db_remove_all(db: impl TestStore) { // depends on the backend. just checking that no SQL errors occur for now. let removed = conn .remove_all( - "category", + Some(EntryKind::Item), + Some("category"), Some(TagFilter::all_of(vec![ TagFilter::is_eq("t1", "del"), TagFilter::is_eq("~t2", "del"), @@ -479,69 +559,49 @@ pub async fn db_remove_all(db: impl TestStore) { assert_eq!(removed, 2); } -pub async fn db_keypair_insert_fetch(db: impl TestStore) { - let keypair = LocalKey::generate(KeyAlg::Ed25519, false).expect(ERR_CREATE_KEYPAIR); - - let mut conn = db.session(None).await.expect(ERR_SESSION); - - let key_name = "testkey"; - let metadata = "meta"; - conn.insert_key(key_name, &keypair, Some(metadata), None, None) - .await - .expect(ERR_INSERT_KEY); - - let found = conn - .fetch_key(key_name, false) - .await - .expect(ERR_FETCH_KEY) - .expect(ERR_REQ_ROW); - assert_eq!(found.algorithm(), Some(KeyAlg::Ed25519.as_str())); - assert_eq!(found.name(), key_name); - assert_eq!(found.metadata(), Some(metadata)); - assert!(found.is_local()); - found.load_local_key().expect(ERR_LOAD_KEY); -} - -pub async fn db_txn_rollback(db: impl TestStore) { +pub async fn db_txn_rollback(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let mut conn = db.session(None, true).expect(ERR_TRANSACTION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) .await .expect(ERR_INSERT); - conn.rollback() + conn.close(false) .await .expect("Error rolling back transaction"); - let mut conn = db.session(None).await.expect("Error starting new session"); + let mut conn = db.session(None, false).expect("Error starting new session"); let row = conn - .fetch(&test_row.category, &test_row.name, false) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, false) .await .expect("Error fetching test row"); assert_eq!(row, None); } -pub async fn db_txn_drop(db: impl TestStore) { +pub async fn db_txn_drop(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db - .transaction(None) - .await + .session(None, true) .expect("Error starting new transaction"); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -550,25 +610,27 @@ pub async fn db_txn_drop(db: impl TestStore) { drop(conn); - let mut conn = db.session(None).await.expect("Error starting new session"); + let mut conn = db.session(None, false).expect("Error starting new session"); let row = conn - .fetch(&test_row.category, &test_row.name, false) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, false) .await .expect("Error fetching test row"); assert_eq!(row, None); } // test that session does NOT have transaction rollback behaviour -pub async fn db_session_drop(db: impl TestStore) { +pub async fn db_session_drop(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -577,50 +639,54 @@ pub async fn db_session_drop(db: impl TestStore) { drop(conn); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); let row = conn - .fetch(&test_row.category, &test_row.name, false) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, false) .await .expect(ERR_FETCH); assert_eq!(row, Some(test_row)); } -pub async fn db_txn_commit(db: impl TestStore) { +pub async fn db_txn_commit(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let mut conn = db.session(None, true).expect(ERR_TRANSACTION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) .await .expect(ERR_INSERT); - conn.commit().await.expect(ERR_COMMIT); + conn.close(true).await.expect(ERR_COMMIT); - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); let row = conn - .fetch(&test_row.category, &test_row.name, false) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, false) .await .expect(ERR_FETCH); assert_eq!(row, Some(test_row)); } -pub async fn db_txn_fetch_for_update(db: impl TestStore) { +pub async fn db_txn_fetch_for_update(db: AnyBackend) { let test_row = Entry::new("category", "name", "value", Vec::new()); - let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let mut conn = db.session(None, true).expect(ERR_TRANSACTION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) @@ -630,23 +696,29 @@ pub async fn db_txn_fetch_for_update(db: impl TestStore) { // could detect that a second transaction would block here? // depends on the backend. just checking that no SQL errors occur for now. let row = conn - .fetch(&test_row.category, &test_row.name, true) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, true) .await .expect(ERR_FETCH) .expect(ERR_REQ_ROW); assert_eq!(row, test_row); let rows = conn - .fetch_all(&test_row.category, None, Some(2), true) + .fetch_all( + Some(EntryKind::Item), + Some(&test_row.category), + None, + Some(2), + true, + ) .await .expect(ERR_FETCH_ALL); assert_eq!(rows.len(), 1); assert_eq!(rows[0], test_row); - conn.commit().await.expect(ERR_COMMIT); + conn.close(true).await.expect(ERR_COMMIT); } -pub async fn db_txn_contention(db: impl TestStore + 'static) { +pub async fn db_txn_contention(db: AnyBackend) { let test_row = Entry::new( "category", "count", @@ -657,29 +729,31 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { ], ); - let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let mut conn = db.session(None, true).expect(ERR_TRANSACTION); - conn.insert( + conn.update( + EntryKind::Item, + EntryOperation::Insert, &test_row.category, &test_row.name, - &test_row.value, + Some(&test_row.value), Some(test_row.tags.as_slice()), None, ) .await .expect(ERR_INSERT); - conn.commit().await.expect(ERR_COMMIT); + conn.close(true).await.expect(ERR_COMMIT); const TASKS: usize = 10; const INC: usize = 1000; - async fn inc(db: impl TestStore, category: String, name: String) -> Result<(), &'static str> { + async fn inc(db: AnyBackend, category: String, name: String) -> Result<(), &'static str> { // try to avoid panics in this section, as they will be raised on a tokio worker thread for _ in 0..INC { - let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let mut conn = db.session(None, true).expect(ERR_TRANSACTION); let row = conn - .fetch(&category, &name, true) + .fetch(EntryKind::Item, &category, &name, true) .await .map_err(|e| { log::error!("{:?}", e); @@ -688,10 +762,12 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { .ok_or(ERR_REQ_ROW)?; let val: usize = str::parse(row.value.as_opt_str().ok_or("Non-string counter value")?) .map_err(|_| "Error parsing counter value")?; - conn.replace( + conn.update( + EntryKind::Item, + EntryOperation::Replace, &category, &name, - format!("{}", val + 1).as_bytes(), + Some(format!("{}", val + 1).as_bytes()), Some(row.tags.as_slice()), None, ) @@ -700,7 +776,7 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { log::error!("{:?}", e); ERR_REPLACE })?; - conn.commit().await.map_err(|_| ERR_COMMIT)?; + conn.close(true).await.map_err(|_| ERR_COMMIT)?; } Ok(()) } @@ -722,9 +798,9 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { } // check the total - let mut conn = db.session(None).await.expect(ERR_SESSION); + let mut conn = db.session(None, false).expect(ERR_SESSION); let row = conn - .fetch(&test_row.category, &test_row.name, false) + .fetch(EntryKind::Item, &test_row.category, &test_row.name, false) .await .expect(ERR_FETCH) .expect(ERR_REQ_ROW); diff --git a/src/backend/any.rs b/src/backend/any.rs deleted file mode 100644 index b09065fe..00000000 --- a/src/backend/any.rs +++ /dev/null @@ -1,343 +0,0 @@ -use super::{Backend, ManageBackend, QueryBackend}; -use crate::{ - error::Error, - future::BoxFuture, - protect::{PassKey, StoreKeyMethod}, - storage::{ - Entry, EntryKind, EntryOperation, EntryTag, IntoOptions, Scan, Session, Store, TagFilter, - }, -}; - -#[cfg(feature = "postgres")] -use super::postgres::{self, PostgresStore}; - -#[cfg(feature = "sqlite")] -use super::sqlite::{self, SqliteStore}; - -/// A generic `Store` implementation for any supported backend -pub type AnyStore = Store; - -/// A generic `Session` implementation for any supported backend -pub type AnySession = Session; - -/// An enumeration of supported store backends -#[derive(Debug)] -pub enum AnyBackend { - /// A PostgreSQL store - #[cfg(feature = "postgres")] - Postgres(PostgresStore), - - /// A Sqlite store - #[cfg(feature = "sqlite")] - Sqlite(SqliteStore), - - #[allow(unused)] - #[doc(hidden)] - Other, -} - -macro_rules! with_backend { - ($slf:ident, $ident:ident, $body:expr) => { - match $slf { - #[cfg(feature = "postgres")] - Self::Postgres($ident) => $body, - - #[cfg(feature = "sqlite")] - Self::Sqlite($ident) => $body, - - _ => unreachable!(), - } - }; -} - -impl Backend for AnyBackend { - type Session = AnyQueryBackend; - - fn create_profile(&self, name: Option) -> BoxFuture<'_, Result> { - with_backend!(self, store, store.create_profile(name)) - } - - fn get_profile_name(&self) -> &str { - with_backend!(self, store, store.get_profile_name()) - } - - fn remove_profile(&self, name: String) -> BoxFuture<'_, Result> { - with_backend!(self, store, store.remove_profile(name)) - } - - fn scan( - &self, - profile: Option, - kind: EntryKind, - category: String, - tag_filter: Option, - offset: Option, - limit: Option, - ) -> BoxFuture<'_, Result, Error>> { - with_backend!( - self, - store, - store.scan(profile, kind, category, tag_filter, offset, limit) - ) - } - - fn session(&self, profile: Option, transaction: bool) -> Result { - match self { - #[cfg(feature = "postgres")] - Self::Postgres(store) => { - let session = store.session(profile, transaction)?; - Ok(AnyQueryBackend::PostgresSession(Box::new(session))) - } - - #[cfg(feature = "sqlite")] - Self::Sqlite(store) => { - let session = store.session(profile, transaction)?; - Ok(AnyQueryBackend::SqliteSession(Box::new(session))) - } - - _ => unreachable!(), - } - } - - fn rekey_backend( - &mut self, - method: StoreKeyMethod, - pass_key: PassKey<'_>, - ) -> BoxFuture<'_, Result<(), Error>> { - with_backend!(self, store, store.rekey_backend(method, pass_key)) - } - - fn close(&self) -> BoxFuture<'_, Result<(), Error>> { - with_backend!(self, store, store.close()) - } -} - -/// An enumeration of supported backend session types -#[derive(Debug)] -pub enum AnyQueryBackend { - /// A PostgreSQL store session - #[cfg(feature = "postgres")] - PostgresSession(Box<::Session>), - - /// A Sqlite store session - #[cfg(feature = "sqlite")] - SqliteSession(Box<::Session>), - - #[allow(unused)] - #[doc(hidden)] - Other, -} - -impl QueryBackend for AnyQueryBackend { - fn count<'q>( - &'q mut self, - kind: EntryKind, - category: &'q str, - tag_filter: Option, - ) -> BoxFuture<'q, Result> { - match self { - #[cfg(feature = "postgres")] - Self::PostgresSession(session) => session.count(kind, category, tag_filter), - - #[cfg(feature = "sqlite")] - Self::SqliteSession(session) => session.count(kind, category, tag_filter), - - _ => unreachable!(), - } - } - - fn fetch<'q>( - &'q mut self, - kind: EntryKind, - category: &'q str, - name: &'q str, - for_update: bool, - ) -> BoxFuture<'q, Result, Error>> { - match self { - #[cfg(feature = "postgres")] - Self::PostgresSession(session) => session.fetch(kind, category, name, for_update), - - #[cfg(feature = "sqlite")] - Self::SqliteSession(session) => session.fetch(kind, category, name, for_update), - - _ => unreachable!(), - } - } - - fn fetch_all<'q>( - &'q mut self, - kind: EntryKind, - category: &'q str, - tag_filter: Option, - limit: Option, - for_update: bool, - ) -> BoxFuture<'q, Result, Error>> { - match self { - #[cfg(feature = "postgres")] - Self::PostgresSession(session) => { - session.fetch_all(kind, category, tag_filter, limit, for_update) - } - - #[cfg(feature = "sqlite")] - Self::SqliteSession(session) => { - session.fetch_all(kind, category, tag_filter, limit, for_update) - } - - _ => unreachable!(), - } - } - - fn remove_all<'q>( - &'q mut self, - kind: EntryKind, - category: &'q str, - tag_filter: Option, - ) -> BoxFuture<'q, Result> { - match self { - #[cfg(feature = "postgres")] - Self::PostgresSession(session) => session.remove_all(kind, category, tag_filter), - - #[cfg(feature = "sqlite")] - Self::SqliteSession(session) => session.remove_all(kind, category, tag_filter), - - _ => unreachable!(), - } - } - - fn update<'q>( - &'q mut self, - kind: EntryKind, - operation: EntryOperation, - category: &'q str, - name: &'q str, - value: Option<&'q [u8]>, - tags: Option<&'q [EntryTag]>, - expiry_ms: Option, - ) -> BoxFuture<'q, Result<(), Error>> { - match self { - #[cfg(feature = "postgres")] - Self::PostgresSession(session) => { - session.update(kind, operation, category, name, value, tags, expiry_ms) - } - - #[cfg(feature = "sqlite")] - Self::SqliteSession(session) => { - session.update(kind, operation, category, name, value, tags, expiry_ms) - } - - _ => unreachable!(), - } - } - - fn close(self, commit: bool) -> BoxFuture<'static, Result<(), Error>> { - match self { - #[cfg(feature = "postgres")] - Self::PostgresSession(session) => Box::pin(session.close(commit)), - - #[cfg(feature = "sqlite")] - Self::SqliteSession(session) => Box::pin(session.close(commit)), - - _ => unreachable!(), - } - } -} - -impl<'a> ManageBackend<'a> for &'a str { - type Store = AnyStore; - - fn open_backend( - self, - method: Option, - pass_key: PassKey<'a>, - profile: Option<&'a str>, - ) -> BoxFuture<'a, Result> { - Box::pin(async move { - let opts = self.into_options()?; - debug!("Open store with options: {:?}", &opts); - - match opts.schema.as_ref() { - #[cfg(feature = "postgres")] - "postgres" => { - let opts = postgres::PostgresStoreOptions::new(opts)?; - let mgr = opts.open(method, pass_key, profile).await?; - Ok(Store::new(AnyBackend::Postgres(mgr.into_inner()))) - } - - #[cfg(feature = "sqlite")] - "sqlite" => { - let opts = sqlite::SqliteStoreOptions::new(opts)?; - let mgr = opts.open(method, pass_key, profile).await?; - Ok(Store::new(AnyBackend::Sqlite(mgr.into_inner()))) - } - - _ => Err(err_msg!( - Unsupported, - "Unsupported backend: {}", - &opts.schema - )), - } - }) - } - - fn provision_backend( - self, - method: StoreKeyMethod, - pass_key: PassKey<'a>, - profile: Option<&'a str>, - recreate: bool, - ) -> BoxFuture<'a, Result> { - Box::pin(async move { - let opts = self.into_options()?; - debug!("Provision store with options: {:?}", &opts); - - match opts.schema.as_ref() { - #[cfg(feature = "postgres")] - "postgres" => { - let opts = postgres::PostgresStoreOptions::new(opts)?; - let mgr = opts.provision(method, pass_key, profile, recreate).await?; - Ok(Store::new(AnyBackend::Postgres(mgr.into_inner()))) - } - - #[cfg(feature = "sqlite")] - "sqlite" => { - let opts = sqlite::SqliteStoreOptions::new(opts)?; - let mgr = opts.provision(method, pass_key, profile, recreate).await?; - Ok(Store::new(AnyBackend::Sqlite(mgr.into_inner()))) - } - - _ => Err(err_msg!( - Unsupported, - "Unsupported backend: {}", - &opts.schema - )), - } - }) - } - - fn remove_backend(self) -> BoxFuture<'a, Result> { - Box::pin(async move { - let opts = self.into_options()?; - debug!("Remove store with options: {:?}", &opts); - - match opts.schema.as_ref() { - #[cfg(feature = "postgres")] - "postgres" => { - let opts = postgres::PostgresStoreOptions::new(opts)?; - Ok(opts.remove().await?) - } - - #[cfg(feature = "sqlite")] - "sqlite" => { - let opts = sqlite::SqliteStoreOptions::new(opts)?; - Ok(opts.remove().await?) - } - - _ => Err(err_msg!( - Unsupported, - "Unsupported backend: {}", - &opts.schema - )), - } - }) - } -} diff --git a/src/backend/mod.rs b/src/backend/mod.rs deleted file mode 100644 index 537e24ea..00000000 --- a/src/backend/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -//! Storage backends supported by aries-askar - -#[cfg(feature = "any")] -#[cfg_attr(docsrs, doc(cfg(feature = "any")))] -/// Generic backend (from URI) support -pub mod any; - -#[cfg(any(feature = "postgres", feature = "sqlite"))] -pub(crate) mod db_utils; - -#[cfg(feature = "postgres")] -#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] -/// Postgres database support -pub mod postgres; - -#[cfg(feature = "sqlite")] -#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] -/// Sqlite database support -pub mod sqlite; - -mod types; -pub use self::types::{Backend, ManageBackend, QueryBackend}; diff --git a/src/error.rs b/src/error.rs index 781233a0..0d0f805e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,7 @@ use std::error::Error as StdError; use std::fmt::{self, Display, Formatter}; use crate::crypto::{Error as CryptoError, ErrorKind as CryptoErrorKind}; +use crate::storage::{Error as StorageError, ErrorKind as StorageErrorKind}; /// The possible kinds of error produced by the crate #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -131,15 +132,6 @@ impl From for Error { } } -// FIXME would be preferable to remove this auto-conversion and handle -// all sqlx errors manually, to ensure there is some context around the error -#[cfg(any(feature = "indy_compat", feature = "postgres", feature = "sqlite"))] -impl From for Error { - fn from(err: sqlx::Error) -> Self { - Error::from(ErrorKind::Backend).with_cause(err) - } -} - impl From for Error { fn from(err: CryptoError) -> Self { let kind = match err.kind() { @@ -157,6 +149,28 @@ impl From for Error { } } +impl From for Error { + fn from(err: StorageError) -> Self { + let (kind, cause, message) = err.into_parts(); + let kind = match kind { + StorageErrorKind::Backend => ErrorKind::Backend, + StorageErrorKind::Busy => ErrorKind::Busy, + StorageErrorKind::Custom => ErrorKind::Custom, + StorageErrorKind::Duplicate => ErrorKind::Duplicate, + StorageErrorKind::Encryption => ErrorKind::Encryption, + StorageErrorKind::Input => ErrorKind::Input, + StorageErrorKind::NotFound => ErrorKind::NotFound, + StorageErrorKind::Unexpected => ErrorKind::Unexpected, + StorageErrorKind::Unsupported => ErrorKind::Unsupported, + }; + Error { + kind, + cause, + message, + } + } +} + macro_rules! err_msg { () => { $crate::error::Error::from($crate::error::ErrorKind::Input) diff --git a/src/ffi/error.rs b/src/ffi/error.rs index 1ab29e6c..cf20bc8e 100644 --- a/src/ffi/error.rs +++ b/src/ffi/error.rs @@ -9,7 +9,7 @@ use once_cell::sync::Lazy; static LAST_ERROR: Lazy>> = Lazy::new(|| RwLock::new(None)); -#[derive(Debug, PartialEq, Copy, Clone, Serialize)] +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize)] #[repr(i64)] pub enum ErrorCode { Success = 0, @@ -60,11 +60,16 @@ pub extern "C" fn askar_get_current_error(error_json_p: *mut *const c_char) -> E } pub fn get_current_error_json() -> String { + #[derive(Serialize)] + struct ErrorJson { + code: usize, + message: String, + } + if let Some(err) = Option::take(&mut *LAST_ERROR.write().unwrap()) { let message = err.to_string(); let code = ErrorCode::from(err.kind()) as usize; - // let extra = err.extra(); - json!({"code": code, "message": message}).to_string() + serde_json::json!(&ErrorJson { code, message }).to_string() } else { r#"{"code":0,"message":null}"#.to_owned() } diff --git a/src/ffi/handle.rs b/src/ffi/handle.rs index a469427c..33d1dafd 100644 --- a/src/ffi/handle.rs +++ b/src/ffi/handle.rs @@ -1,7 +1,12 @@ -use std::{fmt::Display, mem, ptr, sync::Arc}; +use std::{ + fmt::{Debug, Display}, + mem, ptr, + sync::Arc, +}; use crate::error::Error; +#[derive(Debug)] #[repr(C)] pub struct ArcHandle(*const T); @@ -18,6 +23,7 @@ impl ArcHandle { pub fn load(&self) -> Result, Error> { self.validate()?; let result = unsafe { mem::ManuallyDrop::new(Arc::from_raw(self.0)) }; + #[allow(clippy::needless_borrow)] Ok((&*result).clone()) } @@ -46,7 +52,7 @@ impl std::fmt::Display for ArcHandle { } } -pub trait ResourceHandle: Copy + Ord + From + Display { +pub trait ResourceHandle: Copy + Eq + Ord + From + Debug + Display { fn invalid() -> Self { Self::from(0) } diff --git a/src/ffi/key.rs b/src/ffi/key.rs index 49ab97ad..099570de 100644 --- a/src/ffi/key.rs +++ b/src/ffi/key.rs @@ -383,7 +383,7 @@ pub extern "C" fn askar_key_wrap_key( check_useful_c_ptr!(out); let key = handle.load()?; let other = other.load()?; - let result = key.wrap_key(&*other, nonce.as_slice())?; + let result = key.wrap_key(&other, nonce.as_slice())?; unsafe { *out = EncryptedBuffer::from_encrypted(result) }; Ok(ErrorCode::Success) } @@ -435,8 +435,8 @@ pub extern "C" fn askar_key_crypto_box( let recip_key = recip_key.load()?; let sender_key = sender_key.load()?; let message = crypto_box( - &*recip_key, - &*sender_key, + &recip_key, + &sender_key, message.as_slice(), nonce.as_slice() )?; @@ -459,8 +459,8 @@ pub extern "C" fn askar_key_crypto_box_open( let recip_key = recip_key.load()?; let sender_key = sender_key.load()?; let message = crypto_box_open( - &*recip_key, - &*sender_key, + &recip_key, + &sender_key, message.as_slice(), nonce.as_slice() )?; diff --git a/src/ffi/migration.rs b/src/ffi/migration.rs index c520527f..c397aa84 100644 --- a/src/ffi/migration.rs +++ b/src/ffi/migration.rs @@ -1,6 +1,7 @@ use ffi_support::FfiStr; -use crate::{future::spawn_ok, migration::IndySdkToAriesAskarMigration}; +use crate::storage::future::spawn_ok; +use crate::storage::migration::IndySdkToAriesAskarMigration; use super::{ error::{set_last_error, ErrorCode}, diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 806bf845..baa78d32 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -23,6 +23,7 @@ mod log; mod result_list; mod secret; mod store; +mod tags; #[cfg(all(feature = "migration", feature = "sqlite"))] mod migration; diff --git a/src/ffi/result_list.rs b/src/ffi/result_list.rs index 38b15a9c..b5c0a4a3 100644 --- a/src/ffi/result_list.rs +++ b/src/ffi/result_list.rs @@ -1,11 +1,9 @@ use std::{ffi::CString, os::raw::c_char, ptr}; -use super::{handle::ArcHandle, key::LocalKeyHandle, secret::SecretBuffer, ErrorCode}; -use crate::{ - error::Error, - kms::KeyEntry, - storage::{Entry, EntryTagSet}, +use super::{ + handle::ArcHandle, key::LocalKeyHandle, secret::SecretBuffer, tags::EntryTagSet, ErrorCode, }; +use crate::{entry::Entry, error::Error, kms::KeyEntry}; pub enum FfiResultList { Single(R), @@ -28,7 +26,7 @@ impl FfiResultList { } } } - return Err(err_msg!(Input, "Invalid index for result set")); + Err(err_msg!(Input, "Invalid index for result set")) } pub fn len(&self) -> i32 { diff --git a/src/ffi/store.rs b/src/ffi/store.rs index ebddc91f..dfb124bd 100644 --- a/src/ffi/store.rs +++ b/src/ffi/store.rs @@ -8,39 +8,36 @@ use super::{ error::set_last_error, key::LocalKeyHandle, result_list::{EntryListHandle, FfiEntryList, FfiKeyEntryList, KeyEntryListHandle}, + tags::EntryTagSet, CallbackId, EnsureCallback, ErrorCode, ResourceHandle, }; use crate::{ - backend::{ - any::{AnySession, AnyStore}, - ManageBackend, - }, + entry::{Entry, EntryOperation, Scan, TagFilter}, error::Error, future::spawn_ok, - protect::{generate_raw_store_key, PassKey, StoreKeyMethod}, - storage::{Entry, EntryOperation, EntryTagSet, Scan, TagFilter}, + store::{PassKey, Session, Store, StoreKeyMethod}, }; new_sequence_handle!(StoreHandle, FFI_STORE_COUNTER); new_sequence_handle!(SessionHandle, FFI_SESSION_COUNTER); new_sequence_handle!(ScanHandle, FFI_SCAN_COUNTER); -static FFI_STORES: Lazy>>> = +static FFI_STORES: Lazy>> = Lazy::new(|| RwLock::new(BTreeMap::new())); -static FFI_SESSIONS: Lazy> = +static FFI_SESSIONS: Lazy> = Lazy::new(StoreResourceMap::new); static FFI_SCANS: Lazy>> = Lazy::new(StoreResourceMap::new); impl StoreHandle { - pub async fn create(value: AnyStore) -> Self { + pub async fn create(value: Store) -> Self { let handle = Self::next(); let mut repo = FFI_STORES.write().await; - repo.insert(handle, Arc::new(value)); + repo.insert(handle, value); handle } - pub async fn load(&self) -> Result, Error> { + pub async fn load(&self) -> Result { FFI_STORES .read() .await @@ -49,7 +46,7 @@ impl StoreHandle { .ok_or_else(|| err_msg!("Invalid store handle")) } - pub async fn remove(&self) -> Result, Error> { + pub async fn remove(&self) -> Result { FFI_STORES .write() .await @@ -57,7 +54,7 @@ impl StoreHandle { .ok_or_else(|| err_msg!("Invalid store handle")) } - pub async fn replace(&self, store: Arc) { + pub async fn replace(&self, store: Store) { FFI_STORES.write().await.insert(*self, store); } } @@ -138,7 +135,7 @@ pub extern "C" fn askar_store_generate_raw_key( s if s.is_empty() => None, s => Some(s) }; - let key = generate_raw_store_key(seed)?; + let key = Store::new_raw_key(seed)?; unsafe { *out = rust_string_to_c(key.to_string()); } Ok(ErrorCode::Success) } @@ -175,7 +172,8 @@ pub extern "C" fn askar_store_provision( ); spawn_ok(async move { let result = async { - let store = spec_uri.provision_backend( + let store = Store::provision( + spec_uri.as_str(), key_method, pass_key, profile.as_deref(), @@ -219,7 +217,8 @@ pub extern "C" fn askar_store_open( ); spawn_ok(async move { let result = async { - let store = spec_uri.open_backend( + let store = Store::open ( + spec_uri.as_str(), key_method, pass_key, profile.as_deref() @@ -249,10 +248,7 @@ pub extern "C" fn askar_store_remove( } ); spawn_ok(async move { - let result = async { - let removed = spec_uri.remove_backend().await?; - Ok(removed) - }.await; + let result = Store::remove(spec_uri.as_str()).await; cb.resolve(result); }); Ok(ErrorCode::Success) @@ -366,18 +362,10 @@ pub extern "C" fn askar_store_rekey( ); spawn_ok(async move { let result = async { - let store = handle.remove().await?; - match Arc::try_unwrap(store) { - Ok(mut store) => { - store.rekey(key_method, pass_key.as_ref()).await?; - handle.replace(Arc::new(store)).await; - Ok(()) - } - Err(arc_store) => { - handle.replace(arc_store).await; - Err(err_msg!("Cannot re-key store with multiple references")) - } - } + let mut store = handle.remove().await?; + let result = store.rekey(key_method, pass_key.as_ref()).await; + handle.replace(store).await; + result }.await; cb.resolve(result); }); @@ -409,7 +397,7 @@ pub extern "C" fn askar_store_close( // been dropped yet (this will invalidate associated handles) FFI_SESSIONS.remove_all(handle).await?; FFI_SCANS.remove_all(handle).await?; - store.arc_close().await?; + store.close().await?; info!("Closed store {}", handle); Ok(()) }.await; @@ -439,7 +427,7 @@ pub extern "C" fn askar_scan_start( trace!("Scan store start"); let cb = cb.ok_or_else(|| err_msg!("No callback provided"))?; let profile = profile.into_opt_string(); - let category = category.into_opt_string().ok_or_else(|| err_msg!("Category not provided"))?; + let category = category.into_opt_string(); let tag_filter = tag_filter.as_opt_str().map(TagFilter::from_str).transpose()?; let cb = EnsureCallback::new(move |result: Result| match result { @@ -558,7 +546,7 @@ pub extern "C" fn askar_session_count( catch_err! { trace!("Count from store"); let cb = cb.ok_or_else(|| err_msg!("No callback provided"))?; - let category = category.into_opt_string().ok_or_else(|| err_msg!("Category not provided"))?; + let category = category.into_opt_string(); let tag_filter = tag_filter.as_opt_str().map(TagFilter::from_str).transpose()?; let cb = EnsureCallback::new(move |result: Result| match result { @@ -569,8 +557,7 @@ pub extern "C" fn askar_session_count( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let count = session.count(&category, tag_filter).await; - count + session.count(category.as_deref(), tag_filter).await }.await; cb.resolve(result); }); @@ -605,8 +592,7 @@ pub extern "C" fn askar_session_fetch( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let found = session.fetch(&category, &name, for_update != 0).await; - found + session.fetch(&category, &name, for_update != 0).await }.await; cb.resolve(result); }); @@ -627,7 +613,7 @@ pub extern "C" fn askar_session_fetch_all( catch_err! { trace!("Count from store"); let cb = cb.ok_or_else(|| err_msg!("No callback provided"))?; - let category = category.into_opt_string().ok_or_else(|| err_msg!("Category not provided"))?; + let category = category.into_opt_string(); let tag_filter = tag_filter.as_opt_str().map(TagFilter::from_str).transpose()?; let limit = if limit < 0 { None } else {Some(limit)}; let cb = EnsureCallback::new(move |result| @@ -642,8 +628,7 @@ pub extern "C" fn askar_session_fetch_all( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let found = session.fetch_all(&category, tag_filter, limit, for_update != 0).await; - found + session.fetch_all(category.as_deref(), tag_filter, limit, for_update != 0).await }.await; cb.resolve(result); }); @@ -662,7 +647,7 @@ pub extern "C" fn askar_session_remove_all( catch_err! { trace!("Count from store"); let cb = cb.ok_or_else(|| err_msg!("No callback provided"))?; - let category = category.into_opt_string().ok_or_else(|| err_msg!("Category not provided"))?; + let category = category.into_opt_string(); let tag_filter = tag_filter.as_opt_str().map(TagFilter::from_str).transpose()?; let cb = EnsureCallback::new(move |result| match result { @@ -675,8 +660,7 @@ pub extern "C" fn askar_session_remove_all( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let removed = session.remove_all(&category, tag_filter).await; - removed + session.remove_all(category.as_deref(), tag_filter).await }.await; cb.resolve(result); }); @@ -731,8 +715,7 @@ pub extern "C" fn askar_session_update( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let result = session.update(operation, &category, &name, Some(value.as_slice()), tags.as_deref(), expiry_ms).await; - result + session.update(operation, &category, &name, Some(value.as_slice()), tags.as_deref(), expiry_ms).await }.await; cb.resolve(result); }); @@ -783,14 +766,13 @@ pub extern "C" fn askar_session_insert_key( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let result = session.insert_key( + session.insert_key( name.as_str(), &key, metadata.as_deref(), tags.as_deref(), expiry_ms, - ).await; - result + ).await }.await; cb.resolve(result); }); @@ -827,11 +809,10 @@ pub extern "C" fn askar_session_fetch_key( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let result = session.fetch_key( + session.fetch_key( name.as_str(), for_update != 0 - ).await; - result + ).await }.await; cb.resolve(result); }); @@ -871,14 +852,13 @@ pub extern "C" fn askar_session_fetch_all_keys( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let result = session.fetch_all_keys( + session.fetch_all_keys( alg.as_deref(), thumbprint.as_deref(), tag_filter, limit, for_update != 0 - ).await; - result + ).await }.await; cb.resolve(result); }); @@ -927,14 +907,13 @@ pub extern "C" fn askar_session_update_key( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let result = session.update_key( + session.update_key( &name, metadata.as_deref(), tags.as_deref(), expiry_ms, - ).await; - result + ).await }.await; cb.resolve(result); }); @@ -965,10 +944,9 @@ pub extern "C" fn askar_session_remove_key( spawn_ok(async move { let result = async { let mut session = FFI_SESSIONS.borrow(handle).await?; - let result = session.remove_key( + session.remove_key( &name, - ).await; - result + ).await }.await; cb.resolve(result); }); diff --git a/src/ffi/tags.rs b/src/ffi/tags.rs new file mode 100644 index 00000000..82ce242d --- /dev/null +++ b/src/ffi/tags.rs @@ -0,0 +1,210 @@ +use std::{borrow::Cow, fmt}; + +use serde::{ + de::{Error as SerdeError, MapAccess, SeqAccess, Visitor}, + ser::SerializeMap, + Deserialize, Deserializer, Serialize, Serializer, +}; + +use crate::entry::EntryTag; + +/// A wrapper type used for managing (de)serialization of tags +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct EntryTagSet<'e>(Cow<'e, [EntryTag]>); + +impl EntryTagSet<'_> { + pub fn into_vec(self) -> Vec { + self.into() + } +} + +impl<'e> From<&'e [EntryTag]> for EntryTagSet<'e> { + fn from(tags: &'e [EntryTag]) -> Self { + Self(Cow::Borrowed(tags)) + } +} + +impl From> for EntryTagSet<'static> { + fn from(tags: Vec) -> Self { + Self(Cow::Owned(tags)) + } +} + +impl<'e> From> for Vec { + fn from(set: EntryTagSet<'e>) -> Self { + set.0.into_owned() + } +} + +impl<'de> Deserialize<'de> for EntryTagSet<'static> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct TagSetVisitor; + + impl<'d> Visitor<'d> for TagSetVisitor { + type Value = EntryTagSet<'static>; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("an object containing zero or more entry tags") + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'d>, + { + let mut v = Vec::with_capacity(access.size_hint().unwrap_or_default()); + + while let Some((key, values)) = access.next_entry::<&str, EntryTagValues>()? { + let (tag, enc) = match key.chars().next() { + Some('~') => (key[1..].to_owned(), false), + None => return Err(M::Error::custom("invalid tag name: empty string")), + _ => (key.to_owned(), true), + }; + match (values, enc) { + (EntryTagValues::Single(value), true) => { + v.push(EntryTag::Encrypted(tag, value)) + } + (EntryTagValues::Single(value), false) => { + v.push(EntryTag::Plaintext(tag, value)) + } + (EntryTagValues::Multiple(values), true) => { + for value in values { + v.push(EntryTag::Encrypted(tag.clone(), value)) + } + } + (EntryTagValues::Multiple(values), false) => { + for value in values { + v.push(EntryTag::Plaintext(tag.clone(), value)) + } + } + } + } + + Ok(EntryTagSet(Cow::Owned(v))) + } + } + + deserializer.deserialize_map(TagSetVisitor) + } +} + +enum EntryTagValues { + Single(String), + Multiple(Vec), +} + +impl<'de> Deserialize<'de> for EntryTagValues { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct TagValuesVisitor; + + impl<'d> Visitor<'d> for TagValuesVisitor { + type Value = EntryTagValues; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a string or list of strings") + } + + fn visit_str(self, value: &str) -> Result + where + E: SerdeError, + { + Ok(EntryTagValues::Single(value.to_owned())) + } + + fn visit_string(self, value: String) -> Result + where + E: SerdeError, + { + Ok(EntryTagValues::Single(value)) + } + + fn visit_seq(self, mut access: S) -> Result + where + S: SeqAccess<'d>, + { + let mut v = Vec::with_capacity(access.size_hint().unwrap_or_default()); + while let Some(value) = access.next_element()? { + v.push(value) + } + Ok(EntryTagValues::Multiple(v)) + } + } + + deserializer.deserialize_any(TagValuesVisitor) + } +} + +impl Serialize for EntryTagSet<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use std::collections::BTreeMap; + + #[derive(PartialOrd, Ord)] + struct TagName<'a>(&'a str, bool); + + impl<'a> PartialEq for TagName<'a> { + fn eq(&self, other: &Self) -> bool { + self.1 == other.1 && self.0 == other.0 + } + } + + impl<'a> Eq for TagName<'a> {} + + impl Serialize for TagName<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.1 { + serializer.serialize_str(self.0) + } else { + serializer.collect_str(&format_args!("~{}", self.0)) + } + } + } + + let mut tags = BTreeMap::new(); + for tag in self.0.iter() { + let (name, value) = match tag { + EntryTag::Encrypted(name, val) => (TagName(name.as_str(), true), val.as_str()), + EntryTag::Plaintext(name, val) => (TagName(name.as_str(), false), val.as_str()), + }; + tags.entry(name).or_insert_with(Vec::new).push(value); + } + + let mut map = serializer.serialize_map(Some(tags.len()))?; + for (tag_name, values) in tags.into_iter() { + if values.len() > 1 { + map.serialize_entry(&tag_name, &values)?; + } else { + map.serialize_entry(&tag_name, &values[0])?; + } + } + map.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_tags() { + let tags = EntryTagSet::from(vec![ + EntryTag::Encrypted("a".to_owned(), "aval".to_owned()), + EntryTag::Plaintext("b".to_owned(), "bval".to_owned()), + EntryTag::Plaintext("b".to_owned(), "bval-2".to_owned()), + ]); + let ser = serde_json::to_string(&tags).unwrap(); + assert_eq!(ser, r#"{"a":"aval","~b":["bval","bval-2"]}"#); + let tags2 = serde_json::from_str(&ser).unwrap(); + assert_eq!(tags, tags2); + } +} diff --git a/src/kms/entry.rs b/src/kms/entry.rs index 16f6b17c..f4135df8 100644 --- a/src/kms/entry.rs +++ b/src/kms/entry.rs @@ -1,8 +1,8 @@ use super::local_key::LocalKey; use crate::{ crypto::{alg::AnyKey, buffer::SecretBytes, jwk::FromJwk}, + entry::{Entry, EntryTag}, error::Error, - storage::{Entry, EntryTag}, }; /// Parameters defining a stored key diff --git a/src/kms/envelope.rs b/src/kms/envelope.rs index 095c2517..f5407a53 100644 --- a/src/kms/envelope.rs +++ b/src/kms/envelope.rs @@ -88,14 +88,7 @@ pub fn derive_key_ecdh_1pu( receive: bool, ) -> Result { let derive = Ecdh1PU::new( - &*ephem_key, - &*sender_key, - &*recip_key, - alg_id, - apu, - apv, - cc_tag, - receive, + ephem_key, sender_key, recip_key, alg_id, apu, apv, cc_tag, receive, ); LocalKey::from_key_derivation(key_alg, derive) } @@ -110,6 +103,6 @@ pub fn derive_key_ecdh_es( apv: &[u8], receive: bool, ) -> Result { - let derive = EcdhEs::new(&*ephem_key, &*recip_key, alg_id, apu, apv, receive); + let derive = EcdhEs::new(ephem_key, recip_key, alg_id, apu, apv, receive); LocalKey::from_key_derivation(key_alg, derive) } diff --git a/src/lib.rs b/src/lib.rs index c5f518e9..b15b6160 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,13 +7,6 @@ mod error; pub use self::error::{Error, ErrorKind}; -#[cfg(test)] -#[macro_use] -extern crate hex_literal; - -#[macro_use] -mod macros; - #[cfg(any(test, feature = "log"))] #[macro_use] extern crate log; @@ -21,41 +14,17 @@ extern crate log; #[macro_use] extern crate serde; -pub mod backend; -pub use self::backend::{Backend, ManageBackend}; - -#[cfg(feature = "any")] -pub use self::backend::any; - -#[cfg(feature = "postgres")] -pub use self::backend::postgres; - -#[cfg(feature = "sqlite")] -pub use self::backend::sqlite; - +#[doc(hidden)] pub use askar_crypto as crypto; - #[doc(hidden)] -pub mod future; - -#[cfg(feature = "ffi")] -#[macro_use] -extern crate serde_json; +pub use askar_storage as storage; +#[doc(hidden)] +pub use askar_storage::future; #[cfg(feature = "ffi")] mod ffi; -#[cfg(all(feature = "migration", feature = "sqlite"))] -pub mod migration; - pub mod kms; -mod protect; -pub use protect::{ - generate_raw_store_key, - kdf::{Argon2Level, KdfMethod}, - PassKey, StoreKeyMethod, -}; - -mod storage; -pub use storage::{Entry, EntryTag, Scan, Store, TagFilter}; +mod store; +pub use store::{entry, PassKey, Session, Store, StoreKeyMethod}; diff --git a/src/storage/mod.rs b/src/storage/mod.rs deleted file mode 100644 index f7d00491..00000000 --- a/src/storage/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod entry; -pub(crate) use self::entry::{EncEntryTag, EntryTagSet}; -pub use self::entry::{Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}; - -mod options; -pub(crate) use self::options::{IntoOptions, Options}; - -mod store; -pub use self::store::{Session, Store}; - -pub(crate) mod wql; diff --git a/src/storage/store.rs b/src/store.rs similarity index 75% rename from src/storage/store.rs rename to src/store.rs index 2f42d17a..11a8d57a 100644 --- a/src/storage/store.rs +++ b/src/store.rs @@ -1,34 +1,62 @@ use std::sync::Arc; -use super::entry::{Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}; use crate::{ - backend::{Backend, QueryBackend}, error::Error, kms::{KeyEntry, KeyParams, KmsCategory, LocalKey}, - protect::{PassKey, StoreKeyMethod}, + storage::{ + any::{AnyBackend, AnyBackendSession}, + backend::{BackendSession, ManageBackend}, + entry::{Entry, EntryKind, EntryOperation, EntryTag, Scan, TagFilter}, + generate_raw_store_key, + }, }; -#[derive(Debug)] +pub use crate::storage::{entry, PassKey, StoreKeyMethod}; + +#[derive(Debug, Clone)] /// An instance of an opened store -pub struct Store(B); +pub struct Store(AnyBackend); -impl Store { - pub(crate) fn new(inner: B) -> Self { +impl Store { + pub(crate) fn new(inner: AnyBackend) -> Self { Self(inner) } - #[cfg(test)] - #[allow(unused)] - pub(crate) fn inner(&self) -> &B { - &self.0 + /// Provision a new store instance using a database URL + pub async fn provision( + db_url: &str, + key_method: StoreKeyMethod, + pass_key: PassKey<'_>, + profile: Option<&str>, + recreate: bool, + ) -> Result { + let backend = db_url + .provision_backend(key_method, pass_key, profile, recreate) + .await?; + Ok(Self::new(backend)) } - pub(crate) fn into_inner(self) -> B { - self.0 + /// Open a store instance from a database URL + pub async fn open( + db_url: &str, + key_method: Option, + pass_key: PassKey<'_>, + profile: Option<&str>, + ) -> Result { + let backend = db_url.open_backend(key_method, pass_key, profile).await?; + Ok(Self::new(backend)) + } + + /// Remove a store instance using a database URL + pub async fn remove(db_url: &str) -> Result { + Ok(db_url.remove_backend().await?) + } + + /// Generate a new raw store key + pub fn new_raw_key(seed: Option<&[u8]>) -> Result, Error> { + Ok(generate_raw_store_key(seed)?) } -} -impl Store { /// Get the default profile name used when starting a scan or a session pub fn get_profile_name(&self) -> &str { self.0.get_profile_name() @@ -40,17 +68,20 @@ impl Store { method: StoreKeyMethod, pass_key: PassKey<'_>, ) -> Result<(), Error> { - self.0.rekey_backend(method, pass_key).await + match Arc::get_mut(&mut self.0) { + Some(inner) => Ok(inner.rekey(method, pass_key).await?), + None => Err(err_msg!("Cannot re-key a store with multiple references")), + } } /// Create a new profile with the given profile name pub async fn create_profile(&self, name: Option) -> Result { - self.0.create_profile(name).await + Ok(self.0.create_profile(name).await?) } /// Remove an existing profile with the given profile name pub async fn remove_profile(&self, name: String) -> Result { - self.0.remove_profile(name).await + Ok(self.0.remove_profile(name).await?) } /// Create a new scan instance against the store @@ -59,62 +90,66 @@ impl Store { pub async fn scan( &self, profile: Option, - category: String, + category: Option, tag_filter: Option, offset: Option, limit: Option, ) -> Result, Error> { - self.0 + Ok(self + .0 .scan( profile, - EntryKind::Item, + Some(EntryKind::Item), category, tag_filter, offset, limit, ) - .await + .await?) } /// Create a new session against the store - pub async fn session(&self, profile: Option) -> Result, Error> { + pub async fn session(&self, profile: Option) -> Result { // FIXME - add 'immediate' flag Ok(Session::new(self.0.session(profile, false)?)) } /// Create a new transaction session against the store - pub async fn transaction(&self, profile: Option) -> Result, Error> { + pub async fn transaction(&self, profile: Option) -> Result { Ok(Session::new(self.0.session(profile, true)?)) } /// Close the store instance, waiting for any shutdown procedures to complete. pub async fn close(self) -> Result<(), Error> { - self.0.close().await + Ok(self.0.close().await?) } +} - pub(crate) async fn arc_close(self: Arc) -> Result<(), Error> { - self.0.close().await +impl From for Store { + fn from(backend: AnyBackend) -> Self { + Self::new(backend) } } /// An active connection to the store backend #[derive(Debug)] -pub struct Session(Q); +pub struct Session(AnyBackendSession); -impl Session { - pub(crate) fn new(inner: Q) -> Self { +impl Session { + pub(crate) fn new(inner: AnyBackendSession) -> Self { Self(inner) } -} -impl Session { /// Count the number of entries for a given record category pub async fn count( &mut self, - category: &str, + category: Option<&str>, tag_filter: Option, ) -> Result { - self.0.count(EntryKind::Item, category, tag_filter).await + Ok(self + .0 + .count(Some(EntryKind::Item), category, tag_filter) + .await?) } /// Retrieve the current record at `(category, name)`. @@ -127,9 +162,10 @@ impl Session { name: &str, for_update: bool, ) -> Result, Error> { - self.0 + Ok(self + .0 .fetch(EntryKind::Item, category, name, for_update) - .await + .await?) } /// Retrieve all records matching the given `category` and `tag_filter`. @@ -139,14 +175,21 @@ impl Session { /// requirements pub async fn fetch_all( &mut self, - category: &str, + category: Option<&str>, tag_filter: Option, limit: Option, for_update: bool, ) -> Result, Error> { - self.0 - .fetch_all(EntryKind::Item, category, tag_filter, limit, for_update) - .await + Ok(self + .0 + .fetch_all( + Some(EntryKind::Item), + category, + tag_filter, + limit, + for_update, + ) + .await?) } /// Insert a new record into the store @@ -158,7 +201,8 @@ impl Session { tags: Option<&[EntryTag]>, expiry_ms: Option, ) -> Result<(), Error> { - self.0 + Ok(self + .0 .update( EntryKind::Item, EntryOperation::Insert, @@ -168,12 +212,13 @@ impl Session { tags, expiry_ms, ) - .await + .await?) } /// Remove a record from the store pub async fn remove(&mut self, category: &str, name: &str) -> Result<(), Error> { - self.0 + Ok(self + .0 .update( EntryKind::Item, EntryOperation::Remove, @@ -183,7 +228,7 @@ impl Session { None, None, ) - .await + .await?) } /// Replace the value and tags of a record in the store @@ -195,7 +240,8 @@ impl Session { tags: Option<&[EntryTag]>, expiry_ms: Option, ) -> Result<(), Error> { - self.0 + Ok(self + .0 .update( EntryKind::Item, EntryOperation::Replace, @@ -205,18 +251,19 @@ impl Session { tags, expiry_ms, ) - .await + .await?) } /// Remove all records in the store matching a given `category` and `tag_filter` pub async fn remove_all( &mut self, - category: &str, + category: Option<&str>, tag_filter: Option, ) -> Result { - self.0 - .remove_all(EntryKind::Item, category, tag_filter) - .await + Ok(self + .0 + .remove_all(Some(EntryKind::Item), category, tag_filter) + .await?) } /// Perform a record update @@ -232,7 +279,8 @@ impl Session { tags: Option<&[EntryTag]>, expiry_ms: Option, ) -> Result<(), Error> { - self.0 + Ok(self + .0 .update( EntryKind::Item, operation, @@ -242,7 +290,7 @@ impl Session { tags, expiry_ms, ) - .await + .await?) } /// Insert a local key instance into the store @@ -326,7 +374,7 @@ impl Session { for_update: bool, ) -> Result, Error> { let mut query_parts = Vec::with_capacity(3); - if let Some(query) = tag_filter.map(|f| f.query) { + if let Some(query) = tag_filter.map(|f| f.into_query()) { query_parts.push(TagFilter::from( query .map_names(|mut k| { @@ -350,8 +398,8 @@ impl Session { let rows = self .0 .fetch_all( - EntryKind::Kms, - KmsCategory::CryptoKey.as_str(), + Some(EntryKind::Kms), + Some(KmsCategory::CryptoKey.as_str()), tag_filter, limit, for_update, @@ -366,7 +414,8 @@ impl Session { /// Remove an existing key from the store pub async fn remove_key(&mut self, name: &str) -> Result<(), Error> { - self.0 + Ok(self + .0 .update( EntryKind::Kms, EntryOperation::Remove, @@ -376,7 +425,7 @@ impl Session { None, None, ) - .await + .await?) } /// Replace the metadata and tags on an existing key in the store @@ -425,12 +474,12 @@ impl Session { } /// Commit the pending transaction - pub async fn commit(self) -> Result<(), Error> { - self.0.close(true).await + pub async fn commit(mut self) -> Result<(), Error> { + Ok(self.0.close(true).await?) } /// Roll back the pending transaction - pub async fn rollback(self) -> Result<(), Error> { - self.0.close(false).await + pub async fn rollback(mut self) -> Result<(), Error> { + Ok(self.0.close(false).await?) } } diff --git a/tests/local_key.rs b/tests/local_key.rs index 90581bed..92280982 100644 --- a/tests/local_key.rs +++ b/tests/local_key.rs @@ -6,7 +6,8 @@ const ERR_CREATE_KEYPAIR: &str = "Error creating keypair"; const ERR_SIGN: &str = "Error signing message"; const ERR_VERIFY: &str = "Error verifying signature"; -pub async fn localkey_sign_verify() { +#[test] +pub fn localkey_sign_verify() { let keypair = LocalKey::generate(KeyAlg::Ed25519, true).expect(ERR_CREATE_KEYPAIR); let message = b"message".to_vec(); @@ -38,8 +39,8 @@ pub async fn localkey_sign_verify() { assert_eq!( keypair .verify_signature(&message, b"bad sig", None) - .is_err(), - true + .expect(ERR_VERIFY), + false ); assert_eq!( diff --git a/tests/store_key.rs b/tests/store_key.rs new file mode 100644 index 00000000..90ca2267 --- /dev/null +++ b/tests/store_key.rs @@ -0,0 +1,51 @@ +use aries_askar::{ + future::block_on, + kms::{KeyAlg, LocalKey}, + Store, StoreKeyMethod, +}; + +const ERR_RAW_KEY: &str = "Error creating raw store key"; +const ERR_SESSION: &str = "Error creating store session"; +const ERR_OPEN: &str = "Error opening test store instance"; +const ERR_REQ_ROW: &str = "Row required"; +const ERR_CLOSE: &str = "Error closing test store instance"; + +#[test] +fn keypair_create_fetch() { + block_on(async { + let pass_key = Store::new_raw_key(None).expect(ERR_RAW_KEY); + let db = Store::provision( + "sqlite://:memory:", + StoreKeyMethod::RawKey, + pass_key, + None, + true, + ) + .await + .expect(ERR_OPEN); + + let keypair = LocalKey::generate(KeyAlg::Ed25519, false).expect("Error creating keypair"); + + let mut conn = db.session(None).await.expect(ERR_SESSION); + + let key_name = "testkey"; + let metadata = "meta"; + conn.insert_key(key_name, &keypair, Some(metadata), None, None) + .await + .expect("Error inserting key"); + + let found = conn + .fetch_key(key_name, false) + .await + .expect("Error fetching key") + .expect(ERR_REQ_ROW); + assert_eq!(found.algorithm(), Some(KeyAlg::Ed25519.as_str())); + assert_eq!(found.name(), key_name); + assert_eq!(found.metadata(), Some(metadata)); + assert!(found.is_local()); + found.load_local_key().expect("Error loading key"); + + drop(conn); + db.close().await.expect(ERR_CLOSE); + }) +} diff --git a/wrappers/python/aries_askar/bindings/__init__.py b/wrappers/python/aries_askar/bindings/__init__.py index a0a945d9..d7d382bc 100644 --- a/wrappers/python/aries_askar/bindings/__init__.py +++ b/wrappers/python/aries_askar/bindings/__init__.py @@ -195,7 +195,7 @@ async def session_start( async def session_count( - handle: SessionHandle, category: str, tag_filter: Union[str, dict] = None + handle: SessionHandle, category: str = None, tag_filter: Union[str, dict] = None ) -> int: """Count rows in the Store.""" return int( @@ -227,7 +227,7 @@ async def session_fetch( async def session_fetch_all( handle: SessionHandle, - category: str, + category: str = None, tag_filter: Union[str, dict] = None, limit: int = None, for_update: bool = False, @@ -247,7 +247,7 @@ async def session_fetch_all( async def session_remove_all( handle: SessionHandle, - category: str, + category: str = None, tag_filter: Union[str, dict] = None, ) -> int: """Remove all matching rows in the Store.""" @@ -373,7 +373,7 @@ async def session_remove_key(handle: SessionHandle, name: str): async def scan_start( handle: StoreHandle, profile: Optional[str], - category: str, + category: str = None, tag_filter: Union[str, dict] = None, offset: int = None, limit: int = None, diff --git a/wrappers/python/aries_askar/store.py b/wrappers/python/aries_askar/store.py index 11e24643..7cfc93f3 100644 --- a/wrappers/python/aries_askar/store.py +++ b/wrappers/python/aries_askar/store.py @@ -59,8 +59,8 @@ def tags(self) -> dict: """Accessor for the entry tags.""" return self._list.get_tags(self._pos) - def keys(self): - """Mapping keys.""" + def keys(self) -> Sequence[str]: + """Accessor for the list of mapping keys.""" return Entry._KEYS def __getitem__(self, key): @@ -69,6 +69,10 @@ def __getitem__(self, key): return getattr(self, key) return KeyError + def __hasitem__(self, key) -> bool: + """Check if a key is defined.""" + return key in Entry._KEYS + def __repr__(self) -> str: """Format entry handle as a string.""" return ( @@ -95,27 +99,35 @@ def handle(self) -> EntryListHandle: return self._handle def __getitem__(self, index) -> Entry: + """Fetch an entry by index.""" if not isinstance(index, int) or index < 0 or index >= self._len: return IndexError() return Entry(self._handle, index) def __iter__(self): + """Iterate the entry list.""" return IterEntryList(self) def __len__(self) -> int: + """Accessor for the length of the list.""" return self._len def __repr__(self) -> str: + """Format entry list as a string.""" return f"" class IterEntryList: + """Iterator for the records in an entry list.""" + def __init__(self, list: EntryList): + """Create a new entry list iterator.""" self._handle = list._handle self._len = list._len self._pos = 0 def __next__(self): + """Fetch the next entry from the iterator.""" if self._pos < self._len: entry = Entry(self._handle, self._pos) self._pos += 1 @@ -185,29 +197,37 @@ def handle(self) -> KeyEntryListHandle: return self._handle def __getitem__(self, index) -> KeyEntry: + """Fetch the key entry at a specific index.""" if not isinstance(index, int) or index < 0 or index >= self._len: return IndexError() return KeyEntry(self._handle, index) def __iter__(self): + """Create an iterator over the key entry list.""" return IterKeyEntryList(self) def __len__(self) -> int: + """Accessor for the number of key entries.""" return self._len def __repr__(self) -> str: + """Format this key entry list as a string.""" return ( f"" ) class IterKeyEntryList: + """Iterator for a list of key entries.""" + def __init__(self, list: KeyEntryList): + """Create a new key entry iterator.""" self._handle = list._handle self._len = list._len self._pos = 0 def __next__(self): + """Fetch the next key entry from the iterator.""" if self._pos < self._len: entry = KeyEntry(self._handle, self._pos) self._pos += 1 @@ -223,7 +243,7 @@ def __init__( self, store: "Store", profile: Optional[str], - category: Union[str, bytes], + category: Optional[str], tag_filter: Union[str, dict] = None, offset: int = None, limit: int = None, @@ -239,9 +259,11 @@ def handle(self) -> ScanHandle: return self._handle def __aiter__(self): + """Async iterator for the scan results.""" return self async def __anext__(self): + """Fetch the next scan result during async iteration.""" if self._handle is None: (store, profile, category, tag_filter, offset, limit) = self._params self._params = None @@ -264,12 +286,14 @@ async def __anext__(self): self._buffer = iter(EntryList(list_handle)) if list_handle else None async def fetch_all(self) -> Sequence[Entry]: + """Fetch all remaining rows.""" rows = [] async for row in self: rows.append(row) return rows def __repr__(self) -> str: + """Format the scan instance as a string.""" return f"" @@ -307,6 +331,7 @@ async def provision( profile: str = None, recreate: bool = False, ) -> "Store": + """Provision a new store.""" return Store( await bindings.store_provision( uri, key_method, pass_key, profile, recreate @@ -323,29 +348,41 @@ async def open( *, profile: str = None, ) -> "Store": + """Open an existing store.""" return Store(await bindings.store_open(uri, key_method, pass_key, profile), uri) @classmethod async def remove(cls, uri: str) -> bool: + """Remove an existing store.""" return await bindings.store_remove(uri) async def __aenter__(self) -> "Session": + """Start a new session when used as an async context.""" if not self._opener: self._opener = OpenSession(self._handle, None, False) return await self._opener.__aenter__() async def __aexit__(self, exc_type, exc, tb): + """Async context termination.""" opener = self._opener self._opener = None return await opener.__aexit__(exc_type, exc, tb) async def create_profile(self, name: str = None) -> str: + """ + Create a new profile in the store. + + Returns the name of the profile, which is automatically + generated if not provided. + """ return await bindings.store_create_profile(self._handle, name) async def get_profile_name(self) -> str: + """Accessor for the currently defined profile name.""" return await bindings.store_get_profile_name(self._handle) async def remove_profile(self, name: str) -> bool: + """Remove a profile from the store.""" return await bindings.store_remove_profile(self._handle, name) async def rekey( @@ -353,22 +390,26 @@ async def rekey( key_method: str = None, pass_key: str = None, ): + """Update the master encryption key of the store.""" await bindings.store_rekey(self._handle, key_method, pass_key) def scan( self, - category: str, + category: str = None, tag_filter: Union[str, dict] = None, offset: int = None, limit: int = None, profile: str = None, ) -> Scan: + """Start a new record scan.""" return Scan(self, profile, category, tag_filter, offset, limit) def session(self, profile: str = None) -> "OpenSession": + """Open a new session on the store without starting a transaction.""" return OpenSession(self._handle, profile, False) def transaction(self, profile: str = None) -> "OpenSession": + """Open a new transactional session on the store.""" return OpenSession(self._handle, profile, True) async def close(self, *, remove: bool = False) -> bool: @@ -383,6 +424,7 @@ async def close(self, *, remove: bool = False) -> bool: return False def __repr__(self) -> str: + """Format the store instance as a string.""" return f"" @@ -405,7 +447,10 @@ def handle(self) -> SessionHandle: """Accessor for the SessionHandle instance.""" return self._handle - async def count(self, category: str, tag_filter: Union[str, dict] = None) -> int: + async def count( + self, category: str = None, tag_filter: Union[str, dict] = None + ) -> int: + """Count the records matching a category and tag filter.""" if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot count from closed session") return await bindings.session_count(self._handle, category, tag_filter) @@ -413,6 +458,7 @@ async def count(self, category: str, tag_filter: Union[str, dict] = None) -> int async def fetch( self, category: str, name: str, *, for_update: bool = False ) -> Optional[Entry]: + """Fetch a record from the store by category and name.""" if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot fetch from closed session") result_handle = await bindings.session_fetch( @@ -422,12 +468,13 @@ async def fetch( async def fetch_all( self, - category: str, + category: str = None, tag_filter: Union[str, dict] = None, limit: int = None, *, for_update: bool = False, ) -> EntryList: + """Fetch all records matching a category and tag filter.""" if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot fetch from closed session") return EntryList( @@ -445,6 +492,7 @@ async def insert( expiry_ms: int = None, value_json=None, ): + """Insert a new record into the store.""" if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot update closed session") if value is None and value_json is not None: @@ -462,6 +510,7 @@ async def replace( expiry_ms: int = None, value_json=None, ): + """Replace a record in the store matching a category and name.""" if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot update closed session") if value is None and value_json is not None: @@ -475,6 +524,7 @@ async def remove( category: str, name: str, ): + """Remove a record by category and name.""" if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot update closed session") await bindings.session_update( @@ -483,9 +533,10 @@ async def remove( async def remove_all( self, - category: str, + category: str = None, tag_filter: Union[str, dict] = None, ) -> int: + """Remove all records matching a category and tag filter.""" if not self._handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot remove all for closed session" @@ -501,6 +552,7 @@ async def insert_key( tags: dict = None, expiry_ms: int = None, ) -> str: + """Insert a new key into the store.""" if not self._handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot insert key with closed session" @@ -514,6 +566,7 @@ async def insert_key( async def fetch_key( self, name: str, *, for_update: bool = False ) -> Optional[KeyEntry]: + """Fetch a key in the store by name.""" if not self._handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot fetch key from closed session" @@ -532,6 +585,7 @@ async def fetch_all_keys( limit: int = None, for_update: bool = False, ) -> KeyEntryList: + """Fetch a set of keys in the store..""" if not self._handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot fetch key from closed session" @@ -549,6 +603,7 @@ async def update_key( tags: dict = None, expiry_ms: int = None, ): + """Update details of a key in the store.""" if not self._handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot update key with closed session" @@ -556,6 +611,7 @@ async def update_key( await bindings.session_update_key(self._handle, name, metadata, tags, expiry_ms) async def remove_key(self, name: str): + """Remove a key from the store.""" if not self._handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot remove key with closed session" @@ -563,6 +619,7 @@ async def remove_key(self, name: str): await bindings.session_remove_key(self._handle, name) async def commit(self): + """Commit the current transaction and close the session.""" if not self._is_txn: raise AskarError(AskarErrorCode.WRAPPER, "Session is not a transaction") if not self._handle: @@ -571,6 +628,7 @@ async def commit(self): self._handle = None async def rollback(self): + """Roll back the current transaction and close the session.""" if not self._is_txn: raise AskarError(AskarErrorCode.WRAPPER, "Session is not a transaction") if not self._handle: @@ -581,6 +639,7 @@ async def rollback(self): self._handle = None async def close(self): + """Close the session without specifying the commit behaviour.""" if self._handle: await self._handle.close(commit=False) self._handle = None @@ -590,6 +649,8 @@ def __repr__(self) -> str: class OpenSession: + """A pending session instance.""" + def __init__(self, store: StoreHandle, profile: Optional[str], is_txn: bool): """Initialize the OpenSession instance.""" self._store = store @@ -599,9 +660,11 @@ def __init__(self, store: StoreHandle, profile: Optional[str], is_txn: bool): @property def is_transaction(self) -> bool: + """Determine if this instance would begin a transaction.""" return self._is_txn async def _open(self) -> Session: + """Open this pending session.""" if not self._store: raise AskarError( AskarErrorCode.WRAPPER, "Cannot start session from closed store" @@ -615,13 +678,16 @@ async def _open(self) -> Session: ) def __await__(self) -> Session: + """Open this pending session.""" return self._open().__await__() async def __aenter__(self) -> Session: + """Use this pending session as an async context manager, opening the session.""" self._session = await self._open() return self._session async def __aexit__(self, exc_type, exc, tb): + """Terminate the async context and close the session.""" session = self._session self._session = None await session.close() diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index 36a6ea45..de3fb0c4 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -123,6 +123,18 @@ async def test_scan(store: Store): ).fetch_all() assert len(rows) == 1 and dict(rows[0]) == TEST_ENTRY + # Scan entries with non-matching category + rows = await store.scan("not the category").fetch_all() + assert len(rows) == 0 + + # Scan entries with non-matching tag filter + rows = await store.scan(TEST_ENTRY["category"], {"~plaintag": "X"}).fetch_all() + assert len(rows) == 0 + + # Scan entries with no category filter + rows = await store.scan(None, {"~plaintag": "a", "enctag": "b"}).fetch_all() + assert len(rows) == 1 and dict(rows[0]) == TEST_ENTRY + @mark.asyncio async def test_txn_basic(store: Store):