diff --git a/Cargo.lock b/Cargo.lock index 9596a44b49d..74e718383c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1779,6 +1779,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "gix-momo" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.26", +] + [[package]] name = "gix-negotiate" version = "0.6.0" diff --git a/Cargo.toml b/Cargo.toml index 5803b2331ef..39a466716ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -238,6 +238,7 @@ members = [ "gix-packetline", "gix-packetline-blocking", "gix-mailmap", + "gix-momo", "gix-note", "gix-negotiate", "gix-fetchhead", diff --git a/gix-momo/Cargo.toml b/gix-momo/Cargo.toml new file mode 100644 index 00000000000..193019bbc83 --- /dev/null +++ b/gix-momo/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "gix-momo" +version = "0.0.0" +edition = "2021" +description = "A gix proc_macro_attribute to outline conversions from generic functions" +authors = ["Sebastian Thiel "] +repository = "https://github.com/Byron/gitoxide" +license = "MIT OR Apache-2.0" +include = ["src/**/*", "LICENSE-*", "CHANGELOG.md"] +rust-version = "1.65" + +[lib] +proc_macro = true + +[dependencies] +syn = { version = "2.0", features = ["full", "fold"] } +quote = "1.0" +proc-macro2 = "1.0" diff --git a/gix-momo/LICENSE-APACHE b/gix-momo/LICENSE-APACHE new file mode 120000 index 00000000000..965b606f331 --- /dev/null +++ b/gix-momo/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/gix-momo/LICENSE-MIT b/gix-momo/LICENSE-MIT new file mode 120000 index 00000000000..76219eb72e8 --- /dev/null +++ b/gix-momo/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/gix-momo/src/lib.rs b/gix-momo/src/lib.rs new file mode 100644 index 00000000000..e8a3c435ae9 --- /dev/null +++ b/gix-momo/src/lib.rs @@ -0,0 +1,302 @@ +use std::collections::{HashMap, HashSet}; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{fold::Fold, punctuated::Punctuated, spanned::Spanned, *}; + +#[derive(Copy, Clone)] +// All conversions we support. Check references to this type for an idea how to add more. +enum Conversion<'a> { + Into(&'a Type), + TryInto(&'a Type), + AsRef(&'a Type), + AsMut(&'a Type), +} + +impl<'a> Conversion<'a> { + fn target_type(&self) -> Type { + match *self { + Conversion::Into(ty) => ty.clone(), + Conversion::TryInto(ty) => ty.clone(), + Conversion::AsRef(ty) => parse_quote!(&#ty), + Conversion::AsMut(ty) => parse_quote!(&mut #ty), + } + } + + fn conversion_expr(&self, i: Ident) -> Expr { + match *self { + Conversion::Into(_) => parse_quote!(#i.into()), + Conversion::TryInto(_) => parse_quote!(#i.try_into()?), + Conversion::AsRef(_) => parse_quote!(#i.as_ref()), + Conversion::AsMut(_) => parse_quote!(#i.as_mut()), + } + } +} + +fn parse_bounded_type(ty: &Type) -> Option { + if let Type::Path(TypePath { qself: None, ref path }) = ty { + if path.segments.len() == 1 { + return Some(path.segments[0].ident.clone()); + } + } + None +} + +fn parse_bounds(bounds: &Punctuated) -> Option { + if bounds.len() != 1 { + return None; + } + if let TypeParamBound::Trait(ref tb) = bounds.first().unwrap() { + if let Some(seg) = tb.path.segments.iter().last() { + if let PathArguments::AngleBracketed(ref gen_args) = seg.arguments { + if let GenericArgument::Type(ref arg_ty) = gen_args.args.first().unwrap() { + if seg.ident == "Into" { + return Some(Conversion::Into(arg_ty)); + } else if seg.ident == "TryInto" { + return Some(Conversion::TryInto(arg_ty)); + } else if seg.ident == "AsRef" { + return Some(Conversion::AsRef(arg_ty)); + } else if seg.ident == "AsMut" { + return Some(Conversion::AsMut(arg_ty)); + } + } + } + } + } + None +} + +// create a map from generic type to Conversion +fn parse_generics(decl: &Signature) -> (HashMap>, Generics) { + let mut ty_conversions = HashMap::new(); + let mut params = Punctuated::new(); + for gp in decl.generics.params.iter() { + if let GenericParam::Type(ref tp) = gp { + if let Some(conversion) = parse_bounds(&tp.bounds) { + ty_conversions.insert(tp.ident.clone(), conversion); + continue; + } + } + params.push(gp.clone()); + } + let where_clause = if let Some(ref wc) = decl.generics.where_clause { + let mut idents_to_remove = HashSet::new(); + let mut predicates = Punctuated::new(); + for wp in wc.predicates.iter() { + if let WherePredicate::Type(ref pt) = wp { + if let Some(ident) = parse_bounded_type(&pt.bounded_ty) { + if let Some(conversion) = parse_bounds(&pt.bounds) { + idents_to_remove.insert(ident.clone()); + ty_conversions.insert(ident, conversion); + continue; + } + } + } + predicates.push(wp.clone()); + } + params = params + .into_iter() + .filter(|param| { + if let GenericParam::Type(type_param) = param { + !idents_to_remove.contains(&type_param.ident) + } else { + true + } + }) + .collect(); + Some(WhereClause { + predicates, + ..wc.clone() + }) + } else { + None + }; + ( + ty_conversions, + Generics { + params, + where_clause, + ..decl.generics.clone() + }, + ) +} + +fn pat_to_ident(pat: &Pat) -> Ident { + if let Pat::Ident(ref pat_ident) = *pat { + return pat_ident.ident.clone(); + } + unimplemented!("No non-ident patterns for now!"); +} + +fn pat_to_expr(pat: &Pat) -> Expr { + let ident = pat_to_ident(pat); + parse_quote!(#ident) +} + +fn convert<'a>( + inputs: &'a Punctuated, + ty_conversions: &HashMap>, +) -> ( + Punctuated, + Conversions, + Punctuated, + bool, +) { + let mut argtypes = Punctuated::new(); + let mut conversions = Conversions { + intos: Vec::new(), + try_intos: Vec::new(), + as_refs: Vec::new(), + as_muts: Vec::new(), + }; + let mut argexprs = Punctuated::new(); + let mut has_self = false; + inputs.iter().for_each(|input| match input { + FnArg::Receiver(..) => { + has_self = true; + argtypes.push(input.clone()); + } + FnArg::Typed(PatType { + ref pat, + ref ty, + ref colon_token, + .. + }) => match **ty { + Type::ImplTrait(TypeImplTrait { ref bounds, .. }) => { + if let Some(conv) = parse_bounds(bounds) { + argtypes.push(FnArg::Typed(PatType { + attrs: Vec::new(), + pat: pat.clone(), + colon_token: *colon_token, + ty: Box::new(conv.target_type()), + })); + let ident = pat_to_ident(pat); + conversions.add(ident.clone(), conv); + argexprs.push(conv.conversion_expr(ident)); + } + } + Type::Path(..) => { + if let Some(ident) = parse_bounded_type(ty) { + if let Some(conv) = ty_conversions.get(&ident) { + argtypes.push(FnArg::Typed(PatType { + attrs: Vec::new(), + pat: pat.clone(), + colon_token: *colon_token, + ty: Box::new(conv.target_type()), + })); + let ident = pat_to_ident(pat); + conversions.add(ident, *conv); + argexprs.push(conv.conversion_expr(pat_to_ident(pat))); + } + } + } + _ => { + argtypes.push(input.clone()); + argexprs.push(pat_to_expr(pat)); + } + }, + }); + (argtypes, conversions, argexprs, has_self) +} + +struct Conversions { + intos: Vec, + try_intos: Vec, + as_refs: Vec, + as_muts: Vec, +} + +impl Conversions { + fn add(&mut self, ident: Ident, conv: Conversion) { + match conv { + Conversion::Into(_) => self.intos.push(ident), + Conversion::TryInto(_) => self.try_intos.push(ident), + Conversion::AsRef(_) => self.as_refs.push(ident), + Conversion::AsMut(_) => self.as_muts.push(ident), + } + } +} + +fn has_conversion(idents: &[Ident], expr: &Expr) -> bool { + if let Expr::Path(ExprPath { ref path, .. }) = *expr { + if path.segments.len() == 1 { + let seg = path.segments.iter().last().unwrap(); + return idents.iter().any(|i| i == &seg.ident); + } + } + false +} + +#[allow(clippy::collapsible_if)] +impl Fold for Conversions { + fn fold_expr(&mut self, expr: Expr) -> Expr { + //TODO: Also catch `Expr::Call` with suitable paths & args + match expr { + Expr::MethodCall(mc) if mc.args.is_empty() => match &*mc.method.to_string() { + "into" if has_conversion(&self.intos, &mc.receiver) => *mc.receiver, + "try_into" if has_conversion(&self.try_intos, &mc.receiver) => *mc.receiver, + + "as_ref" if has_conversion(&self.as_refs, &mc.receiver) => *mc.receiver, + "as_mut" if has_conversion(&self.as_muts, &mc.receiver) => *mc.receiver, + + _ => syn::fold::fold_expr(self, Expr::MethodCall(mc)), + }, + Expr::Try(ExprTry { expr, .. }) => match *expr { + Expr::MethodCall(mc) + if mc.args.is_empty() + && mc.method == "try_into" + && has_conversion(&self.try_intos, &mc.receiver) => + { + *mc.receiver + } + expr => syn::fold::fold_expr(self, expr), + }, + _ => syn::fold::fold_expr(self, expr), + } + } +} + +/// Generate lightweight monomorphized wrapper around main implementation. +/// May be applied to functions and methods. +#[proc_macro_attribute] +pub fn momo(_attrs: TokenStream, input: TokenStream) -> TokenStream { + //TODO: alternatively parse ImplItem::Method + momo_inner(input.into()).into() +} + +fn momo_inner(code: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let fn_item: Item = match syn::parse2(code.clone()) { + Ok(input) => input, + Err(err) => return err.to_compile_error(), + }; + + if let Item::Fn(ref item_fn) = fn_item { + let inner_ident = syn::parse_str::(&format!("_{}_inner", item_fn.sig.ident)).unwrap(); + let (ty_conversions, generics) = parse_generics(&item_fn.sig); + let (argtypes, mut conversions, argexprs, has_self) = convert(&item_fn.sig.inputs, &ty_conversions); + let new_item = Item::Fn(ItemFn { + block: if has_self { + parse_quote!({ self.#inner_ident(#argexprs) }) + } else { + parse_quote!({ #inner_ident(#argexprs) }) + }, + ..item_fn.clone() + }); + let mut new_inner_item = ItemFn { + vis: Visibility::Inherited, + sig: Signature { + ident: inner_ident, + generics, + inputs: argtypes, + ..item_fn.sig.clone() + }, + ..item_fn.clone() + }; + new_inner_item.block = + Box::new(conversions.fold_block(std::mem::replace(new_inner_item.block.as_mut(), parse_quote!({})))); + let new_inner_item = Item::Fn(new_inner_item); + quote!(#new_item #[allow(unused_mut)] #new_inner_item) + } else { + Error::new(fn_item.span(), "expected a function").to_compile_error() + } +} diff --git a/gix-momo/tests/test.rs b/gix-momo/tests/test.rs new file mode 100644 index 00000000000..13963938327 --- /dev/null +++ b/gix-momo/tests/test.rs @@ -0,0 +1,143 @@ +use gix_momo::momo; + +#[momo] +fn test_fn( + a: impl Into, + b: impl AsRef, + mut c: impl AsMut, + d: impl TryInto, +) -> Result { + let mut s = a.into(); + s += b.as_ref(); + s += c.as_mut(); + s += &d.try_into()?; + + Ok(s) +} + +#[momo] +fn test_fn_where(a: A, b: B, mut c: C, d: D) -> Result +where + A: Into, + B: AsRef, + C: AsMut, + D: TryInto, +{ + let mut s = a.into(); + s += b.as_ref(); + s += c.as_mut(); + s += &d.try_into()?; + + Ok(s) +} + +struct TestStruct; + +impl TestStruct { + #[momo] + fn test_method( + self, + a: impl Into, + b: impl AsRef, + mut c: impl AsMut, + d: impl TryInto, + ) -> Result { + let mut s = a.into(); + s += b.as_ref(); + s += c.as_mut(); + s += &d.try_into()?; + + Ok(s) + } + + #[momo] + fn test_method2( + self: Self, + a: impl Into, + b: impl AsRef, + mut c: impl AsMut, + d: impl TryInto, + ) -> Result { + let mut s = a.into(); + s += b.as_ref(); + s += c.as_mut(); + s += &d.try_into()?; + + Ok(s) + } + + #[momo] + fn test_fn( + a: impl Into, + b: impl AsRef, + mut c: impl AsMut, + d: impl TryInto, + ) -> Result { + let mut s = a.into(); + s += b.as_ref(); + s += c.as_mut(); + s += &d.try_into()?; + + Ok(s) + } +} + +struct S(bool); +impl TryInto for S { + type Error = (); + + fn try_into(self) -> Result { + if self.0 { + Ok(String::from("!2345")) + } else { + Err(()) + } + } +} + +#[test] +fn test_basic_fn() { + assert_eq!( + test_fn("12345", "12345", String::from("12345"), S(true)).unwrap(), + "123451234512345!2345" + ); + + test_fn("12345", "12345", String::from("12345"), S(false)).unwrap_err(); +} + +#[test] +fn test_struct_method() { + // Test test_method + assert_eq!( + TestStruct + .test_method("12345", "12345", String::from("12345"), S(true)) + .unwrap(), + "123451234512345!2345" + ); + + TestStruct + .test_method("12345", "12345", String::from("12345"), S(false)) + .unwrap_err(); + + // Test test_method2 + assert_eq!( + TestStruct + .test_method2("12345", "12345", String::from("12345"), S(true)) + .unwrap(), + "123451234512345!2345" + ); + + TestStruct + .test_method2("12345", "12345", String::from("12345"), S(false)) + .unwrap_err(); +} + +#[test] +fn test_struct_fn() { + assert_eq!( + TestStruct::test_fn("12345", "12345", String::from("12345"), S(true)).unwrap(), + "123451234512345!2345" + ); + + TestStruct::test_fn("12345", "12345", String::from("12345"), S(false)).unwrap_err(); +}