Skip to content

Commit

Permalink
Merge pull request #595 from asomers/mutable-fnmut-arguments
Browse files Browse the repository at this point in the history
Make concretize work with FnMut arguments
  • Loading branch information
asomers authored Jul 21, 2024
2 parents 0949f8b + 836d100 commit 5aba4b0
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 27 deletions.
36 changes: 36 additions & 0 deletions mockall/tests/automock_concretize_closures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// vim: tw=80
//! #[concretize] should work for closure arguments
#![deny(warnings)]

use mockall::*;

pub struct Foo{}
#[automock]
impl Foo {
#[concretize]
pub fn foo<F: Fn(u32) -> u32>(&self, _f: F) -> u32 {todo!()}
#[concretize]
pub fn bar<F: FnMut(&mut u32) -> u32>(&self, _f: F) -> u32 {todo!()}
}

#[test]
fn fn_() {
let mut mock = MockFoo::default();
mock.expect_foo()
.returning(|f| f(42));
assert_eq!(mock.foo(|x| x + 1), 43);
}

#[test]
fn fn_mut() {
use std::sync::{Arc, Mutex};
let x = Arc::new(Mutex::new(42u32));
{
let mut mock = MockFoo::default();
let x2 = x.clone();
mock.expect_bar()
.returning(move |f| f(&mut x2.lock().unwrap()));
assert_eq!(mock.bar(|y| {*y += 1; *y}), 43);
}
assert_eq!(*x.lock().unwrap(), 43);
}
189 changes: 166 additions & 23 deletions mockall_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,42 @@ fn is_concretize(attr: &Attribute) -> bool {
}

