Skip to content

Commit

Permalink
Merge pull request #178 from Earthcomputer/try_get_or_insert_mut
Browse files Browse the repository at this point in the history
Add try_get_or_insert_mut
  • Loading branch information
jeromefroe authored Sep 6, 2023
2 parents 8f605c5 + 10311e0 commit b24f53f
Showing 1 changed file with 84 additions and 16 deletions.
100 changes: 84 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,9 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
/// let b = ||->Result<&str, String> {Ok("b")};
/// assert_eq!(cache.try_get_or_insert(2, a), Ok(&"c"));
/// assert_eq!(cache.try_get_or_insert(3, a), Ok(&"d"));
/// assert_eq!(cache.try_get_or_insert(1, f), Err("failed".to_owned()));
/// assert_eq!(cache.try_get_or_insert(1, b), Ok(&"b"));
/// assert_eq!(cache.try_get_or_insert(1, a), Ok(&"b"));
/// assert_eq!(cache.try_get_or_insert(4, f), Err("failed".to_owned()));
/// assert_eq!(cache.try_get_or_insert(5, b), Ok(&"b"));
/// assert_eq!(cache.try_get_or_insert(5, a), Ok(&"b"));
/// ```
pub fn try_get_or_insert<F, E>(&mut self, k: K, f: F) -> Result<&V, E>
where
Expand All @@ -547,19 +547,15 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {

unsafe { Ok(&*(*node_ptr).val.as_ptr()) }
} else {
match f() {
Err(e) => Err(e),
Ok(v) => {
let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
Ok(unsafe { &*(*node_ptr).val.as_ptr() })
}
}
let v = f()?;
let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
Ok(unsafe { &*(*node_ptr).val.as_ptr() })
}
}

Expand Down Expand Up @@ -609,6 +605,58 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
}
}

/// Returns a mutable reference to the value of the key in the cache if it is
/// present in the cache and moves the key to the head of the LRU list.
/// If the key does not exist the provided `FnOnce` is used to populate
/// the list and a mutable reference is returned. If `FnOnce` returns `Err`,
/// returns the `Err`.
///
/// # Example
///
/// ```
/// use lru::LruCache;
/// use std::num::NonZeroUsize;
/// let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
///
/// cache.put(1, "a");
/// cache.put(2, "b");
/// cache.put(2, "c");
///
/// let f = ||->Result<&str, String> {Err("failed".to_owned())};
/// let a = ||->Result<&str, String> {Ok("a")};
/// let b = ||->Result<&str, String> {Ok("b")};
/// if let Ok(v) = cache.try_get_or_insert_mut(2, a) {
/// *v = "d";
/// }
/// assert_eq!(cache.try_get_or_insert_mut(2, a), Ok(&mut "d"));
/// assert_eq!(cache.try_get_or_insert_mut(3, f), Err("failed".to_owned()));
/// assert_eq!(cache.try_get_or_insert_mut(4, b), Ok(&mut "b"));
/// assert_eq!(cache.try_get_or_insert_mut(4, a), Ok(&mut "b"));
/// ```
pub fn try_get_or_insert_mut<'a, F, E>(&mut self, k: K, f: F) -> Result<&'a mut V, E>
where
F: FnOnce() -> Result<V, E>,
{
if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) {
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.detach(node_ptr);
self.attach(node_ptr);

unsafe { Ok(&mut *(*node_ptr).val.as_mut_ptr()) }
} else {
let v = f()?;
let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
unsafe { Ok(&mut *(*node_ptr).val.as_mut_ptr()) }
}
}

/// Returns a reference to the value corresponding to the key in the cache or `None` if it is
/// not present in the cache. Unlike `get`, `peek` does not update the LRU list so the key's
/// position will be unchanged.
Expand Down Expand Up @@ -1494,6 +1542,26 @@ mod tests {
assert_eq!(cache.get_or_insert_mut("lemon", || "red"), &"orange");
}

#[test]
fn test_try_get_or_insert_mut() {
let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());

cache.put(1, "a");
cache.put(2, "b");
cache.put(2, "c");

let f = || -> Result<&str, &str> { Err("failed") };
let a = || -> Result<&str, &str> { Ok("a") };
let b = || -> Result<&str, &str> { Ok("b") };
if let Ok(v) = cache.try_get_or_insert_mut(2, a) {
*v = "d";
}
assert_eq!(cache.try_get_or_insert_mut(2, a), Ok(&mut "d"));
assert_eq!(cache.try_get_or_insert_mut(3, f), Err("failed"));
assert_eq!(cache.try_get_or_insert_mut(4, b), Ok(&mut "b"));
assert_eq!(cache.try_get_or_insert_mut(4, a), Ok(&mut "b"));
}

#[test]
fn test_put_and_get_mut() {
let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
Expand Down

0 comments on commit b24f53f

Please sign in to comment.