Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Database::get_or_put: insert if value does not exist, otherwise return previous value #252

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 306 additions & 3 deletions heed/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,8 @@ impl<KC, DC, C> Database<KC, DC, C> {
RwCursor::new(txn, self.dbi).map(|cursor| RwRevPrefix::new(cursor, prefix_bytes))
}

/// Insert a key-value pair in this database. The entry is written with no specific flag.
/// Insert a key-value pair in this database, replacing any previous value. The entry is
/// written with no specific flag.
///
/// ```
/// # use std::fs;
Expand Down Expand Up @@ -1842,7 +1843,8 @@ impl<KC, DC, C> Database<KC, DC, C> {
Ok(())
}

/// Insert a key-value pair where the value can directly be written to disk.
/// Insert a key-value pair where the value can directly be written to disk, replacing any
/// previous value.
///
/// ```
/// # use std::fs;
Expand Down Expand Up @@ -1908,7 +1910,8 @@ impl<KC, DC, C> Database<KC, DC, C> {
}
}

/// Insert a key-value pair in this database. The entry is written with the specified flags.
/// Insert a key-value pair in this database, replacing any previous value. The entry is
/// written with the specified flags.
///
/// ```
/// # use std::fs;
Expand Down Expand Up @@ -1993,6 +1996,281 @@ impl<KC, DC, C> Database<KC, DC, C> {
Ok(())
}

/// Attempt to insert a key-value pair in this database, or if a value already exists for the
/// key, returns the previous value.
///
/// The entry is always written with the [`NO_OVERWRITE`](PutFlags::NO_OVERWRITE) flag.
///
/// ```
/// # use heed::EnvOpenOptions;
/// use heed::Database;
/// use heed::types::*;
/// use heed::byteorder::BigEndian;
///
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let dir = tempfile::tempdir()?;
/// # let env = unsafe { EnvOpenOptions::new()
/// # .map_size(10 * 1024 * 1024) // 10MB
/// # .max_dbs(3000)
/// # .open(dir.path())?
/// # };
/// type BEI32 = I32<BigEndian>;
///
/// let mut wtxn = env.write_txn()?;
/// let db: Database<BEI32, Str> = env.create_database(&mut wtxn, Some("iter-i32"))?;
///
/// # db.clear(&mut wtxn)?;
/// assert_eq!(db.get_or_put(&mut wtxn, &42, "i-am-forty-two")?, None);
/// assert_eq!(db.get_or_put(&mut wtxn, &42, "the meaning of life")?, Some("i-am-forty-two"));
///
/// let ret = db.get(&mut wtxn, &42)?;
/// assert_eq!(ret, Some("i-am-forty-two"));
///
/// wtxn.commit()?;
/// # Ok(()) }
/// ```
pub fn get_or_put<'a, 'txn>(
&'txn self,
txn: &mut RwTxn,
key: &'a KC::EItem,
data: &'a DC::EItem,
) -> Result<Option<DC::DItem>>
where
KC: BytesEncode<'a>,
DC: BytesEncode<'a> + BytesDecode<'a>,
{
self.get_or_put_with_flags(txn, PutFlags::empty(), key, data)
}

/// Attempt to insert a key-value pair in this database, or if a value already exists for the
/// key, returns the previous value.
///
/// The entry is written with the specified flags, in addition to
/// [`NO_OVERWRITE`](PutFlags::NO_OVERWRITE) which is always used.
///
/// ```
/// # use heed::EnvOpenOptions;
/// use heed::{Database, PutFlags};
/// use heed::types::*;
/// use heed::byteorder::BigEndian;
///
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let dir = tempfile::tempdir()?;
/// # let env = unsafe { EnvOpenOptions::new()
/// # .map_size(10 * 1024 * 1024) // 10MB
/// # .max_dbs(3000)
/// # .open(dir.path())?
/// # };
/// type BEI32 = I32<BigEndian>;
///
/// let mut wtxn = env.write_txn()?;
/// let db: Database<BEI32, Str> = env.create_database(&mut wtxn, Some("iter-i32"))?;
///
/// # db.clear(&mut wtxn)?;
/// assert_eq!(db.get_or_put_with_flags(&mut wtxn, PutFlags::empty(), &42, "i-am-forty-two")?, None);
/// assert_eq!(db.get_or_put_with_flags(&mut wtxn, PutFlags::empty(), &42, "the meaning of life")?, Some("i-am-forty-two"));
///
/// let ret = db.get(&mut wtxn, &42)?;
/// assert_eq!(ret, Some("i-am-forty-two"));
///
/// wtxn.commit()?;
/// # Ok(()) }
/// ```
pub fn get_or_put_with_flags<'a, 'txn>(
&'txn self,
txn: &mut RwTxn,
flags: PutFlags,
key: &'a KC::EItem,
data: &'a DC::EItem,
) -> Result<Option<DC::DItem>>
where
KC: BytesEncode<'a>,
DC: BytesEncode<'a> + BytesDecode<'a>,
{
assert_eq_env_db_txn!(self, txn);

let key_bytes: Cow<[u8]> = KC::bytes_encode(key).map_err(Error::Encoding)?;
let data_bytes: Cow<[u8]> = DC::bytes_encode(data).map_err(Error::Encoding)?;

let mut key_val = unsafe { crate::into_val(&key_bytes) };
let mut data_val = unsafe { crate::into_val(&data_bytes) };
let flags = (flags | PutFlags::NO_OVERWRITE).bits();

let result = unsafe {
mdb_result(ffi::mdb_put(txn.txn.txn, self.dbi, &mut key_val, &mut data_val, flags))
};

match result {
// the value was successfully inserted
Ok(()) => Ok(None),
// the key already exists: the previous value is stored in the data parameter
Err(MdbError::KeyExist) => {
let bytes = unsafe { crate::from_val(data_val) };
let data = DC::bytes_decode(bytes).map_err(Error::Decoding)?;
Ok(Some(data))
}
Err(error) => Err(error.into()),
}
}

