diff --git a/x/programs/rust/examples/counter/src/lib.rs b/x/programs/rust/examples/counter/src/lib.rs index 1ee7745922..fa366541cd 100644 --- a/x/programs/rust/examples/counter/src/lib.rs +++ b/x/programs/rust/examples/counter/src/lib.rs @@ -13,7 +13,7 @@ pub fn initialize_address(context: Context, address: Address) -> bool { if program .state() - .get::(StateKeys::Counter(address)) + .get::(StateKeys::Counter(address)) .is_ok() { panic!("counter already initialized for address") diff --git a/x/programs/rust/examples/token/src/lib.rs b/x/programs/rust/examples/token/src/lib.rs index 23251cabad..e523308554 100644 --- a/x/programs/rust/examples/token/src/lib.rs +++ b/x/programs/rust/examples/token/src/lib.rs @@ -59,7 +59,7 @@ pub fn mint_to(context: Context, recipient: Address, amount: i64) -> bool { let Context { program } = context; let balance = program .state() - .get::(StateKey::Balance(recipient)) + .get::(StateKey::Balance(recipient)) .unwrap_or_default(); program @@ -90,14 +90,14 @@ pub fn transfer(context: Context, sender: Address, recipient: Address, amount: i // ensure the sender has adequate balance let sender_balance = program .state() - .get::(StateKey::Balance(sender)) + .get::(StateKey::Balance(sender)) .expect("failed to update balance"); assert!(amount >= 0 && sender_balance >= amount, "invalid input"); let recipient_balance = program .state() - .get::(StateKey::Balance(recipient)) + .get::(StateKey::Balance(recipient)) .unwrap_or_default(); // update balances diff --git a/x/programs/rust/sdk_macros/src/lib.rs b/x/programs/rust/sdk_macros/src/lib.rs index ff8ac84299..c6c90441a3 100644 --- a/x/programs/rust/sdk_macros/src/lib.rs +++ b/x/programs/rust/sdk_macros/src/lib.rs @@ -155,7 +155,7 @@ pub fn state_keys(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut item_enum = parse_macro_input!(item as ItemEnum); // add default attributes item_enum.attrs.push(syn::parse_quote! { - #[derive(Clone, Copy, Debug)] + #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] }); item_enum.attrs.push(syn::parse_quote! { #[repr(u8)] diff --git a/x/programs/rust/wasmlanche-sdk/src/program.rs b/x/programs/rust/wasmlanche-sdk/src/program.rs index c4598dddbe..8462f9ae4d 100644 --- a/x/programs/rust/wasmlanche-sdk/src/program.rs +++ b/x/programs/rust/wasmlanche-sdk/src/program.rs @@ -1,5 +1,8 @@ +use std::hash::Hash; + use borsh::{BorshDeserialize, BorshSerialize}; +use crate::state::Key; use crate::{memory::to_host_ptr, state::Error as StateError, state::State, Params}; /// Represents the current Program in the context of the caller. Or an external @@ -25,7 +28,10 @@ impl Program { /// Returns a State object that can be used to interact with persistent /// storage exposed by the host. #[must_use] - pub fn state(&self) -> State { + pub fn state(&self) -> State + where + K: Into + Hash + PartialEq + Eq + Clone, + { State::new(Program::new(*self.id())) } diff --git a/x/programs/rust/wasmlanche-sdk/src/state.rs b/x/programs/rust/wasmlanche-sdk/src/state.rs index eab090ec9e..0d6332c77b 100644 --- a/x/programs/rust/wasmlanche-sdk/src/state.rs +++ b/x/programs/rust/wasmlanche-sdk/src/state.rs @@ -1,6 +1,6 @@ -use crate::{memory::from_host_ptr, program::Program}; -use borsh::{BorshDeserialize, BorshSerialize}; -use std::ops::Deref; +use crate::{from_host_ptr, program::Program, state::Error as StateError}; +use borsh::{from_slice, to_vec, BorshDeserialize, BorshSerialize}; +use std::{collections::HashMap, hash::Hash, ops::Deref}; #[derive(Clone, thiserror::Error, Debug)] pub enum Error { @@ -38,14 +38,36 @@ pub enum Error { Delete, } -pub struct State { +pub struct State +where + K: Into + Hash + PartialEq + Eq + Clone, +{ program: Program, + cache: HashMap>, } -impl State { +impl Drop for State +where + K: Into + Hash + PartialEq + Eq + Clone, +{ + fn drop(&mut self) { + if !self.cache.is_empty() { + // force flush + self.flush().unwrap(); + } + } +} + +impl State +where + K: Into + Hash + PartialEq + Eq + Clone, +{ #[must_use] pub fn new(program: Program) -> Self { - Self { program } + Self { + program, + cache: HashMap::new(), + } } /// Store a key and value to the host storage. If the key already exists, @@ -53,12 +75,14 @@ impl State { /// # Errors /// Returns an [Error] if the key or value cannot be /// serialized or if the host fails to handle the operation. - pub fn store(&self, key: K, value: &V) -> Result<(), Error> + pub fn store(&mut self, key: K, value: &V) -> Result<(), Error> where V: BorshSerialize, - K: Into, { - unsafe { host::put_bytes(&self.program, &key.into(), value) } + let serialized = to_vec(&value).map_err(|_| StateError::Deserialization)?; + self.cache.insert(key, serialized); + + Ok(()) } /// Get a value from the host's storage. @@ -71,30 +95,46 @@ impl State { /// the host fails to read the key and value. /// # Panics /// Panics if the value cannot be converted from i32 to usize. - pub fn get(&self, key: K) -> Result + pub fn get(&mut self, key: K) -> Result where - K: Into, - T: BorshDeserialize, + V: BorshDeserialize, { - let val_ptr = unsafe { host::get_bytes(&self.program, &key.into())? }; - if val_ptr < 0 { - return Err(Error::Read); - } - - // Wrap in OK for now, change from_raw_ptr to return Result - from_host_ptr(val_ptr) + let val_bytes = if let Some(val) = self.cache.get(&key) { + val + } else { + let val_ptr = unsafe { host::get_bytes(&self.program, &key.clone().into())? }; + if val_ptr < 0 { + return Err(Error::Read); + } + + // TODO Wrap in OK for now, change from_raw_ptr to return Result + let bytes = from_host_ptr(val_ptr)?; + self.cache.entry(key).or_insert(bytes) + }; + + from_slice::(val_bytes).map_err(|_| StateError::Deserialization) } /// Delete a value from the hosts's storage. /// # Errors /// Returns an [Error] if the key cannot be serialized /// or if the host fails to delete the key and the associated value - pub fn delete(&self, key: K) -> Result<(), Error> - where - K: Into, - { + pub fn delete(&mut self, key: K) -> Result<(), Error> { + self.cache.remove(&key); + unsafe { host::delete_bytes(&self.program, &key.into()) } } + + /// Apply all pending operations to storage and mark the cache as flushed + fn flush(&mut self) -> Result<(), Error> { + for (key, value) in self.cache.drain() { + unsafe { + host::put_bytes(&self.program, &key.into(), &value)?; + } + } + + Ok(()) + } } /// Key is a wrapper around a `Vec` that represents a key in the host storage. diff --git a/x/programs/rust/wasmlanche-sdk/src/types.rs b/x/programs/rust/wasmlanche-sdk/src/types.rs index 0c69f8cd51..9edc9539d6 100644 --- a/x/programs/rust/wasmlanche-sdk/src/types.rs +++ b/x/programs/rust/wasmlanche-sdk/src/types.rs @@ -2,7 +2,7 @@ use borsh::{BorshDeserialize, BorshSerialize}; /// A struct that enforces a fixed length of 32 bytes which represents an address. -#[derive(Clone, Copy, PartialEq, Eq, Debug, BorshSerialize, BorshDeserialize)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, BorshSerialize, BorshDeserialize, Hash)] pub struct Address([u8; Self::LEN]); impl Address {