From 85db85042821396f7d5151a965d1e8233ac7738f Mon Sep 17 00:00:00 2001 From: Gus Caplan Date: Tue, 5 Mar 2024 08:45:40 -0800 Subject: [PATCH 1/2] add reverse map iteration --- rustler/src/types/map.rs | 136 +++++++++++++++--- rustler/src/wrapper/map.rs | 21 ++- rustler_tests/lib/rustler_test.ex | 3 +- rustler_tests/native/rustler_test/src/lib.rs | 3 +- .../native/rustler_test/src/test_map.rs | 18 ++- rustler_tests/test/map_test.exs | 7 +- 6 files changed, 163 insertions(+), 25 deletions(-) diff --git a/rustler/src/types/map.rs b/rustler/src/types/map.rs index 77ddf27b..72b1b55b 100644 --- a/rustler/src/types/map.rs +++ b/rustler/src/types/map.rs @@ -193,40 +193,140 @@ impl<'a> Term<'a> { } } -pub struct MapIterator<'a> { - env: Env<'a>, - iter: map::ErlNifMapIterator, +struct SimpleMapIterator<'a> { + map: Term<'a>, + entry: map::MapIteratorEntry, + iter: Option, + last_key: Option>, + done: bool, } -impl<'a> MapIterator<'a> { - pub fn new(map: Term<'a>) -> Option> { - let env = map.get_env(); - unsafe { map::map_iterator_create(env.as_c_arg(), map.as_c_arg()) } - .map(|iter| MapIterator { env, iter }) +impl<'a> SimpleMapIterator<'a> { + fn next(&mut self) -> Option<(Term<'a>, Term<'a>)> { + if self.done { + return None; + } + + let iter = loop { + match self.iter.as_mut() { + None => { + match unsafe { + map::map_iterator_create( + self.map.get_env().as_c_arg(), + self.map.as_c_arg(), + self.entry, + ) + } { + Some(iter) => { + self.iter = Some(iter); + continue; + } + None => { + self.done = true; + return None; + } + } + } + Some(iter) => { + break iter; + } + } + }; + + let env = self.map.get_env(); + + unsafe { + match map::map_iterator_get_pair(env.as_c_arg(), iter) { + Some((key, value)) => { + match self.entry { + map::MapIteratorEntry::First => { + map::map_iterator_next(env.as_c_arg(), iter); + } + map::MapIteratorEntry::Last => { + map::map_iterator_prev(env.as_c_arg(), iter); + } + } + let key = Term::new(env, key); + self.last_key = Some(key); + Some((key, Term::new(env, value))) + } + None => { + self.done = true; + None + } + } + } } } -impl<'a> Drop for MapIterator<'a> { +impl<'a> Drop for SimpleMapIterator<'a> { fn drop(&mut self) { - unsafe { - map::map_iterator_destroy(self.env.as_c_arg(), &mut self.iter); + if let Some(iter) = self.iter.as_mut() { + unsafe { + map::map_iterator_destroy(self.map.get_env().as_c_arg(), iter); + } } } } -impl<'a> Iterator for MapIterator<'a> { - type Item = (Term<'a>, Term<'a>); +pub struct MapIterator<'a> { + forward: SimpleMapIterator<'a>, + reverse: SimpleMapIterator<'a>, +} - fn next(&mut self) -> Option<(Term<'a>, Term<'a>)> { - unsafe { - map::map_iterator_get_pair(self.env.as_c_arg(), &mut self.iter).map(|(key, value)| { - map::map_iterator_next(self.env.as_c_arg(), &mut self.iter); - (Term::new(self.env, key), Term::new(self.env, value)) +impl<'a> MapIterator<'a> { + pub fn new(map: Term<'a>) -> Option> { + if map.is_map() { + Some(MapIterator { + forward: SimpleMapIterator { + map, + entry: map::MapIteratorEntry::First, + iter: None, + last_key: None, + done: false, + }, + reverse: SimpleMapIterator { + map, + entry: map::MapIteratorEntry::Last, + iter: None, + last_key: None, + done: false, + }, }) + } else { + None } } } +impl<'a> Iterator for MapIterator<'a> { + type Item = (Term<'a>, Term<'a>); + + fn next(&mut self) -> Option { + self.forward.next().and_then(|(key, value)| { + if self.reverse.last_key == Some(key) { + self.forward.done = true; + self.reverse.done = true; + return None; + } + Some((key, value)) + }) + } +} + +impl<'a> DoubleEndedIterator for MapIterator<'a> { + fn next_back(&mut self) -> Option { + self.reverse.next().and_then(|(key, value)| { + if self.forward.last_key == Some(key) { + self.forward.done = true; + self.reverse.done = true; + return None; + } + Some((key, value)) + }) + } +} + impl<'a> Decoder<'a> for MapIterator<'a> { fn decode(term: Term<'a>) -> NifResult { match MapIterator::new(term) { diff --git a/rustler/src/wrapper/map.rs b/rustler/src/wrapper/map.rs index 0009b828..9a0a81e9 100644 --- a/rustler/src/wrapper/map.rs +++ b/rustler/src/wrapper/map.rs @@ -66,13 +66,26 @@ pub unsafe fn map_update( Some(result.assume_init()) } -pub unsafe fn map_iterator_create(env: NIF_ENV, map: NIF_TERM) -> Option { +#[derive(Clone, Copy, Debug)] +pub enum MapIteratorEntry { + First, + Last, +} + +pub unsafe fn map_iterator_create( + env: NIF_ENV, + map: NIF_TERM, + entry: MapIteratorEntry, +) -> Option { let mut iter = MaybeUninit::uninit(); let success = rustler_sys::enif_map_iterator_create( env, map, iter.as_mut_ptr(), - ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_HEAD, + match entry { + MapIteratorEntry::First => ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_HEAD, + MapIteratorEntry::Last => ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_TAIL, + }, ); if success == 0 { None @@ -103,6 +116,10 @@ pub unsafe fn map_iterator_next(env: NIF_ENV, iter: &mut ErlNifMapIterator) { rustler_sys::enif_map_iterator_next(env, iter); } +pub unsafe fn map_iterator_prev(env: NIF_ENV, iter: &mut ErlNifMapIterator) { + rustler_sys::enif_map_iterator_prev(env, iter); +} + pub unsafe fn make_map_from_arrays( env: NIF_ENV, keys: &[NIF_TERM], diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index ce424d40..ce63a45d 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -50,7 +50,8 @@ defmodule RustlerTest do def term_type(_term), do: err() def sum_map_values(_), do: err() - def map_entries_sorted(_), do: err() + def map_entries(_), do: err() + def map_entries_reversed(_), do: err() def map_from_arrays(_keys, _values), do: err() def map_from_pairs(_pairs), do: err() def map_generic(_), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index efe2372e..a7a4f700 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -32,7 +32,8 @@ rustler::init!( test_term::term_phash2_hash, test_term::term_type, test_map::sum_map_values, - test_map::map_entries_sorted, + test_map::map_entries, + test_map::map_entries_reversed, test_map::map_from_arrays, test_map::map_from_pairs, test_map::map_generic, diff --git a/rustler_tests/native/rustler_test/src/test_map.rs b/rustler_tests/native/rustler_test/src/test_map.rs index e33293b9..d4e1d5ed 100644 --- a/rustler_tests/native/rustler_test/src/test_map.rs +++ b/rustler_tests/native/rustler_test/src/test_map.rs @@ -11,14 +11,28 @@ pub fn sum_map_values(iter: MapIterator) -> NifResult { } #[rustler::nif] -pub fn map_entries_sorted<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { +pub fn map_entries<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { let mut vec = vec![]; for (key, value) in iter { let key_string = key.decode::()?; vec.push((key_string, value)); } - vec.sort_by_key(|pair| pair.0.clone()); + let erlang_pairs: Vec = vec + .into_iter() + .map(|(key, value)| make_tuple(env, &[key.encode(env), value])) + .collect(); + Ok(erlang_pairs) +} + +#[rustler::nif] +pub fn map_entries_reversed<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { + let mut vec = vec![]; + for (key, value) in iter.rev() { + let key_string = key.decode::()?; + vec.push((key_string, value)); + } + let erlang_pairs: Vec = vec .into_iter() .map(|(key, value)| make_tuple(env, &[key.encode(env), value])) diff --git a/rustler_tests/test/map_test.exs b/rustler_tests/test/map_test.exs index 9a80039d..2473b7cf 100644 --- a/rustler_tests/test/map_test.exs +++ b/rustler_tests/test/map_test.exs @@ -7,8 +7,13 @@ defmodule RustlerTest.MapTest do end test "map iteration with keys" do + entries = RustlerTest.map_entries(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6}) + assert [{"a", 1}, {"b", 7}, {"c", 6}, {"d", 0}, {"e", 4}] == - RustlerTest.map_entries_sorted(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6}) + Enum.sort_by(entries, &elem(&1, 0)) + + assert Enum.reverse(entries) == + RustlerTest.map_entries_reversed(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6}) end test "map from arrays" do From fa5e63dba4faf21f790a5a73b07b5217c19f782e Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Sat, 9 Mar 2024 09:20:33 +0100 Subject: [PATCH 2/2] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb89d49..53edfa14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ versions. ### Added +- Map iterators are now [DoubleEndedIterators](https://doc.rust-lang.org/std/iter/trait.DoubleEndedIterator.html) + (#598), thus allowing being iterated in reverse using `.rev()` + ### Fixed ### Changed