diff --git a/borsh-derive-internal/Cargo.toml b/borsh-derive-internal/Cargo.toml index 6310b831d..0801037a4 100644 --- a/borsh-derive-internal/Cargo.toml +++ b/borsh-derive-internal/Cargo.toml @@ -17,6 +17,7 @@ Binary Object Representation Serializer for Hashing proc-macro2 = "1" syn = { version = "2", features = ["full", "fold"] } quote = "1" +syn_derive = "0.1.6" [dev-dependencies] syn = { version = "2", features = ["full", "fold", "parsing"] } diff --git a/borsh-derive-internal/src/attribute_helpers.rs b/borsh-derive-internal/src/attribute_helpers.rs index 948b3e3fd..77967d7a3 100644 --- a/borsh-derive-internal/src/attribute_helpers.rs +++ b/borsh-derive-internal/src/attribute_helpers.rs @@ -4,21 +4,27 @@ use syn::{Attribute, Field, Path, WherePredicate}; pub mod parsing_helpers; use parsing_helpers::get_where_predicates; +use self::parsing_helpers::{get_schema_attrs, SchemaParamsOverride}; + #[derive(Copy, Clone)] pub struct Symbol(pub &'static str); -/// top level prefix in nested meta attribute +/// borsh - top level prefix in nested meta attribute pub const BORSH: Symbol = Symbol("borsh"); -/// field-level only attribute, `BorshSerialize` and `BorshDeserialize` contexts +/// bound - sub-borsh nested meta, field-level only, `BorshSerialize` and `BorshDeserialize` contexts pub const BOUND: Symbol = Symbol("bound"); -/// sub-BOUND nested meta attribute +/// serialize - sub-bound nested meta attribute pub const SERIALIZE: Symbol = Symbol("serialize"); -/// sub-BOUND nested meta attribute +/// deserialize - sub-bound nested meta attribute pub const DESERIALIZE: Symbol = Symbol("deserialize"); -/// field-level only attribute, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts +/// borsh_skip - field-level only attribute, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts pub const SKIP: Symbol = Symbol("borsh_skip"); -/// item-level only attribute `BorshDeserialize` context +/// borsh_init - item-level only attribute `BorshDeserialize` context pub const INIT: Symbol = Symbol("borsh_init"); +/// schema - sub-borsh nested meta, `BorshSchema` context +pub const SCHEMA: Symbol = Symbol("schema"); +/// params - sub-schema nested meta, field-level only attribute +pub const PARAMS: Symbol = Symbol("params"); impl PartialEq for Path { fn eq(&self, word: &Symbol) -> bool { @@ -32,11 +38,11 @@ impl<'a> PartialEq for &'a Path { } } -pub fn contains_skip(attrs: &[Attribute]) -> bool { +pub(crate) fn contains_skip(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| attr.path() == SKIP) } -pub fn contains_initialize_with(attrs: &[Attribute]) -> Option { +pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Option { for attr in attrs.iter() { if attr.path() == INIT { let mut res = None; @@ -51,9 +57,10 @@ pub fn contains_initialize_with(attrs: &[Attribute]) -> Option { None } -type Bounds = Option>; +pub(crate) type Bounds = Option>; +pub(crate) type SchemaParams = Option>; -pub fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> { +fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> { let (mut ser, mut de): (Bounds, Bounds) = (None, None); for attr in attrs { if attr.path() != BORSH { @@ -75,12 +82,32 @@ pub fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> Ok((ser, de)) } -pub enum BoundType { +pub(crate) fn parse_schema_attrs(attrs: &[Attribute]) -> Result { + let mut params: SchemaParams = None; + for attr in attrs { + if attr.path() != BORSH { + continue; + } + + attr.parse_nested_meta(|meta| { + if meta.path == SCHEMA { + // #[borsh(schema(params = "..."))] + + let params_parsed = get_schema_attrs(&meta)?; + params = params_parsed; + } + Ok(()) + })?; + } + + Ok(params) +} +pub(crate) enum BoundType { Serialize, Deserialize, } -pub fn get_bounds(field: &Field, ty: BoundType) -> Result { +pub(crate) fn get_bounds(field: &Field, ty: BoundType) -> Result { let (ser, de) = parse_bounds(&field.attrs)?; match ty { BoundType::Serialize => Ok(ser), @@ -88,7 +115,7 @@ pub fn get_bounds(field: &Field, ty: BoundType) -> Result { } } -pub fn collect_override_bounds( +pub(crate) fn collect_override_bounds( field: &Field, ty: BoundType, output: &mut Vec, @@ -109,8 +136,10 @@ mod tests { use std::fmt::Write; use syn::ItemStruct; + use crate::attribute_helpers::parse_schema_attrs; + use super::{parse_bounds, Bounds}; - fn debug_print_bounds(bounds: Bounds) -> String { + fn debug_print_bounds(bounds: Option>) -> String { let mut s = String::new(); if let Some(bounds) = bounds { for bound in bounds { @@ -180,7 +209,7 @@ mod tests { let first_field = &item_struct.fields.into_iter().collect::>()[0]; let (ser, de) = parse_bounds(&first_field.attrs).unwrap(); - insta::assert_snapshot!(debug_print_bounds(ser)); + assert_eq!(ser.unwrap().len(), 0); insta::assert_snapshot!(debug_print_bounds(de)); } @@ -197,7 +226,7 @@ mod tests { let first_field = &item_struct.fields.into_iter().collect::>()[0]; let (ser, de) = parse_bounds(&first_field.attrs).unwrap(); - insta::assert_snapshot!(debug_print_bounds(ser)); + assert!(ser.is_none()); insta::assert_snapshot!(debug_print_bounds(de)); } @@ -257,4 +286,105 @@ mod tests { }; insta::assert_debug_snapshot!(err); } + + #[test] + fn test_schema_params_parsing1() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + #[borsh(schema(params = + "T => ::Associated" + ))] + field: ::Associated, + another: V, + } + }) + .unwrap(); + + let first_field = &item_struct.fields.into_iter().collect::>()[0]; + let schema_params = parse_schema_attrs(&first_field.attrs).unwrap(); + insta::assert_snapshot!(debug_print_bounds(schema_params)); + } + #[test] + fn test_schema_params_parsing_error() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + #[borsh(schema(params = + "T => ::Associated" + ))] + field: ::Associated, + another: V, + } + }) + .unwrap(); + + let first_field = &item_struct.fields.into_iter().collect::>()[0]; + let err = match parse_schema_attrs(&first_field.attrs) { + Ok(..) => unreachable!("expecting error here"), + Err(err) => err, + }; + insta::assert_debug_snapshot!(err); + } + + #[test] + fn test_schema_params_parsing2() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + #[borsh(schema(params = + "T => ::Associated, V => Vec" + ))] + field: ::Associated, + another: V, + } + }) + .unwrap(); + + let first_field = &item_struct.fields.into_iter().collect::>()[0]; + let schema_params = parse_schema_attrs(&first_field.attrs).unwrap(); + insta::assert_snapshot!(debug_print_bounds(schema_params)); + } + #[test] + fn test_schema_params_parsing3() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + #[borsh(schema(params = "" ))] + field: ::Associated, + another: V, + } + }) + .unwrap(); + + let first_field = &item_struct.fields.into_iter().collect::>()[0]; + let schema_params = parse_schema_attrs(&first_field.attrs).unwrap(); + assert_eq!(schema_params.unwrap().len(), 0); + } + + #[test] + fn test_schema_params_parsing4() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + field: ::Associated, + another: V, + } + }) + .unwrap(); + + let first_field = &item_struct.fields.into_iter().collect::>()[0]; + let schema_params = parse_schema_attrs(&first_field.attrs).unwrap(); + assert!(schema_params.is_none()); + } } diff --git a/borsh-derive-internal/src/attribute_helpers/parsing_helpers.rs b/borsh-derive-internal/src/attribute_helpers/parsing_helpers.rs index fa955686f..abd223903 100644 --- a/borsh-derive-internal/src/attribute_helpers/parsing_helpers.rs +++ b/borsh-derive-internal/src/attribute_helpers/parsing_helpers.rs @@ -2,9 +2,12 @@ #![allow(unused)] use std::iter::FromIterator; -use syn::{meta::ParseNestedMeta, punctuated::Punctuated, token::Paren, Expr, Lit, LitStr, Token}; +use syn::{ + meta::ParseNestedMeta, punctuated::Punctuated, token::Paren, Expr, Ident, Lit, LitStr, Token, + Type, WherePredicate, +}; -use super::{Bounds, Symbol, BOUND, DESERIALIZE, SERIALIZE}; +use super::{Bounds, SchemaParams, Symbol, BOUND, DESERIALIZE, PARAMS, SCHEMA, SERIALIZE}; fn get_lit_str2( attr_name: Symbol, meta_item_name: Symbol, @@ -31,22 +34,30 @@ fn get_lit_str2( } } -fn parse_lit_into_where( +fn parse_lit_into( attr_name: Symbol, meta_item_name: Symbol, meta: &ParseNestedMeta, -) -> syn::Result> { +) -> syn::Result> { let string = match get_lit_str2(attr_name, meta_item_name, meta)? { Some(string) => string, None => return Ok(Vec::new()), }; - match string.parse_with(Punctuated::::parse_terminated) { - Ok(predicates) => Ok(Vec::from_iter(predicates)), + match string.parse_with(Punctuated::::parse_terminated) { + Ok(elements) => Ok(Vec::from_iter(elements)), Err(err) => Err(syn::Error::new_spanned(string, err)), } } +/// struct describes entries like `K => ::Associated` +#[derive(Clone, syn_derive::Parse, syn_derive::ToTokens)] +pub(crate) struct SchemaParamsOverride { + pub order_param: Ident, + arrow_token: Token![=>], + pub override_type: Type, +} + fn get_ser_and_de( attr_name: Symbol, meta: &ParseNestedMeta, @@ -85,6 +96,45 @@ where Ok((ser_meta, de_meta)) } -pub fn get_where_predicates(meta: &ParseNestedMeta) -> syn::Result<(Bounds, Bounds)> { - get_ser_and_de(BOUND, meta, parse_lit_into_where) + +fn get_schema_nested_meta( + attr_name: Symbol, + meta: &ParseNestedMeta, + f: F, +) -> syn::Result> +where + T: Clone, + F: Fn(Symbol, Symbol, &ParseNestedMeta) -> syn::Result, + R: Into>, +{ + let mut params: Option = None; + + let lookahead = meta.input.lookahead1(); + if lookahead.peek(Paren) { + meta.parse_nested_meta(|meta| { + if meta.path == PARAMS { + if let Some(v) = f(attr_name, PARAMS, &meta)?.into() { + params = Some(v); + } + } else { + return Err(meta.error(format_args!( + "malformed {0} attribute, expected `{0}(params = ...)`", + attr_name.0, + ))); + } + Ok(()) + })?; + } else { + return Err(lookahead.error()); + } + + Ok(params) +} + +pub(crate) fn get_where_predicates(meta: &ParseNestedMeta) -> syn::Result<(Bounds, Bounds)> { + get_ser_and_de(BOUND, meta, parse_lit_into::) +} + +pub(crate) fn get_schema_attrs(meta: &ParseNestedMeta) -> syn::Result { + get_schema_nested_meta(SCHEMA, meta, parse_lit_into::) } diff --git a/borsh-derive-internal/src/generics.rs b/borsh-derive-internal/src/generics.rs index 5c602cfde..27fd8c147 100644 --- a/borsh-derive-internal/src/generics.rs +++ b/borsh-derive-internal/src/generics.rs @@ -2,7 +2,7 @@ #![allow(unused)] use std::collections::{HashMap, HashSet}; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{ punctuated::Pair, Field, GenericArgument, Generics, Ident, Macro, Path, PathArguments, PathSegment, ReturnType, Type, TypeParamBound, TypePath, WherePredicate, @@ -41,6 +41,14 @@ pub fn without_defaults(generics: &Generics) -> Generics { } } +pub fn type_contains_some_param(type_: &Type, params: &HashSet) -> bool { + let mut find: FindTyParams = FindTyParams::from_params(params.iter()); + + find.visit_type_top_level(type_); + + find.at_least_one_hit() +} + /// a Visitor-like struct, which helps determine, if a type parameter is found in field #[derive(Clone)] pub struct FindTyParams { @@ -53,7 +61,7 @@ pub struct FindTyParams { relevant_type_params: HashSet, // [Param] => [Type, containing Param] mapping - associated_type_params_usage: HashMap, + associated_type_params_usage: HashMap>, } fn ungroup(mut ty: &Type) -> &Type { @@ -82,10 +90,21 @@ impl FindTyParams { associated_type_params_usage: HashMap::new(), } } + pub fn from_params<'a>(params: impl Iterator) -> Self { + let all_type_params_ordered: Vec = params.cloned().collect(); + let all_type_params = all_type_params_ordered.clone().into_iter().collect(); + FindTyParams { + all_type_params, + all_type_params_ordered, + relevant_type_params: HashSet::new(), + associated_type_params_usage: HashMap::new(), + } + } pub fn process_for_bounds(self) -> Vec { let relevant_type_params = self.relevant_type_params; let associated_type_params_usage = self.associated_type_params_usage; let mut new_predicates: Vec = vec![]; + let mut new_predicates_set: HashSet = HashSet::new(); self.all_type_params_ordered.iter().for_each(|param| { if relevant_type_params.contains(param) { @@ -93,10 +112,20 @@ impl FindTyParams { qself: None, path: param.clone().into(), }); - new_predicates.push(ty); + let ty_str_repr = ty.to_token_stream().to_string(); + if !new_predicates_set.contains(&ty_str_repr) { + new_predicates.push(ty); + new_predicates_set.insert(ty_str_repr); + } } - if let Some(type_) = associated_type_params_usage.get(param) { - new_predicates.push(type_.clone()); + if let Some(vec_type) = associated_type_params_usage.get(param) { + for type_ in vec_type { + let ty_str_repr = type_.to_token_stream().to_string(); + if !new_predicates_set.contains(&ty_str_repr) { + new_predicates.push(type_.clone()); + new_predicates_set.insert(ty_str_repr); + } + } } }); @@ -107,33 +136,47 @@ impl FindTyParams { let associated_type_params_usage = self.associated_type_params_usage; let mut params: Vec = vec![]; + let mut params_set: HashSet = HashSet::new(); self.all_type_params_ordered.iter().for_each(|param| { - if relevant_type_params.contains(param) { + if relevant_type_params.contains(param) && !params_set.contains(param) { params.push(param.clone()); + params_set.insert(param.clone()); } - if associated_type_params_usage.get(param).is_some() { + if associated_type_params_usage.get(param).is_some() && !params_set.contains(param){ params.push(param.clone()); + params_set.insert(param.clone()); } }); params } + pub fn at_least_one_hit(&self) -> bool { + !self.relevant_type_params.is_empty() || !self.associated_type_params_usage.is_empty() + } } impl FindTyParams { pub fn visit_field(&mut self, field: &Field) { - if let Type::Path(ty) = ungroup(&field.ty) { + self.visit_type_top_level(&field.ty); + } + + pub fn visit_type_top_level(&mut self, type_: &Type) { + if let Type::Path(ty) = ungroup(type_) { if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() { if self.all_type_params.contains(&t.ident) { - self.associated_type_params_usage - .insert(t.ident.clone(), field.ty.clone()); + self.param_associated_type_insert(t.ident.clone(), type_.clone()); } } } - self.visit_type(&field.ty); + self.visit_type(type_); } - pub fn insert_type(&mut self, param: Ident, type_: Type) { - self.associated_type_params_usage.insert(param, type_); + pub fn param_associated_type_insert(&mut self, param: Ident, type_: Type) { + if let Some(type_vec) = self.associated_type_params_usage.get_mut(¶m) { + type_vec.push(type_); + } else { + let type_vec = vec![type_]; + self.associated_type_params_usage.insert(param, type_vec); + } } fn visit_return_type(&mut self, return_type: &ReturnType) { diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3-2.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3-2.snap deleted file mode 100644 index b8638b75a..000000000 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3-2.snap +++ /dev/null @@ -1,7 +0,0 @@ ---- -source: borsh-derive-internal/src/attribute_helpers.rs -expression: debug_print_bounds(de) ---- -K : Hash + Eq + borsh :: de :: BorshDeserialize -V : borsh :: de :: BorshDeserialize - diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3.snap index fdda2b867..b8638b75a 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing3.snap @@ -1,5 +1,7 @@ --- source: borsh-derive-internal/src/attribute_helpers.rs -expression: debug_print_bounds(ser) +expression: debug_print_bounds(de) --- +K : Hash + Eq + borsh :: de :: BorshDeserialize +V : borsh :: de :: BorshDeserialize diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4.snap index ce64f72be..b297807c3 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4.snap @@ -1,5 +1,6 @@ --- source: borsh-derive-internal/src/attribute_helpers.rs -expression: debug_print_bounds(ser) +expression: debug_print_bounds(de) --- -None +K : Hash + diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing1.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing1.snap new file mode 100644 index 000000000..073bde496 --- /dev/null +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing1.snap @@ -0,0 +1,6 @@ +--- +source: borsh-derive-internal/src/attribute_helpers.rs +expression: debug_print_bounds(schema_params) +--- +T => < T as TraitName > :: Associated + diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing2.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing2.snap new file mode 100644 index 000000000..6fdd2ab62 --- /dev/null +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing2.snap @@ -0,0 +1,7 @@ +--- +source: borsh-derive-internal/src/attribute_helpers.rs +expression: debug_print_bounds(schema_params) +--- +T => < T as TraitName > :: Associated +V => Vec < V > + diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4-2.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing_error.snap similarity index 58% rename from borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4-2.snap rename to borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing_error.snap index b297807c3..b05782256 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__bounds_parsing4-2.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__attribute_helpers__tests__schema_params_parsing_error.snap @@ -1,6 +1,7 @@ --- source: borsh-derive-internal/src/attribute_helpers.rs -expression: debug_print_bounds(de) +expression: err --- -K : Hash - +Error( + "expected `>`", +) diff --git a/borsh-schema-derive-internal/Cargo.toml b/borsh-schema-derive-internal/Cargo.toml index 46365955b..ce7055612 100644 --- a/borsh-schema-derive-internal/Cargo.toml +++ b/borsh-schema-derive-internal/Cargo.toml @@ -16,6 +16,7 @@ Schema Generator for Borsh proc-macro2 = "1" syn = { version = "2", features = ["full", "fold"] } quote = "1" +syn_derive = "0.1.6" [dev-dependencies] syn = { version = "2", features = ["full", "fold", "parsing"] } diff --git a/borsh-schema-derive-internal/src/attribute_helpers.rs b/borsh-schema-derive-internal/src/attribute_helpers.rs index 976372149..b2eb8f3e4 100644 --- a/borsh-schema-derive-internal/src/attribute_helpers.rs +++ b/borsh-schema-derive-internal/src/attribute_helpers.rs @@ -4,23 +4,27 @@ use syn::{Attribute, Field, Path, WherePredicate}; pub mod parsing_helpers; use parsing_helpers::get_where_predicates; +use self::parsing_helpers::{get_schema_attrs, SchemaParamsOverride}; + #[derive(Copy, Clone)] pub struct Symbol(pub &'static str); -/// top level prefix in nested meta attribute +/// borsh - top level prefix in nested meta attribute pub const BORSH: Symbol = Symbol("borsh"); -/// sub-BORSH nested meta, field-level only attribute, `BorshSerialize` and `BorshDeserialize` contexts +/// bound - sub-borsh nested meta, field-level only, `BorshSerialize` and `BorshDeserialize` contexts pub const BOUND: Symbol = Symbol("bound"); -/// sub-BOUND nested meta attribute +/// serialize - sub-bound nested meta attribute pub const SERIALIZE: Symbol = Symbol("serialize"); -/// sub-BOUND nested meta attribute +/// deserialize - sub-bound nested meta attribute pub const DESERIALIZE: Symbol = Symbol("deserialize"); -/// field-level only attribute, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts +/// borsh_skip - field-level only attribute, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts pub const SKIP: Symbol = Symbol("borsh_skip"); -/// item-level only attribute `BorshDeserialize` context +/// borsh_init - item-level only attribute `BorshDeserialize` context pub const INIT: Symbol = Symbol("borsh_init"); -/// sub-BORSH nested meta, field-level only attribute, `BorshSchema` context -pub const SCHEMA_PARAMS: Symbol = Symbol("schema_params"); +/// schema - sub-borsh nested meta, `BorshSchema` context +pub const SCHEMA: Symbol = Symbol("schema"); +/// params - sub-schema nested meta, field-level only attribute +pub const PARAMS: Symbol = Symbol("params"); impl PartialEq for Path { fn eq(&self, word: &Symbol) -> bool { @@ -34,11 +38,11 @@ impl<'a> PartialEq for &'a Path { } } -pub fn contains_skip(attrs: &[Attribute]) -> bool { +pub(crate) fn contains_skip(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| attr.path() == SKIP) } -pub fn contains_initialize_with(attrs: &[Attribute]) -> Option { +pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Option { for attr in attrs.iter() { if attr.path() == INIT { let mut res = None; @@ -53,9 +57,10 @@ pub fn contains_initialize_with(attrs: &[Attribute]) -> Option { None } -type Bounds = Option>; +pub(crate) type Bounds = Option>; +pub(crate) type SchemaParams = Option>; -pub fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> { +fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> { let (mut ser, mut de): (Bounds, Bounds) = (None, None); for attr in attrs { if attr.path() != BORSH { @@ -77,12 +82,32 @@ pub fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> Ok((ser, de)) } -pub enum BoundType { +pub(crate) fn parse_schema_attrs(attrs: &[Attribute]) -> Result { + let mut params: SchemaParams = None; + for attr in attrs { + if attr.path() != BORSH { + continue; + } + + attr.parse_nested_meta(|meta| { + if meta.path == SCHEMA { + // #[borsh(schema(params = "..."))] + + let params_parsed = get_schema_attrs(&meta)?; + params = params_parsed; + } + Ok(()) + })?; + } + + Ok(params) +} +pub(crate) enum BoundType { Serialize, Deserialize, } -pub fn get_bounds(field: &Field, ty: BoundType) -> Result { +pub(crate) fn get_bounds(field: &Field, ty: BoundType) -> Result { let (ser, de) = parse_bounds(&field.attrs)?; match ty { BoundType::Serialize => Ok(ser), @@ -90,7 +115,7 @@ pub fn get_bounds(field: &Field, ty: BoundType) -> Result { } } -pub fn collect_override_bounds( +pub(crate) fn collect_override_bounds( field: &Field, ty: BoundType, output: &mut Vec, diff --git a/borsh-schema-derive-internal/src/attribute_helpers/parsing_helpers.rs b/borsh-schema-derive-internal/src/attribute_helpers/parsing_helpers.rs index fa955686f..abd223903 100644 --- a/borsh-schema-derive-internal/src/attribute_helpers/parsing_helpers.rs +++ b/borsh-schema-derive-internal/src/attribute_helpers/parsing_helpers.rs @@ -2,9 +2,12 @@ #![allow(unused)] use std::iter::FromIterator; -use syn::{meta::ParseNestedMeta, punctuated::Punctuated, token::Paren, Expr, Lit, LitStr, Token}; +use syn::{ + meta::ParseNestedMeta, punctuated::Punctuated, token::Paren, Expr, Ident, Lit, LitStr, Token, + Type, WherePredicate, +}; -use super::{Bounds, Symbol, BOUND, DESERIALIZE, SERIALIZE}; +use super::{Bounds, SchemaParams, Symbol, BOUND, DESERIALIZE, PARAMS, SCHEMA, SERIALIZE}; fn get_lit_str2( attr_name: Symbol, meta_item_name: Symbol, @@ -31,22 +34,30 @@ fn get_lit_str2( } } -fn parse_lit_into_where( +fn parse_lit_into( attr_name: Symbol, meta_item_name: Symbol, meta: &ParseNestedMeta, -) -> syn::Result> { +) -> syn::Result> { let string = match get_lit_str2(attr_name, meta_item_name, meta)? { Some(string) => string, None => return Ok(Vec::new()), }; - match string.parse_with(Punctuated::::parse_terminated) { - Ok(predicates) => Ok(Vec::from_iter(predicates)), + match string.parse_with(Punctuated::::parse_terminated) { + Ok(elements) => Ok(Vec::from_iter(elements)), Err(err) => Err(syn::Error::new_spanned(string, err)), } } +/// struct describes entries like `K => ::Associated` +#[derive(Clone, syn_derive::Parse, syn_derive::ToTokens)] +pub(crate) struct SchemaParamsOverride { + pub order_param: Ident, + arrow_token: Token![=>], + pub override_type: Type, +} + fn get_ser_and_de( attr_name: Symbol, meta: &ParseNestedMeta, @@ -85,6 +96,45 @@ where Ok((ser_meta, de_meta)) } -pub fn get_where_predicates(meta: &ParseNestedMeta) -> syn::Result<(Bounds, Bounds)> { - get_ser_and_de(BOUND, meta, parse_lit_into_where) + +fn get_schema_nested_meta( + attr_name: Symbol, + meta: &ParseNestedMeta, + f: F, +) -> syn::Result> +where + T: Clone, + F: Fn(Symbol, Symbol, &ParseNestedMeta) -> syn::Result, + R: Into>, +{ + let mut params: Option = None; + + let lookahead = meta.input.lookahead1(); + if lookahead.peek(Paren) { + meta.parse_nested_meta(|meta| { + if meta.path == PARAMS { + if let Some(v) = f(attr_name, PARAMS, &meta)?.into() { + params = Some(v); + } + } else { + return Err(meta.error(format_args!( + "malformed {0} attribute, expected `{0}(params = ...)`", + attr_name.0, + ))); + } + Ok(()) + })?; + } else { + return Err(lookahead.error()); + } + + Ok(params) +} + +pub(crate) fn get_where_predicates(meta: &ParseNestedMeta) -> syn::Result<(Bounds, Bounds)> { + get_ser_and_de(BOUND, meta, parse_lit_into::) +} + +pub(crate) fn get_schema_attrs(meta: &ParseNestedMeta) -> syn::Result { + get_schema_nested_meta(SCHEMA, meta, parse_lit_into::) } diff --git a/borsh-schema-derive-internal/src/enum_schema.rs b/borsh-schema-derive-internal/src/enum_schema.rs index 2a7df3803..fca7af2ec 100644 --- a/borsh-schema-derive-internal/src/enum_schema.rs +++ b/borsh-schema-derive-internal/src/enum_schema.rs @@ -58,7 +58,7 @@ pub fn process_enum(input: &ItemEnum, cratename: Ident) -> syn::Result + where + K: TraitName, + { + B { + x: Vec, + #[borsh_skip] + #[borsh(schema(params = "K => ::Associated"))] + z: ::Associated, + }, + C(T, u16), + } + }) + .unwrap(); + + let actual = process_enum(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + + insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } } diff --git a/borsh-schema-derive-internal/src/generics.rs b/borsh-schema-derive-internal/src/generics.rs index 314d30f75..27fd8c147 100644 --- a/borsh-schema-derive-internal/src/generics.rs +++ b/borsh-schema-derive-internal/src/generics.rs @@ -2,7 +2,7 @@ #![allow(unused)] use std::collections::{HashMap, HashSet}; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{ punctuated::Pair, Field, GenericArgument, Generics, Ident, Macro, Path, PathArguments, PathSegment, ReturnType, Type, TypeParamBound, TypePath, WherePredicate, @@ -61,7 +61,7 @@ pub struct FindTyParams { relevant_type_params: HashSet, // [Param] => [Type, containing Param] mapping - associated_type_params_usage: HashMap, + associated_type_params_usage: HashMap>, } fn ungroup(mut ty: &Type) -> &Type { @@ -104,6 +104,7 @@ impl FindTyParams { let relevant_type_params = self.relevant_type_params; let associated_type_params_usage = self.associated_type_params_usage; let mut new_predicates: Vec = vec![]; + let mut new_predicates_set: HashSet = HashSet::new(); self.all_type_params_ordered.iter().for_each(|param| { if relevant_type_params.contains(param) { @@ -111,10 +112,20 @@ impl FindTyParams { qself: None, path: param.clone().into(), }); - new_predicates.push(ty); + let ty_str_repr = ty.to_token_stream().to_string(); + if !new_predicates_set.contains(&ty_str_repr) { + new_predicates.push(ty); + new_predicates_set.insert(ty_str_repr); + } } - if let Some(type_) = associated_type_params_usage.get(param) { - new_predicates.push(type_.clone()); + if let Some(vec_type) = associated_type_params_usage.get(param) { + for type_ in vec_type { + let ty_str_repr = type_.to_token_stream().to_string(); + if !new_predicates_set.contains(&ty_str_repr) { + new_predicates.push(type_.clone()); + new_predicates_set.insert(ty_str_repr); + } + } } }); @@ -125,12 +136,15 @@ impl FindTyParams { let associated_type_params_usage = self.associated_type_params_usage; let mut params: Vec = vec![]; + let mut params_set: HashSet = HashSet::new(); self.all_type_params_ordered.iter().for_each(|param| { - if relevant_type_params.contains(param) { + if relevant_type_params.contains(param) && !params_set.contains(param) { params.push(param.clone()); + params_set.insert(param.clone()); } - if associated_type_params_usage.get(param).is_some() { + if associated_type_params_usage.get(param).is_some() && !params_set.contains(param){ params.push(param.clone()); + params_set.insert(param.clone()); } }); params @@ -149,16 +163,20 @@ impl FindTyParams { if let Type::Path(ty) = ungroup(type_) { if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() { if self.all_type_params.contains(&t.ident) { - self.associated_type_params_usage - .insert(t.ident.clone(), type_.clone()); + self.param_associated_type_insert(t.ident.clone(), type_.clone()); } } } self.visit_type(type_); } - pub fn insert_type(&mut self, param: Ident, type_: Type) { - self.associated_type_params_usage.insert(param, type_); + pub fn param_associated_type_insert(&mut self, param: Ident, type_: Type) { + if let Some(type_vec) = self.associated_type_params_usage.get_mut(¶m) { + type_vec.push(type_); + } else { + let type_vec = vec![type_]; + self.associated_type_params_usage.insert(param, type_vec); + } } fn visit_return_type(&mut self, return_type: &ReturnType) { diff --git a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override.snap b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override.snap index 56f0a1c32..f103ee343 100644 --- a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override.snap +++ b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override.snap @@ -10,11 +10,13 @@ where T: Eq + Hash, T: borsh::BorshSchema, K: borsh::BorshSchema, + ::Associated: borsh::BorshSchema, V: borsh::BorshSchema, { fn declaration() -> borsh::schema::Declaration { let params = borsh::__private::maybestd::vec![ - < T > ::declaration(), < K > ::declaration(), < V > ::declaration() + < T > ::declaration(), < K > ::declaration(), < < K as TraitName > + ::Associated > ::declaration(), < V > ::declaration() ]; format!(r#"{}<{}>"#, "EnumParametrized", params.join(", ")) } diff --git a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override_ignored.snap b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override_ignored.snap new file mode 100644 index 000000000..132a34f56 --- /dev/null +++ b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__enum_schema__tests__generic_associated_type_param_override_ignored.snap @@ -0,0 +1,54 @@ +--- +source: borsh-schema-derive-internal/src/enum_schema.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::BorshSchema for EnumParametrized +where + K: TraitName, + T: borsh::BorshSchema, + V: borsh::BorshSchema, +{ + fn declaration() -> borsh::schema::Declaration { + let params = borsh::__private::maybestd::vec![ + < T > ::declaration(), < V > ::declaration() + ]; + format!(r#"{}<{}>"#, "EnumParametrized", params.join(", ")) + } + fn add_definitions_recursively( + definitions: &mut borsh::__private::maybestd::collections::BTreeMap< + borsh::schema::Declaration, + borsh::schema::Definition, + >, + ) { + #[allow(dead_code)] + #[derive(borsh::BorshSchema)] + struct EnumParametrizedB + where + K: TraitName, + { + x: Vec, + #[borsh_skip] + #[borsh(schema(params = "K => ::Associated"))] + z: ::Associated, + } + #[allow(dead_code)] + #[derive(borsh::BorshSchema)] + struct EnumParametrizedC(T, u16); + as borsh::BorshSchema>::add_definitions_recursively(definitions); + as borsh::BorshSchema>::add_definitions_recursively(definitions); + let variants = borsh::__private::maybestd::vec![ + ("B".to_string(), < EnumParametrizedB < K, V > > ::declaration()), ("C" + .to_string(), < EnumParametrizedC < T > > ::declaration()) + ]; + let definition = borsh::schema::Definition::Enum { + variants, + }; + Self::add_definition(Self::declaration(), definition, definitions); + } +} + diff --git a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override.snap b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override.snap index 8d386211e..303a6a90d 100644 --- a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override.snap +++ b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override.snap @@ -6,11 +6,11 @@ impl borsh::BorshSchema for Parametrized where T: TraitName, V: borsh::BorshSchema, - T: borsh::BorshSchema, + ::Associated: borsh::BorshSchema, { fn declaration() -> borsh::schema::Declaration { let params = borsh::__private::maybestd::vec![ - < V > ::declaration(), < T > ::declaration() + < V > ::declaration(), < < T as TraitName > ::Associated > ::declaration() ]; format!(r#"{}<{}>"#, "Parametrized", params.join(", ")) } diff --git a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override2.snap b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override2.snap new file mode 100644 index 000000000..3e47f73dc --- /dev/null +++ b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override2.snap @@ -0,0 +1,46 @@ +--- +source: borsh-schema-derive-internal/src/struct_schema.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::BorshSchema for Parametrized +where + T: TraitName, + V: borsh::BorshSchema, + T: borsh::BorshSchema, + ::Associated: borsh::BorshSchema, +{ + fn declaration() -> borsh::schema::Declaration { + let params = borsh::__private::maybestd::vec![ + < V > ::declaration(), < T > ::declaration(), < < T as TraitName > + ::Associated > ::declaration() + ]; + format!(r#"{}<{}>"#, "Parametrized", params.join(", ")) + } + fn add_definitions_recursively( + definitions: &mut borsh::__private::maybestd::collections::BTreeMap< + borsh::schema::Declaration, + borsh::schema::Definition, + >, + ) { + let fields = borsh::schema::Fields::NamedFields( + borsh::__private::maybestd::vec![ + ("field".to_string(), < (< T as TraitName > ::Associated, T) as + borsh::BorshSchema > ::declaration()), ("another".to_string(), < V as + borsh::BorshSchema > ::declaration()) + ], + ); + let definition = borsh::schema::Definition::Struct { + fields, + }; + let no_recursion_flag = definitions.get(&Self::declaration()).is_none(); + Self::add_definition(Self::declaration(), definition, definitions); + if no_recursion_flag { + <( + ::Associated, + T, + ) as borsh::BorshSchema>::add_definitions_recursively(definitions); + ::add_definitions_recursively(definitions); + } + } +} + diff --git a/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override_ignored.snap b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override_ignored.snap new file mode 100644 index 000000000..f44b941ee --- /dev/null +++ b/borsh-schema-derive-internal/src/snapshots/borsh_schema_derive_internal__struct_schema__tests__generic_associated_type_param_override_ignored.snap @@ -0,0 +1,35 @@ +--- +source: borsh-schema-derive-internal/src/struct_schema.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::BorshSchema for Parametrized +where + T: TraitName, + V: borsh::BorshSchema, +{ + fn declaration() -> borsh::schema::Declaration { + let params = borsh::__private::maybestd::vec![< V > ::declaration()]; + format!(r#"{}<{}>"#, "Parametrized", params.join(", ")) + } + fn add_definitions_recursively( + definitions: &mut borsh::__private::maybestd::collections::BTreeMap< + borsh::schema::Declaration, + borsh::schema::Definition, + >, + ) { + let fields = borsh::schema::Fields::NamedFields( + borsh::__private::maybestd::vec![ + ("another".to_string(), < V as borsh::BorshSchema > ::declaration()) + ], + ); + let definition = borsh::schema::Definition::Struct { + fields, + }; + let no_recursion_flag = definitions.get(&Self::declaration()).is_none(); + Self::add_definition(Self::declaration(), definition, definitions); + if no_recursion_flag { + ::add_definitions_recursively(definitions); + } + } +} + diff --git a/borsh-schema-derive-internal/src/struct_schema.rs b/borsh-schema-derive-internal/src/struct_schema.rs index b22b4a2fa..740c30f9d 100644 --- a/borsh-schema-derive-internal/src/struct_schema.rs +++ b/borsh-schema-derive-internal/src/struct_schema.rs @@ -1,32 +1,51 @@ use proc_macro2::TokenStream as TokenStream2; use quote::{quote, ToTokens}; -use syn::{Fields, Ident, ItemStruct, Path, WhereClause}; +use syn::{Field, Fields, Ident, ItemStruct, Path, WhereClause}; use crate::{ - attribute_helpers::contains_skip, + attribute_helpers::{contains_skip, parse_schema_attrs, parsing_helpers::SchemaParamsOverride}, generics::{compute_predicates, without_defaults, FindTyParams}, schema_helpers::declaration, }; +fn visit_field(field: &Field, visitor: &mut FindTyParams) -> syn::Result<()> { + if !contains_skip(&field.attrs) { + // there's no need to override params when field is skipped, because when field is skipped + // derive for it doesn't attempt to add any bounds, unlike `BorshDeserialize`, which + // adds `Default` bound on any type parameters in skipped field + let schema_attrs = parse_schema_attrs(&field.attrs)?; + if let Some(schema_params) = schema_attrs { + for SchemaParamsOverride { + order_param, + override_type, + .. + } in schema_params + { + visitor.param_associated_type_insert(order_param, override_type); + } + } else { + visitor.visit_field(field); + } + } + Ok(()) +} + /// check param usage in fields with respect to `borsh_skip` attribute usage -pub fn visit_struct_fields(fields: &Fields, visitor: &mut FindTyParams) { +pub fn visit_struct_fields(fields: &Fields, visitor: &mut FindTyParams) -> syn::Result<()> { match &fields { Fields::Named(fields) => { for field in &fields.named { - if !contains_skip(&field.attrs) { - visitor.visit_field(field); - } + visit_field(field, visitor)?; } } Fields::Unnamed(fields) => { for field in &fields.unnamed { - if !contains_skip(&field.attrs) { - visitor.visit_field(field); - } + visit_field(field, visitor)?; } } Fields::Unit => {} } + Ok(()) } /// check param usage in fields @@ -66,7 +85,7 @@ pub fn process_struct(input: &ItemStruct, cratename: Ident) -> syn::Result { for field in &fields.named { @@ -437,4 +456,47 @@ mod tests { insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); } + + #[test] + fn generic_associated_type_param_override2() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + #[borsh(schema(params = + "T => T, T => ::Associated" + ))] + field: (::Associated, T), + another: V, + } + }) + .unwrap(); + + let actual = process_struct(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + + insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } + + #[test] + fn generic_associated_type_param_override_ignored() { + let item_struct: ItemStruct = syn::parse2(quote! { + struct Parametrized + where + T: TraitName, + { + #[borsh_skip] + #[borsh(schema(params = + "T => ::Associated" + ))] + field: ::Associated, + another: V, + } + }) + .unwrap(); + + let actual = process_struct(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + + insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } } diff --git a/borsh/tests/test_schema_enums.rs b/borsh/tests/test_schema_enums.rs index 29dd6327b..1cdb21410 100644 --- a/borsh/tests/test_schema_enums.rs +++ b/borsh/tests/test_schema_enums.rs @@ -9,7 +9,7 @@ use std::collections::BTreeMap; #[cfg(not(feature = "std"))] extern crate alloc; #[cfg(not(feature = "std"))] -use alloc::{collections::BTreeMap, format, string::ToString, vec}; +use alloc::{collections::BTreeMap, format, string::{ToString, String}, vec}; use borsh::schema::*; use borsh::schema_helpers::{try_from_slice_with_schema, try_to_vec_with_schema}; @@ -252,6 +252,23 @@ pub fn complex_enum_generics() { ); } +fn common_map() -> BTreeMap { + map! { + "EnumParametrized" => Definition::Enum{ variants: vec![ + ("B".to_string(), "EnumParametrizedB".to_string()), + ("C".to_string(), "EnumParametrizedC".to_string()) + ]}, + "EnumParametrizedB" => Definition::Struct { fields: Fields::NamedFields(vec![ + ("x".to_string(), "BTreeMap".to_string()), + ("y".to_string(), "string".to_string()), + ("z".to_string(), "i8".to_string()) + ])}, + "EnumParametrizedC" => Definition::Struct{ fields: Fields::UnnamedFields(vec!["string".to_string(), "u16".to_string()])}, + "BTreeMap" => Definition::Sequence { elements: "Tuple".to_string()}, + "Tuple" => Definition::Tuple { elements: vec!["u32".to_string(), "u16".to_string()]} + } +} + #[test] pub fn generic_associated_item1() { trait TraitName { @@ -280,25 +297,14 @@ pub fn generic_associated_item1() { C(T, u16), } - // assert_eq!( - // "Parametrized".to_string(), - // >::declaration() - // ); - - // let mut defs = Default::default(); - // >::add_definitions_recursively(&mut defs); - // assert_eq!( - // map! { + assert_eq!( + "EnumParametrized".to_string(), + >::declaration() + ); - // "Parametrized" => Definition::Struct { - // fields: Fields::NamedFields(vec![ - // ("field".to_string(), "i8".to_string()), - // ("another".to_string(), "string".to_string()) - // ]) - // } - // }, - // defs - // ); + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); + assert_eq!(common_map(), defs); } #[test] @@ -329,24 +335,8 @@ pub fn generic_associated_item2() { }, C(T, u16), } + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); - // assert_eq!( - // "Parametrized".to_string(), - // >::declaration() - // ); - - // let mut defs = Default::default(); - // >::add_definitions_recursively(&mut defs); - // assert_eq!( - // map! { - - // "Parametrized" => Definition::Struct { - // fields: Fields::NamedFields(vec![ - // ("field".to_string(), "i8".to_string()), - // ("another".to_string(), "string".to_string()) - // ]) - // } - // }, - // defs - // ); + assert_eq!(common_map(), defs); } diff --git a/borsh/tests/test_schema_structs.rs b/borsh/tests/test_schema_structs.rs index f8e10af0a..ad7c69243 100644 --- a/borsh/tests/test_schema_structs.rs +++ b/borsh/tests/test_schema_structs.rs @@ -174,6 +174,18 @@ pub fn simple_generics() { ); } +fn common_map() -> BTreeMap { + map! { + + "Parametrized" => Definition::Struct { + fields: Fields::NamedFields(vec![ + ("field".to_string(), "i8".to_string()), + ("another".to_string(), "string".to_string()) + ]) + } + } +} + #[test] pub fn generic_associated_item() { trait TraitName { @@ -182,13 +194,13 @@ pub fn generic_associated_item() { } impl TraitName for u32 { - type Associated = String; + type Associated = i8; fn method(&self) {} } #[allow(unused)] #[derive(borsh::BorshSchema)] - struct Parametrized + struct Parametrized where T: TraitName, { @@ -197,25 +209,15 @@ pub fn generic_associated_item() { } assert_eq!( - "Parametrized".to_string(), - >::declaration() + "Parametrized".to_string(), + >::declaration() ); let mut defs = Default::default(); - >::add_definitions_recursively(&mut defs); - assert_eq!( - map! { - - "Parametrized" => Definition::Struct { - fields: Fields::NamedFields(vec![ - ("field".to_string(), "string".to_string()), - ("another".to_string(), "string".to_string()) - ]) - } - }, - defs - ); + >::add_definitions_recursively(&mut defs); + assert_eq!(common_map(), defs); } + #[test] pub fn generic_associated_item2() { trait TraitName { @@ -246,14 +248,49 @@ pub fn generic_associated_item2() { let mut defs = Default::default(); >::add_definitions_recursively(&mut defs); + assert_eq!(common_map(), defs); +} + +#[test] +pub fn generic_associated_item3() { + trait TraitName { + type Associated; + fn method(&self); + } + + impl TraitName for u32 { + type Associated = i8; + fn method(&self) {} + } + + #[allow(unused)] + #[derive(borsh::BorshSchema)] + struct Parametrized + where + T: TraitName, + { + #[borsh(schema(params = "T => T, T => ::Associated"))] + field: (::Associated, T), + another: V, + } + assert_eq!( - map! { + "Parametrized".to_string(), + >::declaration() + ); - "Parametrized" => Definition::Struct { - fields: Fields::NamedFields(vec![ - ("field".to_string(), "i8".to_string()), - ("another".to_string(), "string".to_string()) - ]) + let mut defs = Default::default(); + >::add_definitions_recursively(&mut defs); + assert_eq!( + map! { + "Parametrized" => Definition::Struct { + fields: Fields::NamedFields(vec![ + ("field".to_string(), "Tuple".to_string()), + ("another".to_string(), "string".to_string()) + ]) + }, + "Tuple" => Definition::Tuple { + elements: vec!["i8".to_string(), "u32".to_string()] } }, defs