/// replace generic arguments with concrete trait object arguments
fn concretize_args(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
(Generics, Vec<FnArg>, Vec<TokenStream>)
///
/// # Return
///
/// * A Generics object with the concretized types removed
/// * An array of transformed argument types, suitable for matchers and
/// returners
/// * An array of expressions that should be passed to the `call` function.
fn concretize_args(gen: &Generics, sig: &Signature) ->
(Generics, Punctuated<FnArg, Token![,]>, Vec<TokenStream>, Signature)
{
let args = &sig.inputs;
let mut hm = HashMap::default();
let mut needs_muts = HashMap::default();

let mut save_types = |ident: &Ident, tpb: &Punctuated<TypeParamBound, Token![+]>| {
if !tpb.is_empty() {
if let Ok(newty) = parse2::<Type>(quote!(&(dyn #tpb))) {
let mut pat = quote!(&(dyn #tpb));
let mut needs_mut = false;
if let Some(TypeParamBound::Trait(t)) = tpb.first() {
if t.path.segments.first().map(|seg| &seg.ident == "FnMut")
.unwrap_or(false)
{
// For FnMut arguments, the rfunc needs a mutable reference
pat = quote!(&mut (dyn #tpb));
needs_mut = true;
}
}
if let Ok(newty) = parse2::<Type>(pat) {
// substitute T arguments
let subst_ty: Type = parse2(quote!(#ident)).unwrap();
needs_muts.insert(subst_ty.clone(), needs_mut);
hm.insert(subst_ty, (newty.clone(), None));

// substitute &T arguments
let subst_ty: Type = parse2(quote!(&#ident)).unwrap();
needs_muts.insert(subst_ty.clone(), needs_mut);
hm.insert(subst_ty, (newty, None));
} else {
compile_error(tpb.span(),
Expand All @@ -96,6 +118,7 @@ fn concretize_args(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
if let Ok(newty) = parse2::<Type>(quote!(&mut (dyn #tpb))) {
// substitute &mut T arguments
let subst_ty: Type = parse2(quote!(&mut #ident)).unwrap();
needs_muts.insert(subst_ty.clone(), needs_mut);
hm.insert(subst_ty, (newty, None));
} else {
compile_error(tpb.span(),
Expand All @@ -106,6 +129,7 @@ fn concretize_args(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
// for the mock method to turn &[T] into &[&dyn T].
if let Ok(newty) = parse2::<Type>(quote!(&[&(dyn #tpb)])) {
let subst_ty: Type = parse2(quote!(&[#ident])).unwrap();
needs_muts.insert(subst_ty.clone(), needs_mut);
hm.insert(subst_ty, (newty, Some(tpb.clone())));
} else {
compile_error(tpb.span(),
Expand Down Expand Up @@ -139,21 +163,21 @@ fn concretize_args(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
params: Punctuated::new(),
where_clause: None
};
let outargs: Vec<FnArg> = args.iter().map(|arg| {
let outargs = args.iter().map(|arg| {
if let FnArg::Typed(pt) = arg {
let mut immutable_pt = pt.clone();
demutify_arg(&mut immutable_pt);
let mut call_pt = pt.clone();
demutify_arg(&mut call_pt);
if let Some((newty, _)) = hm.get(&pt.ty) {
FnArg::Typed(PatType {
attrs: Vec::default(),
pat: immutable_pt.pat,
pat: call_pt.pat,
colon_token: pt.colon_token,
ty: Box::new(newty.clone())
})
} else {
FnArg::Typed(PatType {
attrs: Vec::default(),
pat: immutable_pt.pat,
pat: call_pt.pat,
colon_token: pt.colon_token,
ty: pt.ty.clone()
})
Expand Down Expand Up @@ -186,6 +210,8 @@ fn concretize_args(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
} else {
Some(quote!(#pat))
}
} else if needs_muts.get(&pt.ty).cloned().unwrap_or(false) {
Some(quote!(&mut #pat))
} else {
Some(quote!(&#pat))
}
Expand All @@ -196,7 +222,23 @@ fn concretize_args(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
FnArg::Receiver(_) => None,
}
}).collect();
(outg, outargs, call_exprs)

// Add any necessary "mut" qualifiers to the Signature
let mut altsig = sig.clone();
for arg in altsig.inputs.iter_mut() {
if let FnArg::Typed(pt) = arg {
if needs_muts.get(&pt.ty).cloned().unwrap_or(false) {
if let Pat::Ident(pi) = &mut *pt.pat {
pi.mutability = Some(Token![mut](pi.mutability.span()));
} else {
compile_error(pt.pat.span(),
"This Pat type is not yet supported by Mockall when used as an argument to a concretized function.")
}
}
}
}

(outg, outargs, call_exprs, altsig)
}

fn deanonymize_lifetime(lt: &mut Lifetime) {
Expand Down Expand Up @@ -273,7 +315,7 @@ fn deanonymize(literal_type: &mut Type) {
// If there are any closures in the argument list, turn them into boxed
// functions
fn declosurefy(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
(Generics, Vec<FnArg>, Vec<TokenStream>)
(Generics, Punctuated<FnArg, Token![,]>, Vec<TokenStream>)
{
let mut hm = HashMap::default();

Expand Down Expand Up @@ -1123,7 +1165,7 @@ fn lifetimes_to_generics(lv: &Punctuated<LifetimeParam, Token![,]>)-> Generics {
/// only, and one for lifetimes that relate to the return type only.
fn split_lifetimes(
generics: Generics,
args: &[FnArg],
args: &Punctuated<FnArg, Token![,]>,
rt: &ReturnType)
-> (Generics,
Punctuated<LifetimeParam, token::Comma>,
Expand Down Expand Up @@ -1534,14 +1576,15 @@ mod automock {
mod concretize_args {
use super::*;

#[allow(clippy::needless_range_loop)] // Clippy's suggestion is worse
fn check_concretize(
sig: TokenStream,
expected_inputs: &[TokenStream],
expected_call_exprs: &[TokenStream])
expected_call_exprs: &[TokenStream],
expected_sig_inputs: &[TokenStream])
{
let f: Signature = parse2(sig).unwrap();
let (generics, inputs, call_exprs) =
concretize_args(&f.generics, &f.inputs);
let (generics, inputs, call_exprs, altsig) = concretize_args(&f.generics, &f);
assert!(generics.params.is_empty());
assert_eq!(inputs.len(), expected_inputs.len());
assert_eq!(call_exprs.len(), expected_call_exprs.len());
Expand All @@ -1555,14 +1598,34 @@ mod concretize_args {
let exp = &expected_call_exprs[i];
assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
}
for i in 0..altsig.inputs.len() {
let actual = &altsig.inputs[i];
let exp = &expected_sig_inputs[i];
assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
}
}

#[test]
fn bystanders() {
check_concretize(
quote!(fn foo<P: AsRef<Path>>(x: i32, p: P, y: &f64)),
&[quote!(x: i32), quote!(p: &(dyn AsRef<Path>)), quote!(y: &f64)],
&[quote!(x), quote!(&p), quote!(y)]
&[quote!(x), quote!(&p), quote!(y)],
&[quote!(x: i32), quote!(p: P), quote!(y: &f64)]
);
}

#[test]
fn function_args() {
check_concretize(
quote!(fn foo<F1: Fn(u32) -> u32,
F2: FnMut(&mut u32) -> u32,
F3: FnOnce(u32) -> u32>(f1: F1, f2: F2, f3: F3)),
&[quote!(f1: &(dyn Fn(u32) -> u32)),
quote!(f2: &mut(dyn FnMut(&mut u32) -> u32)),
quote!(f3: &(dyn FnOnce(u32) -> u32))],
&[quote!(&f1), quote!(&mut f2), quote!(&f3)],
&[quote!(f1: F1), quote!(mut f2: F2), quote!(f3: F3)]
);
}

Expand All @@ -1571,7 +1634,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsRef<String> + AsMut<String>>(p: P)),
&[quote!(p: &(dyn AsRef<String> + AsMut<String>))],
&[quote!(&p)]
&[quote!(&p)],
&[quote!(p: P)],
);
}

Expand All @@ -1580,7 +1644,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsMut<Path>>(p: &mut P)),
&[quote!(p: &mut (dyn AsMut<Path>))],
&[quote!(p)]
&[quote!(p)],
&[quote!(p: &mut P)],
);
}

Expand All @@ -1589,7 +1654,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsRef<String> + AsMut<String>>(p: &mut P)),
&[quote!(p: &mut (dyn AsRef<String> + AsMut<String>))],
&[quote!(p)]
&[quote!(p)],
&[quote!(p: &mut P)]
);
}

Expand All @@ -1598,7 +1664,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsRef<Path>>(p: &P)),
&[quote!(p: &(dyn AsRef<Path>))],
&[quote!(p)]
&[quote!(p)],
&[quote!(p: &P)]
);
}

Expand All @@ -1607,7 +1674,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsRef<Path>>(p: P)),
&[quote!(p: &(dyn AsRef<Path>))],
&[quote!(&p)]
&[quote!(&p)],
&[quote!(p: P)],
);
}

Expand All @@ -1616,7 +1684,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsRef<Path>>(p: &[P])),
&[quote!(p: &[&(dyn AsRef<Path>)])],
&[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path>)).collect::<Vec<_>>())]
&[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path>)).collect::<Vec<_>>())],
&[quote!(p: &[P])]
);
}

Expand All @@ -1625,7 +1694,8 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P: AsRef<Path> + AsMut<String>>(p: &[P])),
&[quote!(p: &[&(dyn AsRef<Path> + AsMut<String>)])],
&[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path> + AsMut<String>)).collect::<Vec<_>>())]
&[quote!(&(0..p.len()).map(|__mockall_i| &p[__mockall_i] as &(dyn AsRef<Path> + AsMut<String>)).collect::<Vec<_>>())],
&[quote!(p: &[P])]
);
}

Expand All @@ -1634,7 +1704,80 @@ mod concretize_args {
check_concretize(
quote!(fn foo<P>(p: P) where P: AsRef<Path>),
&[quote!(p: &(dyn AsRef<Path>))],
&[quote!(&p)]
&[quote!(&p)],
&[quote!(p: P)]
);
}
}

mod declosurefy {
use super::*;

fn check_declosurefy(
sig: TokenStream,
expected_inputs: &[TokenStream],
expected_call_exprs: &[TokenStream])
{
let f: Signature = parse2(sig).unwrap();
let (generics, inputs, call_exprs) =
declosurefy(&f.generics, &f.inputs);
assert!(generics.params.is_empty());
assert_eq!(inputs.len(), expected_inputs.len());
assert_eq!(call_exprs.len(), expected_call_exprs.len());
for i in 0..inputs.len() {
let actual = &inputs[i];
let exp = &expected_inputs[i];
assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
}
for i in 0..call_exprs.len() {
let actual = &call_exprs[i];
let exp = &expected_call_exprs[i];
assert_eq!(quote!(#actual).to_string(), quote!(#exp).to_string());
}
}

#[test]
fn r#fn() {
check_declosurefy(
quote!(fn foo<F: Fn(u32) -> u32>(f: F)),
&[quote!(f: Box<dyn Fn(u32) -> u32>)],
&[quote!(Box::new(f))]
);
}

#[test]
fn fn_mut() {
check_declosurefy(
quote!(fn foo<F: FnMut(u32) -> u32>(f: F)),
&[quote!(f: Box<dyn FnMut(u32) -> u32>)],
&[quote!(Box::new(f))]
);
}

#[test]
fn fn_once() {
check_declosurefy(
quote!(fn foo<F: FnOnce(u32) -> u32>(f: F)),
&[quote!(f: Box<dyn FnOnce(u32) -> u32>)],
&[quote!(Box::new(f))]
);
}

#[test]
fn mutable_pattern() {
check_declosurefy(
quote!(fn foo<F: FnMut(u32) -> u32>(mut f: F)),
&[quote!(f: Box<dyn FnMut(u32) -> u32>)],
&[quote!(Box::new(f))]
);
}

#[test]
fn where_clause() {
check_declosurefy(
quote!(fn foo<F>(f: F) where F: Fn(u32) -> u32),
&[quote!(f: Box<dyn Fn(u32) -> u32>)],
&[quote!(Box::new(f))]
);
}
}
Expand Down
Loading

0 comments on commit 5aba4b0

Please sign in to comment.