Skip to content

Commit

Permalink
feat: implement borsh schema params override
Browse files Browse the repository at this point in the history
  • Loading branch information
dj8yf0μl committed Jul 14, 2023
1 parent 4c0d41f commit 5aff41e
Show file tree
Hide file tree
Showing 23 changed files with 736 additions and 159 deletions.
1 change: 1 addition & 0 deletions borsh-derive-internal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Binary Object Representation Serializer for Hashing
proc-macro2 = "1"
syn = { version = "2", features = ["full", "fold"] }
quote = "1"
syn_derive = "0.1.6"

[dev-dependencies]
syn = { version = "2", features = ["full", "fold", "parsing"] }
Expand Down
162 changes: 146 additions & 16 deletions borsh-derive-internal/src/attribute_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@ use syn::{Attribute, Field, Path, WherePredicate};
pub mod parsing_helpers;
use parsing_helpers::get_where_predicates;

use self::parsing_helpers::{get_schema_attrs, SchemaParamsOverride};

#[derive(Copy, Clone)]
pub struct Symbol(pub &'static str);

/// top level prefix in nested meta attribute
/// borsh - top level prefix in nested meta attribute
pub const BORSH: Symbol = Symbol("borsh");
/// field-level only attribute, `BorshSerialize` and `BorshDeserialize` contexts
/// bound - sub-borsh nested meta, field-level only, `BorshSerialize` and `BorshDeserialize` contexts
pub const BOUND: Symbol = Symbol("bound");
/// sub-BOUND nested meta attribute
/// serialize - sub-bound nested meta attribute
pub const SERIALIZE: Symbol = Symbol("serialize");
/// sub-BOUND nested meta attribute
/// deserialize - sub-bound nested meta attribute
pub const DESERIALIZE: Symbol = Symbol("deserialize");
/// field-level only attribute, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts
/// borsh_skip - field-level only attribute, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts
pub const SKIP: Symbol = Symbol("borsh_skip");
/// item-level only attribute `BorshDeserialize` context
/// borsh_init - item-level only attribute `BorshDeserialize` context
pub const INIT: Symbol = Symbol("borsh_init");
/// schema - sub-borsh nested meta, `BorshSchema` context
pub const SCHEMA: Symbol = Symbol("schema");
/// params - sub-schema nested meta, field-level only attribute
pub const PARAMS: Symbol = Symbol("params");

impl PartialEq<Symbol> for Path {
fn eq(&self, word: &Symbol) -> bool {
Expand All @@ -32,11 +38,11 @@ impl<'a> PartialEq<Symbol> for &'a Path {
}
}

pub fn contains_skip(attrs: &[Attribute]) -> bool {
pub(crate) fn contains_skip(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path() == SKIP)
}

pub fn contains_initialize_with(attrs: &[Attribute]) -> Option<Path> {
pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Option<Path> {
for attr in attrs.iter() {
if attr.path() == INIT {
let mut res = None;
Expand All @@ -51,9 +57,10 @@ pub fn contains_initialize_with(attrs: &[Attribute]) -> Option<Path> {
None
}

type Bounds = Option<Vec<WherePredicate>>;
pub(crate) type Bounds = Option<Vec<WherePredicate>>;
pub(crate) type SchemaParams = Option<Vec<SchemaParamsOverride>>;

pub fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> {
fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error> {
let (mut ser, mut de): (Bounds, Bounds) = (None, None);
for attr in attrs {
if attr.path() != BORSH {
Expand All @@ -75,20 +82,40 @@ pub fn parse_bounds(attrs: &[Attribute]) -> Result<(Bounds, Bounds), syn::Error>
Ok((ser, de))
}

pub enum BoundType {
pub(crate) fn parse_schema_attrs(attrs: &[Attribute]) -> Result<SchemaParams, syn::Error> {
let mut params: SchemaParams = None;
for attr in attrs {
if attr.path() != BORSH {
continue;
}

attr.parse_nested_meta(|meta| {
if meta.path == SCHEMA {
// #[borsh(schema(params = "..."))]

let params_parsed = get_schema_attrs(&meta)?;
params = params_parsed;
}
Ok(())
})?;
}

Ok(params)
}
pub(crate) enum BoundType {
Serialize,
Deserialize,
}

pub fn get_bounds(field: &Field, ty: BoundType) -> Result<Bounds, syn::Error> {
pub(crate) fn get_bounds(field: &Field, ty: BoundType) -> Result<Bounds, syn::Error> {
let (ser, de) = parse_bounds(&field.attrs)?;
match ty {
BoundType::Serialize => Ok(ser),
BoundType::Deserialize => Ok(de),
}
}

pub fn collect_override_bounds(
pub(crate) fn collect_override_bounds(
field: &Field,
ty: BoundType,
output: &mut Vec<WherePredicate>,
Expand All @@ -109,8 +136,10 @@ mod tests {
use std::fmt::Write;
use syn::ItemStruct;

use crate::attribute_helpers::parse_schema_attrs;

use super::{parse_bounds, Bounds};
fn debug_print_bounds(bounds: Bounds) -> String {
fn debug_print_bounds<T: ToTokens>(bounds: Option<Vec<T>>) -> String {
let mut s = String::new();
if let Some(bounds) = bounds {
for bound in bounds {
Expand Down Expand Up @@ -180,7 +209,7 @@ mod tests {

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let (ser, de) = parse_bounds(&first_field.attrs).unwrap();
insta::assert_snapshot!(debug_print_bounds(ser));
assert_eq!(ser.unwrap().len(), 0);
insta::assert_snapshot!(debug_print_bounds(de));
}

Expand All @@ -197,7 +226,7 @@ mod tests {

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let (ser, de) = parse_bounds(&first_field.attrs).unwrap();
insta::assert_snapshot!(debug_print_bounds(ser));
assert!(ser.is_none());
insta::assert_snapshot!(debug_print_bounds(de));
}

Expand Down Expand Up @@ -257,4 +286,105 @@ mod tests {
};
insta::assert_debug_snapshot!(err);
}

#[test]
fn test_schema_params_parsing1() {
let item_struct: ItemStruct = syn::parse2(quote! {
struct Parametrized<V, T>
where
T: TraitName,
{
#[borsh(schema(params =
"T => <T as TraitName>::Associated"
))]
field: <T as TraitName>::Associated,
another: V,
}
})
.unwrap();

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let schema_params = parse_schema_attrs(&first_field.attrs).unwrap();
insta::assert_snapshot!(debug_print_bounds(schema_params));
}
#[test]
fn test_schema_params_parsing_error() {
let item_struct: ItemStruct = syn::parse2(quote! {
struct Parametrized<V, T>
where
T: TraitName,
{
#[borsh(schema(params =
"T => <T as TraitName, W>::Associated"
))]
field: <T as TraitName>::Associated,
another: V,
}
})
.unwrap();

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let err = match parse_schema_attrs(&first_field.attrs) {
Ok(..) => unreachable!("expecting error here"),
Err(err) => err,
};
insta::assert_debug_snapshot!(err);
}

#[test]
fn test_schema_params_parsing2() {
let item_struct: ItemStruct = syn::parse2(quote! {
struct Parametrized<V, T>
where
T: TraitName,
{
#[borsh(schema(params =
"T => <T as TraitName>::Associated, V => Vec<V>"
))]
field: <T as TraitName>::Associated,
another: V,
}
})
.unwrap();

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let schema_params = parse_schema_attrs(&first_field.attrs).unwrap();
insta::assert_snapshot!(debug_print_bounds(schema_params));
}
#[test]
fn test_schema_params_parsing3() {
let item_struct: ItemStruct = syn::parse2(quote! {
struct Parametrized<V, T>
where
T: TraitName,
{
#[borsh(schema(params = "" ))]
field: <T as TraitName>::Associated,
another: V,
}
})
.unwrap();

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let schema_params = parse_schema_attrs(&first_field.attrs).unwrap();
assert_eq!(schema_params.unwrap().len(), 0);
}

#[test]
fn test_schema_params_parsing4() {
let item_struct: ItemStruct = syn::parse2(quote! {
struct Parametrized<V, T>
where
T: TraitName,
{
field: <T as TraitName>::Associated,
another: V,
}
})
.unwrap();

let first_field = &item_struct.fields.into_iter().collect::<Vec<_>>()[0];
let schema_params = parse_schema_attrs(&first_field.attrs).unwrap();
assert!(schema_params.is_none());
}
}
66 changes: 58 additions & 8 deletions borsh-derive-internal/src/attribute_helpers/parsing_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#![allow(unused)]
use std::iter::FromIterator;

use syn::{meta::ParseNestedMeta, punctuated::Punctuated, token::Paren, Expr, Lit, LitStr, Token};
use syn::{
meta::ParseNestedMeta, punctuated::Punctuated, token::Paren, Expr, Ident, Lit, LitStr, Token,
Type, WherePredicate,
};

use super::{Bounds, Symbol, BOUND, DESERIALIZE, SERIALIZE};
use super::{Bounds, SchemaParams, Symbol, BOUND, DESERIALIZE, PARAMS, SCHEMA, SERIALIZE};
fn get_lit_str2(
attr_name: Symbol,
meta_item_name: Symbol,
Expand All @@ -31,22 +34,30 @@ fn get_lit_str2(
}
}

fn parse_lit_into_where(
fn parse_lit_into<T: syn::parse::Parse>(
attr_name: Symbol,
meta_item_name: Symbol,
meta: &ParseNestedMeta,
) -> syn::Result<Vec<syn::WherePredicate>> {
) -> syn::Result<Vec<T>> {
let string = match get_lit_str2(attr_name, meta_item_name, meta)? {
Some(string) => string,
None => return Ok(Vec::new()),
};

match string.parse_with(Punctuated::<syn::WherePredicate, Token![,]>::parse_terminated) {
Ok(predicates) => Ok(Vec::from_iter(predicates)),
match string.parse_with(Punctuated::<T, Token![,]>::parse_terminated) {
Ok(elements) => Ok(Vec::from_iter(elements)),
Err(err) => Err(syn::Error::new_spanned(string, err)),
}
}

/// struct describes entries like `K => <K as TraitName>::Associated`
#[derive(Clone, syn_derive::Parse, syn_derive::ToTokens)]
pub(crate) struct SchemaParamsOverride {
pub order_param: Ident,
arrow_token: Token![=>],
pub override_type: Type,
}

fn get_ser_and_de<T, F, R>(
attr_name: Symbol,
meta: &ParseNestedMeta,
Expand Down Expand Up @@ -85,6 +96,45 @@ where

Ok((ser_meta, de_meta))
}
pub fn get_where_predicates(meta: &ParseNestedMeta) -> syn::Result<(Bounds, Bounds)> {
get_ser_and_de(BOUND, meta, parse_lit_into_where)

fn get_schema_nested_meta<T, F, R>(
attr_name: Symbol,
meta: &ParseNestedMeta,
f: F,
) -> syn::Result<Option<T>>
where
T: Clone,
F: Fn(Symbol, Symbol, &ParseNestedMeta) -> syn::Result<R>,
R: Into<Option<T>>,
{
let mut params: Option<T> = None;

let lookahead = meta.input.lookahead1();
if lookahead.peek(Paren) {
meta.parse_nested_meta(|meta| {
if meta.path == PARAMS {
if let Some(v) = f(attr_name, PARAMS, &meta)?.into() {
params = Some(v);
}
} else {
return Err(meta.error(format_args!(
"malformed {0} attribute, expected `{0}(params = ...)`",
attr_name.0,
)));
}
Ok(())
})?;
} else {
return Err(lookahead.error());
}

Ok(params)
}

pub(crate) fn get_where_predicates(meta: &ParseNestedMeta) -> syn::Result<(Bounds, Bounds)> {
get_ser_and_de(BOUND, meta, parse_lit_into::<WherePredicate>)
}

pub(crate) fn get_schema_attrs(meta: &ParseNestedMeta) -> syn::Result<SchemaParams> {
get_schema_nested_meta(SCHEMA, meta, parse_lit_into::<SchemaParamsOverride>)
}
Loading

0 comments on commit 5aff41e

Please sign in to comment.