Skip to content

Commit

Permalink
Fix NifTaggedEnum derived Encoder impl for named-field variants (#547)
Browse files Browse the repository at this point in the history
Use map_from_term_arrays in Nif{Map,Struct,Exception,TaggedEnum} encoders
  • Loading branch information
dylanburati authored and filmor committed Jun 29, 2023
1 parent 2cbc53a commit ea132ea
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 40 deletions.
25 changes: 24 additions & 1 deletion rustler/src/types/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<'a> Term<'a> {
/// ```elixir
/// keys = ["foo", "bar"]
/// values = [1, 2]
/// List.zip(keys, values) |> Map.new()
/// Enum.zip(keys, values) |> Map.new()
/// ```
pub fn map_from_arrays(
env: Env<'a>,
Expand All @@ -47,6 +47,29 @@ impl<'a> Term<'a> {
}
}

/// Construct a new map from two vectors of terms.
///
/// It is identical to map_from_arrays, but requires the keys and values to
/// be encoded already - this is useful for constructing maps whose values
/// or keys are different Rust types, with the same performance as map_from_arrays.
pub fn map_from_term_arrays(
env: Env<'a>,
keys: &[Term<'a>],
values: &[Term<'a>],
) -> NifResult<Term<'a>> {
if keys.len() == values.len() {
let keys: Vec<_> = keys.iter().map(|k| k.as_c_arg()).collect();
let values: Vec<_> = values.iter().map(|v| v.as_c_arg()).collect();

unsafe {
map::make_map_from_arrays(env.as_c_arg(), &keys, &values)
.map_or_else(|| Err(Error::BadArg), |map| Ok(Term::new(env, map)))
}
} else {
Err(Error::BadArg)
}
}

/// Construct a new map from pairs of terms
///
/// It is similar to `map_from_arrays` but
Expand Down
33 changes: 15 additions & 18 deletions rustler_codegen/src/ex_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,34 +136,31 @@ fn gen_encoder(
atoms_module_name: &Ident,
add_exception: bool,
) -> TokenStream {
let field_defs: Vec<TokenStream> = fields
let mut keys = vec![quote! { ::rustler::Encoder::encode(&atom_struct(), env) }];
let mut values = vec![quote! { ::rustler::Encoder::encode(&atom_module(), env) }];
if add_exception {
keys.push(quote! { ::rustler::Encoder::encode(&atom_exception(), env) });
values.push(quote! { ::rustler::Encoder::encode(&true, env) });
}
let (mut data_keys, mut data_values): (Vec<_>, Vec<_>) = fields
.iter()
.map(|field| {
let field_ident = field.ident.as_ref().unwrap();
let atom_fun = Context::field_to_atom_fun(field);
quote_spanned! { field.span() =>
map = map.map_put(#atom_fun(), &self.#field_ident).unwrap();
}
(
quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
quote! { ::rustler::Encoder::encode(&self.#field_ident, env) },
)
})
.collect();

let exception_field = if add_exception {
quote! {
map = map.map_put(atom_exception(), true).unwrap();
}
} else {
quote! {}
};
.unzip();
keys.append(&mut data_keys);
values.append(&mut data_values);

super::encode_decode_templates::encoder(
ctx,
quote! {
use #atoms_module_name::*;
let mut map = ::rustler::types::map::map_new(env);
map = map.map_put(atom_struct(), atom_module()).unwrap();
#exception_field
#(#field_defs)*
map
::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap()
},
)
}
Expand Down
17 changes: 7 additions & 10 deletions rustler_codegen/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,23 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}

fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let field_defs: Vec<TokenStream> = fields
let (keys, values): (Vec<_>, Vec<_>) = fields
.iter()
.map(|field| {
let field_ident = field.ident.as_ref().unwrap();
let atom_fun = Context::field_to_atom_fun(field);

quote_spanned! { field.span() =>
map = map.map_put(#atom_fun(), &self.#field_ident).unwrap();
}
(
quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
quote! { ::rustler::Encoder::encode(&self.#field_ident, env) },
)
})
.collect();
.unzip();

super::encode_decode_templates::encoder(
ctx,
quote! {
use #atoms_module_name::*;

let mut map = ::rustler::types::map::map_new(env);
#(#field_defs)*
map
::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap()
},
)
}
7 changes: 5 additions & 2 deletions rustler_codegen/src/tagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,17 @@ fn gen_named_encoder(
.as_ref()
.expect("Named fields must have an ident.");
let atom_fun = Context::field_to_atom_fun(field);
(atom_fun, field_ident)
(
quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
quote! { ::rustler::Encoder::encode(&#field_ident, env) },
)
})
.unzip();
quote! {
#enum_name :: #variant_ident{
#(#field_decls)*
} => {
let map = ::rustler::Term::map_from_arrays(env, &[#(#keys()),*], &[#(#values),*])
let map = ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*])
.expect("Failed to create map");
::rustler::types::tuple::make_tuple(env, &[::rustler::Encoder::encode(&#atom_fn(), env), map])
}
Expand Down
1 change: 1 addition & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ defmodule RustlerTest do
def tagged_enum_1_echo(_), do: err()
def tagged_enum_2_echo(_), do: err()
def tagged_enum_3_echo(_), do: err()
def tagged_enum_4_echo(_), do: err()
def untagged_enum_echo(_), do: err()
def untagged_enum_with_truthy(_), do: err()
def untagged_enum_for_issue_370(_), do: err()
Expand Down
1 change: 1 addition & 0 deletions rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ rustler::init!(
test_codegen::tagged_enum_1_echo,
test_codegen::tagged_enum_2_echo,
test_codegen::tagged_enum_3_echo,
test_codegen::tagged_enum_4_echo,
test_codegen::untagged_enum_echo,
test_codegen::untagged_enum_with_truthy,
test_codegen::untagged_enum_for_issue_370,
Expand Down
29 changes: 29 additions & 0 deletions rustler_tests/native/rustler_test/src/test_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub fn record_echo(record: AddRecord) -> AddRecord {
pub struct AddMap {
lhs: i32,
rhs: i32,
loc: (u32, u32),
}

#[rustler::nif]
Expand All @@ -57,12 +58,14 @@ pub fn map_echo(map: AddMap) -> AddMap {
pub struct AddStruct {
lhs: i32,
rhs: i32,
loc: (u32, u32),
}

#[derive(Debug, NifException)]
#[module = "AddException"]
pub struct AddException {
message: String,
loc: (u32, u32),
}

#[rustler::nif]
Expand Down Expand Up @@ -125,6 +128,32 @@ pub fn tagged_enum_3_echo(tagged_enum: TaggedEnum3) -> TaggedEnum3 {
tagged_enum
}

#[derive(NifTaggedEnum)]
pub enum TaggedEnum4 {
Unit,
Unnamed(u64, bool),
Named {
size: u64,
filename: String,
},
Long {
f0: bool,
f1: u8,
f2: u8,
f3: u8,
f4: u8,
f5: Option<i32>,
f6: Option<i32>,
f7: Option<i32>,
f8: Option<i32>,
},
}

#[rustler::nif]
pub fn tagged_enum_4_echo(tagged_enum: TaggedEnum4) -> TaggedEnum4 {
tagged_enum
}

#[derive(NifUntaggedEnum)]
pub enum UntaggedEnum {
Foo(u32),
Expand Down
69 changes: 60 additions & 9 deletions rustler_tests/test/codegen_test.exs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
defmodule AddStruct do
defstruct lhs: 0, rhs: 0
defstruct lhs: 0, rhs: 0, loc: {1, 1}
end

defmodule AddException do
defexception message: ""
defexception message: "", loc: {1, 1}
end

defmodule AddRecord do
Expand Down Expand Up @@ -40,13 +40,13 @@ defmodule RustlerTest.CodegenTest do
end

describe "map" do
test "transcoder" do
value = %{lhs: 1, rhs: 2}
test "transcoder 1" do
value = %{lhs: 1, rhs: 2, loc: {52, 15}}
assert value == RustlerTest.map_echo(value)
end

test "with invalid map" do
value = %{lhs: "invalid", rhs: 2}
value = %{lhs: "invalid", rhs: 2, loc: {57, 15}}

assert_raise ErlangError, "Erlang error: \"Could not decode field :lhs on %{}\"", fn ->
assert value == RustlerTest.map_echo(value)
Expand All @@ -56,7 +56,7 @@ defmodule RustlerTest.CodegenTest do

describe "struct" do
test "transcoder" do
value = %AddStruct{lhs: 45, rhs: 123}
value = %AddStruct{lhs: 45, rhs: 123, loc: {66, 15}}
assert value == RustlerTest.struct_echo(value)

assert %ErlangError{original: :invalid_struct} ==
Expand All @@ -66,19 +66,27 @@ defmodule RustlerTest.CodegenTest do
end

test "with invalid struct" do
value = %AddStruct{lhs: "lhs", rhs: 123}
value = %AddStruct{lhs: "lhs", rhs: 123, loc: {76, 15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :lhs on %AddStruct{}\"",
fn ->
RustlerTest.struct_echo(value)
end

value = %AddStruct{lhs: 45, rhs: 123, loc: {-76, -15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :loc on %AddStruct{}\"",
fn ->
RustlerTest.struct_echo(value)
end
end
end

describe "exception" do
test "transcoder" do
value = %AddException{message: "testing"}
value = %AddException{message: "testing", loc: {96, 15}}
assert value == RustlerTest.exception_echo(value)

assert %ErlangError{original: :invalid_struct} ==
Expand All @@ -88,13 +96,21 @@ defmodule RustlerTest.CodegenTest do
end

test "with invalid struct" do
value = %AddException{message: 'this is a charlist'}
value = %AddException{message: 'this is a charlist', loc: {106, 15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :message on %AddException{}\"",
fn ->
RustlerTest.exception_echo(value)
end

value = %AddException{message: "testing", loc: %{line: 114, col: 15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :loc on %AddException{}\"",
fn ->
RustlerTest.exception_echo(value)
end
end
end

Expand Down Expand Up @@ -279,6 +295,41 @@ defmodule RustlerTest.CodegenTest do
end)
end

test "tagged enum transcoder 4" do
assert :unit == RustlerTest.tagged_enum_4_echo(:unit)

assert {:unnamed, 10_000_000_000, true} ==
RustlerTest.tagged_enum_4_echo({:unnamed, 10_000_000_000, true})

assert {:named, %{filename: "\u2200", size: 123}} ==
RustlerTest.tagged_enum_4_echo({:named, %{filename: "\u2200", size: 123}})

long_map = %{f0: true, f1: 8, f2: 5, f3: 12, f4: 12, f5: 15, f6: nil, f7: nil, f8: nil}

assert {:long, long_map} ==
RustlerTest.tagged_enum_4_echo({:long, long_map})

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo(:unnamed)
end)

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo({:unit, 2, false})
end)

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo({:named, "@", 45})
end)

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo(nil)
end)
end

test "untagged enum transcoder" do
assert 123 == RustlerTest.untagged_enum_echo(123)
assert "Hello" == RustlerTest.untagged_enum_echo("Hello")
Expand Down

0 comments on commit ea132ea

Please sign in to comment.