diff --git a/src/lib.rs b/src/lib.rs index cd57a96..11b1ed9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -531,9 +531,9 @@ impl LruCache { /// let b = ||->Result<&str, String> {Ok("b")}; /// assert_eq!(cache.try_get_or_insert(2, a), Ok(&"c")); /// assert_eq!(cache.try_get_or_insert(3, a), Ok(&"d")); - /// assert_eq!(cache.try_get_or_insert(1, f), Err("failed".to_owned())); - /// assert_eq!(cache.try_get_or_insert(1, b), Ok(&"b")); - /// assert_eq!(cache.try_get_or_insert(1, a), Ok(&"b")); + /// assert_eq!(cache.try_get_or_insert(4, f), Err("failed".to_owned())); + /// assert_eq!(cache.try_get_or_insert(5, b), Ok(&"b")); + /// assert_eq!(cache.try_get_or_insert(5, a), Ok(&"b")); /// ``` pub fn try_get_or_insert(&mut self, k: K, f: F) -> Result<&V, E> where @@ -547,19 +547,15 @@ impl LruCache { unsafe { Ok(&*(*node_ptr).val.as_ptr()) } } else { - match f() { - Err(e) => Err(e), - Ok(v) => { - let (_, node) = self.replace_or_create_node(k, v); - let node_ptr: *mut LruEntry = node.as_ptr(); - - self.attach(node_ptr); - - let keyref = unsafe { (*node_ptr).key.as_ptr() }; - self.map.insert(KeyRef { k: keyref }, node); - Ok(unsafe { &*(*node_ptr).val.as_ptr() }) - } - } + let v = f()?; + let (_, node) = self.replace_or_create_node(k, v); + let node_ptr: *mut LruEntry = node.as_ptr(); + + self.attach(node_ptr); + + let keyref = unsafe { (*node_ptr).key.as_ptr() }; + self.map.insert(KeyRef { k: keyref }, node); + Ok(unsafe { &*(*node_ptr).val.as_ptr() }) } } @@ -609,6 +605,58 @@ impl LruCache { } } + /// Returns a mutable reference to the value of the key in the cache if it is + /// present in the cache and moves the key to the head of the LRU list. + /// If the key does not exist the provided `FnOnce` is used to populate + /// the list and a mutable reference is returned. If `FnOnce` returns `Err`, + /// returns the `Err`. + /// + /// # Example + /// + /// ``` + /// use lru::LruCache; + /// use std::num::NonZeroUsize; + /// let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap()); + /// + /// cache.put(1, "a"); + /// cache.put(2, "b"); + /// cache.put(2, "c"); + /// + /// let f = ||->Result<&str, String> {Err("failed".to_owned())}; + /// let a = ||->Result<&str, String> {Ok("a")}; + /// let b = ||->Result<&str, String> {Ok("b")}; + /// if let Ok(v) = cache.try_get_or_insert_mut(2, a) { + /// *v = "d"; + /// } + /// assert_eq!(cache.try_get_or_insert_mut(2, a), Ok(&mut "d")); + /// assert_eq!(cache.try_get_or_insert_mut(3, f), Err("failed".to_owned())); + /// assert_eq!(cache.try_get_or_insert_mut(4, b), Ok(&mut "b")); + /// assert_eq!(cache.try_get_or_insert_mut(4, a), Ok(&mut "b")); + /// ``` + pub fn try_get_or_insert_mut<'a, F, E>(&mut self, k: K, f: F) -> Result<&'a mut V, E> + where + F: FnOnce() -> Result, + { + if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) { + let node_ptr: *mut LruEntry = node.as_ptr(); + + self.detach(node_ptr); + self.attach(node_ptr); + + unsafe { Ok(&mut *(*node_ptr).val.as_mut_ptr()) } + } else { + let v = f()?; + let (_, node) = self.replace_or_create_node(k, v); + let node_ptr: *mut LruEntry = node.as_ptr(); + + self.attach(node_ptr); + + let keyref = unsafe { (*node_ptr).key.as_ptr() }; + self.map.insert(KeyRef { k: keyref }, node); + unsafe { Ok(&mut *(*node_ptr).val.as_mut_ptr()) } + } + } + /// Returns a reference to the value corresponding to the key in the cache or `None` if it is /// not present in the cache. Unlike `get`, `peek` does not update the LRU list so the key's /// position will be unchanged. @@ -1494,6 +1542,26 @@ mod tests { assert_eq!(cache.get_or_insert_mut("lemon", || "red"), &"orange"); } + #[test] + fn test_try_get_or_insert_mut() { + let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap()); + + cache.put(1, "a"); + cache.put(2, "b"); + cache.put(2, "c"); + + let f = || -> Result<&str, &str> { Err("failed") }; + let a = || -> Result<&str, &str> { Ok("a") }; + let b = || -> Result<&str, &str> { Ok("b") }; + if let Ok(v) = cache.try_get_or_insert_mut(2, a) { + *v = "d"; + } + assert_eq!(cache.try_get_or_insert_mut(2, a), Ok(&mut "d")); + assert_eq!(cache.try_get_or_insert_mut(3, f), Err("failed")); + assert_eq!(cache.try_get_or_insert_mut(4, b), Ok(&mut "b")); + assert_eq!(cache.try_get_or_insert_mut(4, a), Ok(&mut "b")); + } + #[test] fn test_put_and_get_mut() { let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());