Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: compare local pids #611

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions rustler/src/types/local_pid.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::wrapper::{pid, ErlNifPid};
use crate::{Decoder, Encoder, Env, Error, NifResult, Term};
use std::cmp::Ordering;
use std::mem::MaybeUninit;

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -36,6 +37,27 @@ impl Encoder for LocalPid {
}
}

impl PartialEq for LocalPid {
fn eq(&self, other: &Self) -> bool {
unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) == 0 }
}
}

impl Eq for LocalPid {}

impl PartialOrd for LocalPid {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for LocalPid {
fn cmp(&self, other: &Self) -> Ordering {
let cmp = unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) };
cmp.cmp(&0)
}
}

impl<'a> Env<'a> {
/// Return the calling process's pid.
///
Expand Down
6 changes: 6 additions & 0 deletions rustler_sys/src/rustler_sys_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ pub unsafe fn enif_make_pid(_env: *mut ErlNifEnv, pid: ErlNifPid) -> ERL_NIF_TER
pid.pid
}

/// See [enif_compare_pids](http://erlang.org/doc/man/erl_nif.html#enif_compare_pids) in the Erlang docs
pub unsafe fn enif_compare_pids(pid1: *const ErlNifPid, pid2: *const ErlNifPid) -> c_int {
// Mimics the implementation of the enif_compare_pids macro
enif_compare((*pid1).pid, (*pid2).pid)
}

/// See [ErlNifSysInfo](http://www.erlang.org/doc/man/erl_nif.html#ErlNifSysInfo) in the Erlang docs.
#[allow(missing_copy_implementations)]
#[repr(C)]
Expand Down
3 changes: 3 additions & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ defmodule RustlerTest do
def sum_list(_), do: err()
def make_list(), do: err()

def compare_local_pids(_, _), do: err()
def are_equal_local_pids(_, _), do: err()

def term_debug(_), do: err()

def term_debug_and_reparse(term) do
Expand Down
3 changes: 3 additions & 0 deletions rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod test_dirty;
mod test_env;
mod test_error;
mod test_list;
mod test_local_pid;
mod test_map;
mod test_nif_attrs;
mod test_primitives;
Expand All @@ -27,6 +28,8 @@ rustler::init!(
test_primitives::echo_i128,
test_list::sum_list,
test_list::make_list,
test_local_pid::compare_local_pids,
test_local_pid::are_equal_local_pids,
test_term::term_debug,
test_term::term_eq,
test_term::term_cmp,
Expand Down
17 changes: 17 additions & 0 deletions rustler_tests/native/rustler_test/src/test_local_pid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use std::cmp::Ordering;

use rustler::LocalPid;

#[rustler::nif]
pub fn compare_local_pids(lhs: LocalPid, rhs: LocalPid) -> i32 {
match lhs.cmp(&rhs) {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
}
}

#[rustler::nif]
pub fn are_equal_local_pids(lhs: LocalPid, rhs: LocalPid) -> bool {
lhs == rhs
}
35 changes: 35 additions & 0 deletions rustler_tests/test/local_pid_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
defmodule RustlerTest.LocalPidTest do
use ExUnit.Case, async: true

def make_pid() do
{:ok, pid} = Task.start(fn -> :ok end)
pid
end

def compare(lhs, rhs) do
cond do
lhs < rhs -> -1
lhs == rhs -> 0
lhs > rhs -> 1
end
end

test "local pid comparison" do
# We make sure that the code we have in rust code matches the comparisons
# that are performed in the BEAM code.
pids = for _ <- 1..3, do: make_pid()

for lhs <- pids, rhs <- pids do
assert RustlerTest.compare_local_pids(lhs, rhs) == compare(lhs, rhs)
end
end

test "local pid equality" do
pids = for _ <- 1..3, do: make_pid()

for lhs <- pids, rhs <- pids do
expected = lhs == rhs
assert RustlerTest.are_equal_local_pids(lhs, rhs) == expected
end
end
end
Loading