/// Attempt to insert a key-value pair in this database, where the value can be directly
/// written to disk, or if a value already exists for the key, returns the previous value.
///
/// The entry is always written with the [`NO_OVERWRITE`](PutFlags::NO_OVERWRITE) and
/// [`MDB_RESERVE`](lmdb_master_sys::MDB_RESERVE) flags.
///
/// ```
/// # use heed::EnvOpenOptions;
/// use std::io::Write;
/// use heed::{Database, PutFlags};
/// use heed::types::*;
/// use heed::byteorder::BigEndian;
///
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let dir = tempfile::tempdir()?;
/// # let env = unsafe { EnvOpenOptions::new()
/// # .map_size(10 * 1024 * 1024) // 10MB
/// # .max_dbs(3000)
/// # .open(dir.path())?
/// # };
/// type BEI32 = I32<BigEndian>;
///
/// let mut wtxn = env.write_txn()?;
/// let db = env.create_database::<BEI32, Str>(&mut wtxn, Some("number-string"))?;
///
/// # db.clear(&mut wtxn)?;
/// let long = "I am a long long long value";
/// assert_eq!(
/// db.get_or_put_reserved(&mut wtxn, &42, long.len(), |reserved| {
/// reserved.write_all(long.as_bytes())
/// })?,
/// None
/// );
///
/// let longer = "I am an even longer long long long value";
/// assert_eq!(
/// db.get_or_put_reserved(&mut wtxn, &42, longer.len(), |reserved| {
/// unreachable!()
/// })?,
/// Some(long)
/// );
///
/// let ret = db.get(&mut wtxn, &42)?;
/// assert_eq!(ret, Some(long));
///
/// wtxn.commit()?;
/// # Ok(()) }
/// ```
pub fn get_or_put_reserved<'a, 'txn, F>(
&'txn self,
txn: &mut RwTxn,
key: &'a KC::EItem,
data_size: usize,
write_func: F,
) -> Result<Option<DC::DItem>>
where
KC: BytesEncode<'a>,
F: FnOnce(&mut ReservedSpace) -> io::Result<()>,
DC: BytesDecode<'a>,
{
self.get_or_put_reserved_with_flags(txn, PutFlags::empty(), key, data_size, write_func)
}

