Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NifTaggedEnum derived Encoder impl for named-field variants #547

Merged
merged 9 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion rustler/src/types/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<'a> Term<'a> {
/// ```elixir
/// keys = ["foo", "bar"]
/// values = [1, 2]
/// List.zip(keys, values) |> Map.new()
/// Enum.zip(keys, values) |> Map.new()
/// ```
pub fn map_from_arrays(
env: Env<'a>,
Expand All @@ -47,6 +47,29 @@ impl<'a> Term<'a> {
}
}

/// Construct a new map from two vectors of terms.
///
/// It is identical to map_from_arrays, but requires the keys and values to
/// be encoded already - this is useful for constructing maps whose values
/// or keys are different Rust types, with the same performance as map_from_arrays.
pub fn map_from_term_arrays(
env: Env<'a>,
keys: &[Term<'a>],
values: &[Term<'a>],
) -> NifResult<Term<'a>> {
if keys.len() == values.len() {
let keys: Vec<_> = keys.iter().map(|k| k.as_c_arg()).collect();
let values: Vec<_> = values.iter().map(|v| v.as_c_arg()).collect();

unsafe {
map::make_map_from_arrays(env.as_c_arg(), &keys, &values)
.map_or_else(|| Err(Error::BadArg), |map| Ok(Term::new(env, map)))
}
} else {
Err(Error::BadArg)
}
}

/// Construct a new map from pairs of terms
///
/// It is similar to `map_from_arrays` but
Expand Down
33 changes: 15 additions & 18 deletions rustler_codegen/src/ex_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,34 +136,31 @@ fn gen_encoder(
atoms_module_name: &Ident,
add_exception: bool,
) -> TokenStream {
let field_defs: Vec<TokenStream> = fields
let mut keys = vec![quote! { ::rustler::Encoder::encode(&atom_struct(), env) }];
let mut values = vec![quote! { ::rustler::Encoder::encode(&atom_module(), env) }];
if add_exception {
keys.push(quote! { ::rustler::Encoder::encode(&atom_exception(), env) });
values.push(quote! { ::rustler::Encoder::encode(&true, env) });
}
let (mut data_keys, mut data_values): (Vec<_>, Vec<_>) = fields
.iter()
.map(|field| {
let field_ident = field.ident.as_ref().unwrap();
let atom_fun = Context::field_to_atom_fun(field);
quote_spanned! { field.span() =>
map = map.map_put(#atom_fun(), &self.#field_ident).unwrap();
}
(
quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
quote! { ::rustler::Encoder::encode(&self.#field_ident, env) },
)
})
.collect();

let exception_field = if add_exception {
quote! {
map = map.map_put(atom_exception(), true).unwrap();
}
} else {
quote! {}
};
.unzip();
keys.append(&mut data_keys);
values.append(&mut data_values);

super::encode_decode_templates::encoder(
ctx,
quote! {
use #atoms_module_name::*;
let mut map = ::rustler::types::map::map_new(env);
map = map.map_put(atom_struct(), atom_module()).unwrap();
#exception_field
#(#field_defs)*
map
::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap()
},
)
}
Expand Down
17 changes: 7 additions & 10 deletions rustler_codegen/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,23 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}

fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let field_defs: Vec<TokenStream> = fields
let (keys, values): (Vec<_>, Vec<_>) = fields
.iter()
.map(|field| {
let field_ident = field.ident.as_ref().unwrap();
let atom_fun = Context::field_to_atom_fun(field);

quote_spanned! { field.span() =>
map = map.map_put(#atom_fun(), &self.#field_ident).unwrap();
}
(
quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
quote! { ::rustler::Encoder::encode(&self.#field_ident, env) },
)
})
.collect();
.unzip();

super::encode_decode_templates::encoder(
ctx,
quote! {
use #atoms_module_name::*;

let mut map = ::rustler::types::map::map_new(env);
#(#field_defs)*
map
::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap()
},
)
}
7 changes: 5 additions & 2 deletions rustler_codegen/src/tagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,17 @@ fn gen_named_encoder(
.as_ref()
.expect("Named fields must have an ident.");
let atom_fun = Context::field_to_atom_fun(field);
(atom_fun, field_ident)
(
quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
quote! { ::rustler::Encoder::encode(&#field_ident, env) },
)
})
.unzip();
quote! {
#enum_name :: #variant_ident{
#(#field_decls)*
} => {
let map = ::rustler::Term::map_from_arrays(env, &[#(#keys()),*], &[#(#values),*])
let map = ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*])
.expect("Failed to create map");
::rustler::types::tuple::make_tuple(env, &[::rustler::Encoder::encode(&#atom_fn(), env), map])
}
Expand Down
1 change: 1 addition & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ defmodule RustlerTest do
def tagged_enum_1_echo(_), do: err()
def tagged_enum_2_echo(_), do: err()
def tagged_enum_3_echo(_), do: err()
def tagged_enum_4_echo(_), do: err()
def untagged_enum_echo(_), do: err()
def untagged_enum_with_truthy(_), do: err()
def untagged_enum_for_issue_370(_), do: err()
Expand Down
1 change: 1 addition & 0 deletions rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ rustler::init!(
test_codegen::tagged_enum_1_echo,
test_codegen::tagged_enum_2_echo,
test_codegen::tagged_enum_3_echo,
test_codegen::tagged_enum_4_echo,
test_codegen::untagged_enum_echo,
test_codegen::untagged_enum_with_truthy,
test_codegen::untagged_enum_for_issue_370,
Expand Down
29 changes: 29 additions & 0 deletions rustler_tests/native/rustler_test/src/test_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub fn record_echo(record: AddRecord) -> AddRecord {
pub struct AddMap {
lhs: i32,
rhs: i32,
loc: (u32, u32),
}

#[rustler::nif]
Expand All @@ -57,12 +58,14 @@ pub fn map_echo(map: AddMap) -> AddMap {
pub struct AddStruct {
lhs: i32,
rhs: i32,
loc: (u32, u32),
}

#[derive(Debug, NifException)]
#[module = "AddException"]
pub struct AddException {
message: String,
loc: (u32, u32),
}

#[rustler::nif]
Expand Down Expand Up @@ -125,6 +128,32 @@ pub fn tagged_enum_3_echo(tagged_enum: TaggedEnum3) -> TaggedEnum3 {
tagged_enum
}

#[derive(NifTaggedEnum)]
pub enum TaggedEnum4 {
Unit,
Unnamed(u64, bool),
Named {
size: u64,
filename: String,
},
Long {
f0: bool,
f1: u8,
f2: u8,
f3: u8,
f4: u8,
f5: Option<i32>,
f6: Option<i32>,
f7: Option<i32>,
f8: Option<i32>,
},
}

#[rustler::nif]
pub fn tagged_enum_4_echo(tagged_enum: TaggedEnum4) -> TaggedEnum4 {
tagged_enum
}

#[derive(NifUntaggedEnum)]
pub enum UntaggedEnum {
Foo(u32),
Expand Down
69 changes: 60 additions & 9 deletions rustler_tests/test/codegen_test.exs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
defmodule AddStruct do
defstruct lhs: 0, rhs: 0
defstruct lhs: 0, rhs: 0, loc: {1, 1}
end

defmodule AddException do
defexception message: ""
defexception message: "", loc: {1, 1}
end

defmodule AddRecord do
Expand Down Expand Up @@ -40,13 +40,13 @@ defmodule RustlerTest.CodegenTest do
end

describe "map" do
test "transcoder" do
value = %{lhs: 1, rhs: 2}
test "transcoder 1" do
value = %{lhs: 1, rhs: 2, loc: {52, 15}}
assert value == RustlerTest.map_echo(value)
end

test "with invalid map" do
value = %{lhs: "invalid", rhs: 2}
value = %{lhs: "invalid", rhs: 2, loc: {57, 15}}

assert_raise ErlangError, "Erlang error: \"Could not decode field :lhs on %{}\"", fn ->
assert value == RustlerTest.map_echo(value)
Expand All @@ -56,7 +56,7 @@ defmodule RustlerTest.CodegenTest do

describe "struct" do
test "transcoder" do
value = %AddStruct{lhs: 45, rhs: 123}
value = %AddStruct{lhs: 45, rhs: 123, loc: {66, 15}}
assert value == RustlerTest.struct_echo(value)

assert %ErlangError{original: :invalid_struct} ==
Expand All @@ -66,19 +66,27 @@ defmodule RustlerTest.CodegenTest do
end

test "with invalid struct" do
value = %AddStruct{lhs: "lhs", rhs: 123}
value = %AddStruct{lhs: "lhs", rhs: 123, loc: {76, 15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :lhs on %AddStruct{}\"",
fn ->
RustlerTest.struct_echo(value)
end

value = %AddStruct{lhs: 45, rhs: 123, loc: {-76, -15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :loc on %AddStruct{}\"",
fn ->
RustlerTest.struct_echo(value)
end
end
end

describe "exception" do
test "transcoder" do
value = %AddException{message: "testing"}
value = %AddException{message: "testing", loc: {96, 15}}
assert value == RustlerTest.exception_echo(value)

assert %ErlangError{original: :invalid_struct} ==
Expand All @@ -88,13 +96,21 @@ defmodule RustlerTest.CodegenTest do
end

test "with invalid struct" do
value = %AddException{message: 'this is a charlist'}
value = %AddException{message: 'this is a charlist', loc: {106, 15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :message on %AddException{}\"",
fn ->
RustlerTest.exception_echo(value)
end

value = %AddException{message: "testing", loc: %{line: 114, col: 15}}

assert_raise ErlangError,
"Erlang error: \"Could not decode field :loc on %AddException{}\"",
fn ->
RustlerTest.exception_echo(value)
end
end
end

Expand Down Expand Up @@ -279,6 +295,41 @@ defmodule RustlerTest.CodegenTest do
end)
end

test "tagged enum transcoder 4" do
assert :unit == RustlerTest.tagged_enum_4_echo(:unit)

assert {:unnamed, 10_000_000_000, true} ==
RustlerTest.tagged_enum_4_echo({:unnamed, 10_000_000_000, true})

assert {:named, %{filename: "\u2200", size: 123}} ==
RustlerTest.tagged_enum_4_echo({:named, %{filename: "\u2200", size: 123}})

long_map = %{f0: true, f1: 8, f2: 5, f3: 12, f4: 12, f5: 15, f6: nil, f7: nil, f8: nil}

assert {:long, long_map} ==
RustlerTest.tagged_enum_4_echo({:long, long_map})

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo(:unnamed)
end)

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo({:unit, 2, false})
end)

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo({:named, "@", 45})
end)

assert %ErlangError{original: :invalid_variant} ==
assert_raise(ErlangError, fn ->
RustlerTest.tagged_enum_4_echo(nil)
end)
end

test "untagged enum transcoder" do
assert 123 == RustlerTest.untagged_enum_echo(123)
assert "Hello" == RustlerTest.untagged_enum_echo("Hello")
Expand Down