diff --git a/rustler/src/types/local_pid.rs b/rustler/src/types/local_pid.rs index 9babda8b..34ceee1d 100644 --- a/rustler/src/types/local_pid.rs +++ b/rustler/src/types/local_pid.rs @@ -1,5 +1,6 @@ use crate::wrapper::{pid, ErlNifPid}; use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; +use std::cmp::Ordering; use std::mem::MaybeUninit; #[derive(Copy, Clone)] @@ -36,6 +37,27 @@ impl Encoder for LocalPid { } } +impl PartialEq for LocalPid { + fn eq(&self, other: &Self) -> bool { + unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) == 0 } + } +} + +impl Eq for LocalPid {} + +impl PartialOrd for LocalPid { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for LocalPid { + fn cmp(&self, other: &Self) -> Ordering { + let cmp = unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) }; + cmp.cmp(&0) + } +} + impl<'a> Env<'a> { /// Return the calling process's pid. /// diff --git a/rustler_sys/src/rustler_sys_api.rs b/rustler_sys/src/rustler_sys_api.rs index 4ba58089..d4f93bcf 100644 --- a/rustler_sys/src/rustler_sys_api.rs +++ b/rustler_sys/src/rustler_sys_api.rs @@ -203,6 +203,12 @@ pub unsafe fn enif_make_pid(_env: *mut ErlNifEnv, pid: ErlNifPid) -> ERL_NIF_TER pid.pid } +/// See [enif_compare_pids](http://erlang.org/doc/man/erl_nif.html#enif_compare_pids) in the Erlang docs +pub unsafe fn enif_compare_pids(pid1: *const ErlNifPid, pid2: *const ErlNifPid) -> c_int { + // Mimics the implementation of the enif_compare_pids macro + enif_compare((*pid1).pid, (*pid2).pid) +} + /// See [ErlNifSysInfo](http://www.erlang.org/doc/man/erl_nif.html#ErlNifSysInfo) in the Erlang docs. #[allow(missing_copy_implementations)] #[repr(C)] diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 32326d7a..e49c37b3 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -33,6 +33,9 @@ defmodule RustlerTest do def sum_list(_), do: err() def make_list(), do: err() + def compare_local_pids(_, _), do: err() + def are_equal_local_pids(_, _), do: err() + def term_debug(_), do: err() def term_debug_and_reparse(term) do diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 7faa7115..90f10fc8 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -5,6 +5,7 @@ mod test_dirty; mod test_env; mod test_error; mod test_list; +mod test_local_pid; mod test_map; mod test_nif_attrs; mod test_primitives; @@ -27,6 +28,8 @@ rustler::init!( test_primitives::echo_i128, test_list::sum_list, test_list::make_list, + test_local_pid::compare_local_pids, + test_local_pid::are_equal_local_pids, test_term::term_debug, test_term::term_eq, test_term::term_cmp, diff --git a/rustler_tests/native/rustler_test/src/test_local_pid.rs b/rustler_tests/native/rustler_test/src/test_local_pid.rs new file mode 100644 index 00000000..e133c79b --- /dev/null +++ b/rustler_tests/native/rustler_test/src/test_local_pid.rs @@ -0,0 +1,17 @@ +use std::cmp::Ordering; + +use rustler::LocalPid; + +#[rustler::nif] +pub fn compare_local_pids(lhs: LocalPid, rhs: LocalPid) -> i32 { + match lhs.cmp(&rhs) { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } +} + +#[rustler::nif] +pub fn are_equal_local_pids(lhs: LocalPid, rhs: LocalPid) -> bool { + lhs == rhs +} diff --git a/rustler_tests/test/local_pid_test.exs b/rustler_tests/test/local_pid_test.exs new file mode 100644 index 00000000..ec4a2ead --- /dev/null +++ b/rustler_tests/test/local_pid_test.exs @@ -0,0 +1,35 @@ +defmodule RustlerTest.LocalPidTest do + use ExUnit.Case, async: true + + def make_pid() do + {:ok, pid} = Task.start(fn -> :ok end) + pid + end + + def compare(lhs, rhs) do + cond do + lhs < rhs -> -1 + lhs == rhs -> 0 + lhs > rhs -> 1 + end + end + + test "local pid comparison" do + # We make sure that the code we have in rust code matches the comparisons + # that are performed in the BEAM code. + pids = for _ <- 1..3, do: make_pid() + + for lhs <- pids, rhs <- pids do + assert RustlerTest.compare_local_pids(lhs, rhs) == compare(lhs, rhs) + end + end + + test "local pid equality" do + pids = for _ <- 1..3, do: make_pid() + + for lhs <- pids, rhs <- pids do + expected = lhs == rhs + assert RustlerTest.are_equal_local_pids(lhs, rhs) == expected + end + end +end