/// Attempt to insert a key-value pair in this database, where the value can be directly
/// written to disk, or if a value already exists for the key, returns the previous value.
///
/// The entry is written with the specified flags, in addition to
/// [`NO_OVERWRITE`](PutFlags::NO_OVERWRITE) and [`MDB_RESERVE`](lmdb_master_sys::MDB_RESERVE)
/// which are always used.
///
/// ```
/// # use heed::EnvOpenOptions;
/// use std::io::Write;
/// use heed::{Database, PutFlags};
/// use heed::types::*;
/// use heed::byteorder::BigEndian;
///
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let dir = tempfile::tempdir()?;
/// # let env = unsafe { EnvOpenOptions::new()
/// # .map_size(10 * 1024 * 1024) // 10MB
/// # .max_dbs(3000)
/// # .open(dir.path())?
/// # };
/// type BEI32 = I32<BigEndian>;
///
/// let mut wtxn = env.write_txn()?;
/// let db = env.create_database::<BEI32, Str>(&mut wtxn, Some("number-string"))?;
///
/// # db.clear(&mut wtxn)?;
/// let long = "I am a long long long value";
/// assert_eq!(
/// db.get_or_put_reserved_with_flags(&mut wtxn, PutFlags::empty(), &42, long.len(), |reserved| {
/// reserved.write_all(long.as_bytes())
/// })?,
/// None
/// );
///
/// let longer = "I am an even longer long long long value";
/// assert_eq!(
/// db.get_or_put_reserved_with_flags(&mut wtxn, PutFlags::empty(), &42, longer.len(), |reserved| {
/// unreachable!()
/// })?,
/// Some(long)
/// );
///
/// let ret = db.get(&mut wtxn, &42)?;
/// assert_eq!(ret, Some(long));
///
/// wtxn.commit()?;
/// # Ok(()) }
/// ```
pub fn get_or_put_reserved_with_flags<'a, 'txn, F>(
&'txn self,
txn: &mut RwTxn,
flags: PutFlags,
key: &'a KC::EItem,
data_size: usize,
write_func: F,
) -> Result<Option<DC::DItem>>
where
KC: BytesEncode<'a>,
F: FnOnce(&mut ReservedSpace) -> io::Result<()>,
DC: BytesDecode<'a>,
{
assert_eq_env_db_txn!(self, txn);

let key_bytes: Cow<[u8]> = KC::bytes_encode(key).map_err(Error::Encoding)?;

let mut key_val = unsafe { crate::into_val(&key_bytes) };
let mut reserved = ffi::reserve_size_val(data_size);
let flags = (flags | PutFlags::NO_OVERWRITE).bits() | lmdb_master_sys::MDB_RESERVE;

let result = unsafe {
mdb_result(ffi::mdb_put(txn.txn.txn, self.dbi, &mut key_val, &mut reserved, flags))
};

match result {
// value was inserted: fill the reserved space
Ok(()) => {
let mut reserved = unsafe { ReservedSpace::from_val(reserved) };
write_func(&mut reserved)?;
if reserved.remaining() == 0 {
Ok(None)
} else {
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
}
}
// the key already exists: the previous value is stored in the data parameter
Err(MdbError::KeyExist) => {
let bytes = unsafe { crate::from_val(reserved) };
let data = DC::bytes_decode(bytes).map_err(Error::Decoding)?;
Ok(Some(data))
}
Err(error) => Err(error.into()),
}
}

/// Deletes an entry or every duplicate data items of a key
/// if the database supports duplicate data items.
///
Expand Down Expand Up @@ -2355,3 +2633,28 @@ pub struct DatabaseStat {
/// Number of data items.
pub entries: usize,
}

#[cfg(test)]
mod tests {
use heed_types::*;

use super::*;

#[test]
fn put_overwrite() -> Result<()> {
let dir = tempfile::tempdir()?;
let env = unsafe { EnvOpenOptions::new().open(dir.path())? };
let mut txn = env.write_txn()?;
let db = env.create_database::<Bytes, Bytes>(&mut txn, None)?;

assert_eq!(db.get(&txn, b"hello").unwrap(), None);

db.put(&mut txn, b"hello", b"hi").unwrap();
assert_eq!(db.get(&txn, b"hello").unwrap(), Some(&b"hi"[..]));

db.put(&mut txn, b"hello", b"bye").unwrap();
assert_eq!(db.get(&txn, b"hello").unwrap(), Some(&b"bye"[..]));

Ok(())
}
}
6 changes: 3 additions & 3 deletions heed/src/iterator/prefix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::*;
/// defined by the `C` comparator. If no successor exists (i.e. `bytes` is the maximal
/// value), it remains unchanged and the function returns `false`. Otherwise, updates
/// `bytes` and returns `true`.
fn advance_prefix<C: LexicographicComparator>(bytes: &mut Vec<u8>) -> bool {
fn advance_prefix<C: LexicographicComparator>(bytes: &mut [u8]) -> bool {
let mut idx = bytes.len();
while idx > 0 && bytes[idx - 1] == C::max_elem() {
idx -= 1;
Expand All @@ -32,7 +32,7 @@ fn advance_prefix<C: LexicographicComparator>(bytes: &mut Vec<u8>) -> bool {
/// defined by the `C` comparator. If no predecessor exists (i.e. `bytes` is the minimum
/// value), it remains unchanged and the function returns `false`. Otherwise, updates
/// `bytes` and returns `true`.
fn retreat_prefix<C: LexicographicComparator>(bytes: &mut Vec<u8>) -> bool {
fn retreat_prefix<C: LexicographicComparator>(bytes: &mut [u8]) -> bool {
let mut idx = bytes.len();
while idx > 0 && bytes[idx - 1] == C::min_elem() {
idx -= 1;
Expand All @@ -49,7 +49,7 @@ fn retreat_prefix<C: LexicographicComparator>(bytes: &mut Vec<u8>) -> bool {

fn move_on_prefix_end<'txn, C: LexicographicComparator>(
cursor: &mut RoCursor<'txn>,
prefix: &mut Vec<u8>,
prefix: &mut [u8],
) -> Result<Option<(&'txn [u8], &'txn [u8])>> {
if advance_prefix::<C>(prefix) {
let result = cursor
Expand Down
Loading