Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reverse map iteration #596

Merged
merged 2 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 118 additions & 18 deletions rustler/src/types/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,40 +193,140 @@ impl<'a> Term<'a> {
}
}

pub struct MapIterator<'a> {
env: Env<'a>,
iter: map::ErlNifMapIterator,
struct SimpleMapIterator<'a> {
map: Term<'a>,
entry: map::MapIteratorEntry,
iter: Option<map::ErlNifMapIterator>,
last_key: Option<Term<'a>>,
done: bool,
}

impl<'a> MapIterator<'a> {
pub fn new(map: Term<'a>) -> Option<MapIterator<'a>> {
let env = map.get_env();
unsafe { map::map_iterator_create(env.as_c_arg(), map.as_c_arg()) }
.map(|iter| MapIterator { env, iter })
impl<'a> SimpleMapIterator<'a> {
fn next(&mut self) -> Option<(Term<'a>, Term<'a>)> {
if self.done {
return None;
}

let iter = loop {
match self.iter.as_mut() {
None => {
match unsafe {
map::map_iterator_create(
self.map.get_env().as_c_arg(),
self.map.as_c_arg(),
self.entry,
)
} {
Some(iter) => {
self.iter = Some(iter);
continue;
}
None => {
self.done = true;
return None;
}
}
}
Some(iter) => {
break iter;
}
}
};

let env = self.map.get_env();

unsafe {
match map::map_iterator_get_pair(env.as_c_arg(), iter) {
Some((key, value)) => {
match self.entry {
map::MapIteratorEntry::First => {
map::map_iterator_next(env.as_c_arg(), iter);
}
map::MapIteratorEntry::Last => {
map::map_iterator_prev(env.as_c_arg(), iter);
}
}
let key = Term::new(env, key);
self.last_key = Some(key);
Some((key, Term::new(env, value)))
}
None => {
self.done = true;
None
}
}
}
}
}

impl<'a> Drop for MapIterator<'a> {
impl<'a> Drop for SimpleMapIterator<'a> {
fn drop(&mut self) {
unsafe {
map::map_iterator_destroy(self.env.as_c_arg(), &mut self.iter);
if let Some(iter) = self.iter.as_mut() {
unsafe {
map::map_iterator_destroy(self.map.get_env().as_c_arg(), iter);
}
}
}
}

impl<'a> Iterator for MapIterator<'a> {
type Item = (Term<'a>, Term<'a>);
pub struct MapIterator<'a> {
forward: SimpleMapIterator<'a>,
reverse: SimpleMapIterator<'a>,
}

fn next(&mut self) -> Option<(Term<'a>, Term<'a>)> {
unsafe {
map::map_iterator_get_pair(self.env.as_c_arg(), &mut self.iter).map(|(key, value)| {
map::map_iterator_next(self.env.as_c_arg(), &mut self.iter);
(Term::new(self.env, key), Term::new(self.env, value))
impl<'a> MapIterator<'a> {
pub fn new(map: Term<'a>) -> Option<MapIterator<'a>> {
if map.is_map() {
Some(MapIterator {
forward: SimpleMapIterator {
map,
entry: map::MapIteratorEntry::First,
iter: None,
last_key: None,
done: false,
},
reverse: SimpleMapIterator {
map,
entry: map::MapIteratorEntry::Last,
iter: None,
last_key: None,
done: false,
},
})
} else {
None
}
}
}

impl<'a> Iterator for MapIterator<'a> {
type Item = (Term<'a>, Term<'a>);

fn next(&mut self) -> Option<Self::Item> {
self.forward.next().and_then(|(key, value)| {
if self.reverse.last_key == Some(key) {
self.forward.done = true;
self.reverse.done = true;
return None;
}
Some((key, value))
})
}
}

impl<'a> DoubleEndedIterator for MapIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
self.reverse.next().and_then(|(key, value)| {
if self.forward.last_key == Some(key) {
self.forward.done = true;
self.reverse.done = true;
return None;
}
Some((key, value))
})
}
}

impl<'a> Decoder<'a> for MapIterator<'a> {
fn decode(term: Term<'a>) -> NifResult<Self> {
match MapIterator::new(term) {
Expand Down
21 changes: 19 additions & 2 deletions rustler/src/wrapper/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,26 @@ pub unsafe fn map_update(
Some(result.assume_init())
}

pub unsafe fn map_iterator_create(env: NIF_ENV, map: NIF_TERM) -> Option<ErlNifMapIterator> {
#[derive(Clone, Copy, Debug)]
pub enum MapIteratorEntry {
First,
Last,
}

pub unsafe fn map_iterator_create(
env: NIF_ENV,
map: NIF_TERM,
entry: MapIteratorEntry,
) -> Option<ErlNifMapIterator> {
let mut iter = MaybeUninit::uninit();
let success = rustler_sys::enif_map_iterator_create(
env,
map,
iter.as_mut_ptr(),
ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_HEAD,
match entry {
MapIteratorEntry::First => ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_HEAD,
MapIteratorEntry::Last => ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_TAIL,
},
);
if success == 0 {
None
Expand Down Expand Up @@ -103,6 +116,10 @@ pub unsafe fn map_iterator_next(env: NIF_ENV, iter: &mut ErlNifMapIterator) {
rustler_sys::enif_map_iterator_next(env, iter);
}

pub unsafe fn map_iterator_prev(env: NIF_ENV, iter: &mut ErlNifMapIterator) {
rustler_sys::enif_map_iterator_prev(env, iter);
}

pub unsafe fn make_map_from_arrays(
env: NIF_ENV,
keys: &[NIF_TERM],
Expand Down
3 changes: 2 additions & 1 deletion rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ defmodule RustlerTest do
def term_type(_term), do: err()

def sum_map_values(_), do: err()
def map_entries_sorted(_), do: err()
def map_entries(_), do: err()
def map_entries_reversed(_), do: err()
def map_from_arrays(_keys, _values), do: err()
def map_from_pairs(_pairs), do: err()
def map_generic(_), do: err()
Expand Down
3 changes: 2 additions & 1 deletion rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ rustler::init!(
test_term::term_phash2_hash,
test_term::term_type,
test_map::sum_map_values,
test_map::map_entries_sorted,
test_map::map_entries,
test_map::map_entries_reversed,
test_map::map_from_arrays,
test_map::map_from_pairs,
test_map::map_generic,
Expand Down
18 changes: 16 additions & 2 deletions rustler_tests/native/rustler_test/src/test_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,28 @@ pub fn sum_map_values(iter: MapIterator) -> NifResult<i64> {
}

#[rustler::nif]
pub fn map_entries_sorted<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult<Vec<Term<'a>>> {
pub fn map_entries<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult<Vec<Term<'a>>> {
let mut vec = vec![];
for (key, value) in iter {
let key_string = key.decode::<String>()?;
vec.push((key_string, value));
}

vec.sort_by_key(|pair| pair.0.clone());
let erlang_pairs: Vec<Term> = vec
.into_iter()
.map(|(key, value)| make_tuple(env, &[key.encode(env), value]))
.collect();
Ok(erlang_pairs)
}

#[rustler::nif]
pub fn map_entries_reversed<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult<Vec<Term<'a>>> {
let mut vec = vec![];
for (key, value) in iter.rev() {
let key_string = key.decode::<String>()?;
vec.push((key_string, value));
}

let erlang_pairs: Vec<Term> = vec
.into_iter()
.map(|(key, value)| make_tuple(env, &[key.encode(env), value]))
Expand Down
7 changes: 6 additions & 1 deletion rustler_tests/test/map_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ defmodule RustlerTest.MapTest do
end

test "map iteration with keys" do
entries = RustlerTest.map_entries(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6})

assert [{"a", 1}, {"b", 7}, {"c", 6}, {"d", 0}, {"e", 4}] ==
RustlerTest.map_entries_sorted(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6})
Enum.sort_by(entries, &elem(&1, 0))

assert Enum.reverse(entries) ==
RustlerTest.map_entries_reversed(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6})
end

test "map from arrays" do
Expand Down
Loading