Skip to content

Commit

Permalink
Allow generic type parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Bärenz committed Nov 9, 2023
1 parent fba9872 commit e35d771
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 2 deletions.
115 changes: 114 additions & 1 deletion rustler_codegen/src/encode_decode_templates.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{GenericArgument, PathSegment, TraitBound};

use super::context::Context;

Expand Down Expand Up @@ -55,6 +56,69 @@ pub(crate) fn decoder(ctx: &Context, inner: TokenStream) -> TokenStream {
}
}

let type_parameters: Vec<_> = generics
.params
.iter()
.filter_map(|g| match g {
syn::GenericParam::Type(t) => Some(t.clone()),
_ => None,
})
.collect();

if !type_parameters.is_empty() {
let where_clause = impl_generics.make_where_clause();

for type_parameter in type_parameters {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(decode_lifetime.clone().into());
punctuated.push(syn::TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
// path: syn::Ident::new("Decoder", Span::call_site()).into(),
path: syn::Path {
leading_colon: Some(syn::token::PathSep::default()),
segments: [
PathSegment {
ident: syn::Ident::new("rustler", Span::call_site()),
arguments: syn::PathArguments::None,
},
PathSegment {
ident: syn::Ident::new("Decoder", Span::call_site()),
arguments: syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments {
colon2_token: None,
lt_token: Default::default(),
args: std::iter::once(GenericArgument::Lifetime(
decode_lifetime.clone(),
))
.collect(),
gt_token: Default::default(),
},
),
},
]
.iter()
.cloned()
.collect(),
},
}));
let predicate = syn::PredicateType {
lifetimes: None,
// bounded_ty: syn::Type::Verbatim(type_parameter.ident.to_token_stream()),
bounded_ty: syn::Type::Path(syn::TypePath {
qself: None,
path: type_parameter.ident.into(),
}),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: punctuated,
};
where_clause.predicates.push(predicate.into());
}
}

let (impl_generics, _, where_clause) = impl_generics.split_for_impl();

quote! {
Expand All @@ -69,7 +133,56 @@ pub(crate) fn decoder(ctx: &Context, inner: TokenStream) -> TokenStream {

pub(crate) fn encoder(ctx: &Context, inner: TokenStream) -> TokenStream {
let ident = ctx.ident;
let generics = ctx.generics;
let mut generics = ctx.generics.clone();
let type_parameters: Vec<_> = generics
.params
.iter()
.filter_map(|g| match g {
syn::GenericParam::Type(t) => Some(t.clone()),
_ => None,
})
.collect();

if !type_parameters.is_empty() {
let where_clause = generics.make_where_clause();

for type_parameter in type_parameters {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(syn::TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: syn::Path {
leading_colon: Some(syn::token::PathSep::default()),
segments: [
PathSegment {
ident: syn::Ident::new("rustler", Span::call_site()),
arguments: syn::PathArguments::None,
},
PathSegment {
ident: syn::Ident::new("Encoder", Span::call_site()),
arguments: syn::PathArguments::None,
},
]
.iter()
.cloned()
.collect(),
},
}));
let predicate = syn::PredicateType {
lifetimes: None,
bounded_ty: syn::Type::Path(syn::TypePath {
qself: None,
path: type_parameter.ident.into(),
}),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: punctuated,
};
where_clause.predicates.push(predicate.into());
}
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
Expand Down
2 changes: 2 additions & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ defmodule RustlerTest do
def newtype_record_echo(_), do: err()
def tuplestruct_record_echo(_), do: err()
def reserved_keywords_type_echo(_), do: err()
def generic_struct_echo(_), do: err()
def mk_generic_map(_), do: err()

def dirty_io(), do: err()
def dirty_cpu(), do: err()
Expand Down
4 changes: 3 additions & 1 deletion rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ rustler::init!(
test_tuple::maybe_add_one_to_tuple,
test_tuple::add_i32_from_tuple,
test_tuple::greeting_person_from_tuple,
test_codegen::reserved_keywords::reserved_keywords_type_echo
test_codegen::reserved_keywords::reserved_keywords_type_echo,
test_codegen::generic_types::generic_struct_echo,
test_codegen::generic_types::mk_generic_map,
],
load = load
);
Expand Down
25 changes: 25 additions & 0 deletions rustler_tests/native/rustler_test/src/test_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,28 @@ pub mod reserved_keywords {
reserved
}
}

pub mod generic_types {
use rustler::{NifMap, NifStruct};
#[derive(NifStruct)]
#[module = "GenericStruct"]
pub struct GenericStruct<T> {
t: T,
}

#[rustler::nif]
pub fn generic_struct_echo(value: GenericStruct<i32>) -> GenericStruct<i32> {
value
}

#[derive(NifMap)]
pub struct GenericMap<T> {
a: T,
b: T,
}

#[rustler::nif]
pub fn mk_generic_map(value: &str) -> GenericMap<&str> {
GenericMap { a: value, b: value }
}
}
12 changes: 12 additions & 0 deletions rustler_tests/test/codegen_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,16 @@ defmodule RustlerTest.CodegenTest do
assert {1} == RustlerTest.reserved_keywords_type_echo({1})
assert {:record, 1} == RustlerTest.reserved_keywords_type_echo({:record, 1})
end

describe "generic types" do
test "generic struct" do
assert %{__struct__: GenericStruct, t: 1} ==
RustlerTest.generic_struct_echo(%{__struct__: GenericStruct, t: 1})
end

test "generic map" do
assert %{a: "hello", b: "hello"} ==
RustlerTest.mk_generic_map("hello")
end
end
end

0 comments on commit e35d771

Please sign in to comment.