diff --git a/rustler/src/types/map.rs b/rustler/src/types/map.rs index e3ea5764..77ddf27b 100644 --- a/rustler/src/types/map.rs +++ b/rustler/src/types/map.rs @@ -27,7 +27,7 @@ impl<'a> Term<'a> { /// ```elixir /// keys = ["foo", "bar"] /// values = [1, 2] - /// List.zip(keys, values) |> Map.new() + /// Enum.zip(keys, values) |> Map.new() /// ``` pub fn map_from_arrays( env: Env<'a>, @@ -47,6 +47,29 @@ impl<'a> Term<'a> { } } + /// Construct a new map from two vectors of terms. + /// + /// It is identical to map_from_arrays, but requires the keys and values to + /// be encoded already - this is useful for constructing maps whose values + /// or keys are different Rust types, with the same performance as map_from_arrays. + pub fn map_from_term_arrays( + env: Env<'a>, + keys: &[Term<'a>], + values: &[Term<'a>], + ) -> NifResult> { + if keys.len() == values.len() { + let keys: Vec<_> = keys.iter().map(|k| k.as_c_arg()).collect(); + let values: Vec<_> = values.iter().map(|v| v.as_c_arg()).collect(); + + unsafe { + map::make_map_from_arrays(env.as_c_arg(), &keys, &values) + .map_or_else(|| Err(Error::BadArg), |map| Ok(Term::new(env, map))) + } + } else { + Err(Error::BadArg) + } + } + /// Construct a new map from pairs of terms /// /// It is similar to `map_from_arrays` but diff --git a/rustler_codegen/src/ex_struct.rs b/rustler_codegen/src/ex_struct.rs index e489ca7f..e051c494 100644 --- a/rustler_codegen/src/ex_struct.rs +++ b/rustler_codegen/src/ex_struct.rs @@ -136,34 +136,31 @@ fn gen_encoder( atoms_module_name: &Ident, add_exception: bool, ) -> TokenStream { - let field_defs: Vec = fields + let mut keys = vec![quote! { ::rustler::Encoder::encode(&atom_struct(), env) }]; + let mut values = vec![quote! { ::rustler::Encoder::encode(&atom_module(), env) }]; + if add_exception { + keys.push(quote! { ::rustler::Encoder::encode(&atom_exception(), env) }); + values.push(quote! { ::rustler::Encoder::encode(&true, env) }); + } + let (mut data_keys, mut data_values): (Vec<_>, Vec<_>) = fields .iter() .map(|field| { let field_ident = field.ident.as_ref().unwrap(); let atom_fun = Context::field_to_atom_fun(field); - quote_spanned! { field.span() => - map = map.map_put(#atom_fun(), &self.#field_ident).unwrap(); - } + ( + quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, + quote! { ::rustler::Encoder::encode(&self.#field_ident, env) }, + ) }) - .collect(); - - let exception_field = if add_exception { - quote! { - map = map.map_put(atom_exception(), true).unwrap(); - } - } else { - quote! {} - }; + .unzip(); + keys.append(&mut data_keys); + values.append(&mut data_values); super::encode_decode_templates::encoder( ctx, quote! { use #atoms_module_name::*; - let mut map = ::rustler::types::map::map_new(env); - map = map.map_put(atom_struct(), atom_module()).unwrap(); - #exception_field - #(#field_defs)* - map + ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap() }, ) } diff --git a/rustler_codegen/src/map.rs b/rustler_codegen/src/map.rs index 8584484e..9065d132 100644 --- a/rustler_codegen/src/map.rs +++ b/rustler_codegen/src/map.rs @@ -107,26 +107,23 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T } fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { - let field_defs: Vec = fields + let (keys, values): (Vec<_>, Vec<_>) = fields .iter() .map(|field| { let field_ident = field.ident.as_ref().unwrap(); let atom_fun = Context::field_to_atom_fun(field); - - quote_spanned! { field.span() => - map = map.map_put(#atom_fun(), &self.#field_ident).unwrap(); - } + ( + quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, + quote! { ::rustler::Encoder::encode(&self.#field_ident, env) }, + ) }) - .collect(); + .unzip(); super::encode_decode_templates::encoder( ctx, quote! { use #atoms_module_name::*; - - let mut map = ::rustler::types::map::map_new(env); - #(#field_defs)* - map + ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap() }, ) } diff --git a/rustler_codegen/src/tagged_enum.rs b/rustler_codegen/src/tagged_enum.rs index ecdadd45..fe78d331 100644 --- a/rustler_codegen/src/tagged_enum.rs +++ b/rustler_codegen/src/tagged_enum.rs @@ -321,14 +321,17 @@ fn gen_named_encoder( .as_ref() .expect("Named fields must have an ident."); let atom_fun = Context::field_to_atom_fun(field); - (atom_fun, field_ident) + ( + quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, + quote! { ::rustler::Encoder::encode(&#field_ident, env) }, + ) }) .unzip(); quote! { #enum_name :: #variant_ident{ #(#field_decls)* } => { - let map = ::rustler::Term::map_from_arrays(env, &[#(#keys()),*], &[#(#values),*]) + let map = ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]) .expect("Failed to create map"); ::rustler::types::tuple::make_tuple(env, &[::rustler::Encoder::encode(&#atom_fn(), env), map]) } diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 0d8bb277..d75ebaa6 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -95,6 +95,7 @@ defmodule RustlerTest do def tagged_enum_1_echo(_), do: err() def tagged_enum_2_echo(_), do: err() def tagged_enum_3_echo(_), do: err() + def tagged_enum_4_echo(_), do: err() def untagged_enum_echo(_), do: err() def untagged_enum_with_truthy(_), do: err() def untagged_enum_for_issue_370(_), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 03a744d3..bd22ac1f 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -70,6 +70,7 @@ rustler::init!( test_codegen::tagged_enum_1_echo, test_codegen::tagged_enum_2_echo, test_codegen::tagged_enum_3_echo, + test_codegen::tagged_enum_4_echo, test_codegen::untagged_enum_echo, test_codegen::untagged_enum_with_truthy, test_codegen::untagged_enum_for_issue_370, diff --git a/rustler_tests/native/rustler_test/src/test_codegen.rs b/rustler_tests/native/rustler_test/src/test_codegen.rs index 00f97c9a..d7671b68 100644 --- a/rustler_tests/native/rustler_test/src/test_codegen.rs +++ b/rustler_tests/native/rustler_test/src/test_codegen.rs @@ -44,6 +44,7 @@ pub fn record_echo(record: AddRecord) -> AddRecord { pub struct AddMap { lhs: i32, rhs: i32, + loc: (u32, u32), } #[rustler::nif] @@ -57,12 +58,14 @@ pub fn map_echo(map: AddMap) -> AddMap { pub struct AddStruct { lhs: i32, rhs: i32, + loc: (u32, u32), } #[derive(Debug, NifException)] #[module = "AddException"] pub struct AddException { message: String, + loc: (u32, u32), } #[rustler::nif] @@ -125,6 +128,32 @@ pub fn tagged_enum_3_echo(tagged_enum: TaggedEnum3) -> TaggedEnum3 { tagged_enum } +#[derive(NifTaggedEnum)] +pub enum TaggedEnum4 { + Unit, + Unnamed(u64, bool), + Named { + size: u64, + filename: String, + }, + Long { + f0: bool, + f1: u8, + f2: u8, + f3: u8, + f4: u8, + f5: Option, + f6: Option, + f7: Option, + f8: Option, + }, +} + +#[rustler::nif] +pub fn tagged_enum_4_echo(tagged_enum: TaggedEnum4) -> TaggedEnum4 { + tagged_enum +} + #[derive(NifUntaggedEnum)] pub enum UntaggedEnum { Foo(u32), diff --git a/rustler_tests/test/codegen_test.exs b/rustler_tests/test/codegen_test.exs index b3d75f2e..722891d8 100644 --- a/rustler_tests/test/codegen_test.exs +++ b/rustler_tests/test/codegen_test.exs @@ -1,9 +1,9 @@ defmodule AddStruct do - defstruct lhs: 0, rhs: 0 + defstruct lhs: 0, rhs: 0, loc: {1, 1} end defmodule AddException do - defexception message: "" + defexception message: "", loc: {1, 1} end defmodule AddRecord do @@ -40,13 +40,13 @@ defmodule RustlerTest.CodegenTest do end describe "map" do - test "transcoder" do - value = %{lhs: 1, rhs: 2} + test "transcoder 1" do + value = %{lhs: 1, rhs: 2, loc: {52, 15}} assert value == RustlerTest.map_echo(value) end test "with invalid map" do - value = %{lhs: "invalid", rhs: 2} + value = %{lhs: "invalid", rhs: 2, loc: {57, 15}} assert_raise ErlangError, "Erlang error: \"Could not decode field :lhs on %{}\"", fn -> assert value == RustlerTest.map_echo(value) @@ -56,7 +56,7 @@ defmodule RustlerTest.CodegenTest do describe "struct" do test "transcoder" do - value = %AddStruct{lhs: 45, rhs: 123} + value = %AddStruct{lhs: 45, rhs: 123, loc: {66, 15}} assert value == RustlerTest.struct_echo(value) assert %ErlangError{original: :invalid_struct} == @@ -66,19 +66,27 @@ defmodule RustlerTest.CodegenTest do end test "with invalid struct" do - value = %AddStruct{lhs: "lhs", rhs: 123} + value = %AddStruct{lhs: "lhs", rhs: 123, loc: {76, 15}} assert_raise ErlangError, "Erlang error: \"Could not decode field :lhs on %AddStruct{}\"", fn -> RustlerTest.struct_echo(value) end + + value = %AddStruct{lhs: 45, rhs: 123, loc: {-76, -15}} + + assert_raise ErlangError, + "Erlang error: \"Could not decode field :loc on %AddStruct{}\"", + fn -> + RustlerTest.struct_echo(value) + end end end describe "exception" do test "transcoder" do - value = %AddException{message: "testing"} + value = %AddException{message: "testing", loc: {96, 15}} assert value == RustlerTest.exception_echo(value) assert %ErlangError{original: :invalid_struct} == @@ -88,13 +96,21 @@ defmodule RustlerTest.CodegenTest do end test "with invalid struct" do - value = %AddException{message: 'this is a charlist'} + value = %AddException{message: 'this is a charlist', loc: {106, 15}} assert_raise ErlangError, "Erlang error: \"Could not decode field :message on %AddException{}\"", fn -> RustlerTest.exception_echo(value) end + + value = %AddException{message: "testing", loc: %{line: 114, col: 15}} + + assert_raise ErlangError, + "Erlang error: \"Could not decode field :loc on %AddException{}\"", + fn -> + RustlerTest.exception_echo(value) + end end end @@ -279,6 +295,41 @@ defmodule RustlerTest.CodegenTest do end) end + test "tagged enum transcoder 4" do + assert :unit == RustlerTest.tagged_enum_4_echo(:unit) + + assert {:unnamed, 10_000_000_000, true} == + RustlerTest.tagged_enum_4_echo({:unnamed, 10_000_000_000, true}) + + assert {:named, %{filename: "\u2200", size: 123}} == + RustlerTest.tagged_enum_4_echo({:named, %{filename: "\u2200", size: 123}}) + + long_map = %{f0: true, f1: 8, f2: 5, f3: 12, f4: 12, f5: 15, f6: nil, f7: nil, f8: nil} + + assert {:long, long_map} == + RustlerTest.tagged_enum_4_echo({:long, long_map}) + + assert %ErlangError{original: :invalid_variant} == + assert_raise(ErlangError, fn -> + RustlerTest.tagged_enum_4_echo(:unnamed) + end) + + assert %ErlangError{original: :invalid_variant} == + assert_raise(ErlangError, fn -> + RustlerTest.tagged_enum_4_echo({:unit, 2, false}) + end) + + assert %ErlangError{original: :invalid_variant} == + assert_raise(ErlangError, fn -> + RustlerTest.tagged_enum_4_echo({:named, "@", 45}) + end) + + assert %ErlangError{original: :invalid_variant} == + assert_raise(ErlangError, fn -> + RustlerTest.tagged_enum_4_echo(nil) + end) + end + test "untagged enum transcoder" do assert 123 == RustlerTest.untagged_enum_echo(123) assert "Hello" == RustlerTest.untagged_enum_echo("Hello")