From e35d771173f8dd18d8ae2d2797d2019d15ff2f83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Thu, 9 Nov 2023 15:44:01 +0100 Subject: [PATCH] Allow generic type parameters --- .../src/encode_decode_templates.rs | 115 +++++++++++++++++- rustler_tests/lib/rustler_test.ex | 2 + rustler_tests/native/rustler_test/src/lib.rs | 4 +- .../native/rustler_test/src/test_codegen.rs | 25 ++++ rustler_tests/test/codegen_test.exs | 12 ++ 5 files changed, 156 insertions(+), 2 deletions(-) diff --git a/rustler_codegen/src/encode_decode_templates.rs b/rustler_codegen/src/encode_decode_templates.rs index 33909718..7e26e5ef 100644 --- a/rustler_codegen/src/encode_decode_templates.rs +++ b/rustler_codegen/src/encode_decode_templates.rs @@ -1,5 +1,6 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; +use syn::{GenericArgument, PathSegment, TraitBound}; use super::context::Context; @@ -55,6 +56,69 @@ pub(crate) fn decoder(ctx: &Context, inner: TokenStream) -> TokenStream { } } + let type_parameters: Vec<_> = generics + .params + .iter() + .filter_map(|g| match g { + syn::GenericParam::Type(t) => Some(t.clone()), + _ => None, + }) + .collect(); + + if !type_parameters.is_empty() { + let where_clause = impl_generics.make_where_clause(); + + for type_parameter in type_parameters { + let mut punctuated = syn::punctuated::Punctuated::new(); + punctuated.push(decode_lifetime.clone().into()); + punctuated.push(syn::TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + // path: syn::Ident::new("Decoder", Span::call_site()).into(), + path: syn::Path { + leading_colon: Some(syn::token::PathSep::default()), + segments: [ + PathSegment { + ident: syn::Ident::new("rustler", Span::call_site()), + arguments: syn::PathArguments::None, + }, + PathSegment { + ident: syn::Ident::new("Decoder", Span::call_site()), + arguments: syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Default::default(), + args: std::iter::once(GenericArgument::Lifetime( + decode_lifetime.clone(), + )) + .collect(), + gt_token: Default::default(), + }, + ), + }, + ] + .iter() + .cloned() + .collect(), + }, + })); + let predicate = syn::PredicateType { + lifetimes: None, + // bounded_ty: syn::Type::Verbatim(type_parameter.ident.to_token_stream()), + bounded_ty: syn::Type::Path(syn::TypePath { + qself: None, + path: type_parameter.ident.into(), + }), + colon_token: syn::token::Colon { + spans: [Span::call_site()], + }, + bounds: punctuated, + }; + where_clause.predicates.push(predicate.into()); + } + } + let (impl_generics, _, where_clause) = impl_generics.split_for_impl(); quote! { @@ -69,7 +133,56 @@ pub(crate) fn decoder(ctx: &Context, inner: TokenStream) -> TokenStream { pub(crate) fn encoder(ctx: &Context, inner: TokenStream) -> TokenStream { let ident = ctx.ident; - let generics = ctx.generics; + let mut generics = ctx.generics.clone(); + let type_parameters: Vec<_> = generics + .params + .iter() + .filter_map(|g| match g { + syn::GenericParam::Type(t) => Some(t.clone()), + _ => None, + }) + .collect(); + + if !type_parameters.is_empty() { + let where_clause = generics.make_where_clause(); + + for type_parameter in type_parameters { + let mut punctuated = syn::punctuated::Punctuated::new(); + punctuated.push(syn::TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + path: syn::Path { + leading_colon: Some(syn::token::PathSep::default()), + segments: [ + PathSegment { + ident: syn::Ident::new("rustler", Span::call_site()), + arguments: syn::PathArguments::None, + }, + PathSegment { + ident: syn::Ident::new("Encoder", Span::call_site()), + arguments: syn::PathArguments::None, + }, + ] + .iter() + .cloned() + .collect(), + }, + })); + let predicate = syn::PredicateType { + lifetimes: None, + bounded_ty: syn::Type::Path(syn::TypePath { + qself: None, + path: type_parameter.ident.into(), + }), + colon_token: syn::token::Colon { + spans: [Span::call_site()], + }, + bounds: punctuated, + }; + where_clause.predicates.push(predicate.into()); + } + } let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index e6776d04..7de1d85a 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -105,6 +105,8 @@ defmodule RustlerTest do def newtype_record_echo(_), do: err() def tuplestruct_record_echo(_), do: err() def reserved_keywords_type_echo(_), do: err() + def generic_struct_echo(_), do: err() + def mk_generic_map(_), do: err() def dirty_io(), do: err() def dirty_cpu(), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 840fae1c..6b0e4ece 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -95,7 +95,9 @@ rustler::init!( test_tuple::maybe_add_one_to_tuple, test_tuple::add_i32_from_tuple, test_tuple::greeting_person_from_tuple, - test_codegen::reserved_keywords::reserved_keywords_type_echo + test_codegen::reserved_keywords::reserved_keywords_type_echo, + test_codegen::generic_types::generic_struct_echo, + test_codegen::generic_types::mk_generic_map, ], load = load ); diff --git a/rustler_tests/native/rustler_test/src/test_codegen.rs b/rustler_tests/native/rustler_test/src/test_codegen.rs index d7671b68..19c58c64 100644 --- a/rustler_tests/native/rustler_test/src/test_codegen.rs +++ b/rustler_tests/native/rustler_test/src/test_codegen.rs @@ -262,3 +262,28 @@ pub mod reserved_keywords { reserved } } + +pub mod generic_types { + use rustler::{NifMap, NifStruct}; + #[derive(NifStruct)] + #[module = "GenericStruct"] + pub struct GenericStruct { + t: T, + } + + #[rustler::nif] + pub fn generic_struct_echo(value: GenericStruct) -> GenericStruct { + value + } + + #[derive(NifMap)] + pub struct GenericMap { + a: T, + b: T, + } + + #[rustler::nif] + pub fn mk_generic_map(value: &str) -> GenericMap<&str> { + GenericMap { a: value, b: value } + } +} diff --git a/rustler_tests/test/codegen_test.exs b/rustler_tests/test/codegen_test.exs index 1ea7fef9..069f8471 100644 --- a/rustler_tests/test/codegen_test.exs +++ b/rustler_tests/test/codegen_test.exs @@ -434,4 +434,16 @@ defmodule RustlerTest.CodegenTest do assert {1} == RustlerTest.reserved_keywords_type_echo({1}) assert {:record, 1} == RustlerTest.reserved_keywords_type_echo({:record, 1}) end + + describe "generic types" do + test "generic struct" do + assert %{__struct__: GenericStruct, t: 1} == + RustlerTest.generic_struct_echo(%{__struct__: GenericStruct, t: 1}) + end + + test "generic map" do + assert %{a: "hello", b: "hello"} == + RustlerTest.mk_generic_map("hello") + end + end end