Skip to content

Commit

Permalink
Implement METH_FASTCALL for pyfunctions.
Browse files Browse the repository at this point in the history
  • Loading branch information
birkenfeld committed May 20, 2021
1 parent 95d7fa2 commit 088bc57
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 53 deletions.
22 changes: 22 additions & 0 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,28 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> syn::Result<SelfType> {
}

impl<'a> FnSpec<'a> {
/// Determine if the function gets passed a *args tuple or **kwargs dict.
pub fn accept_args_kwargs(&self) -> (bool, bool) {
let (mut accept_args, mut accept_kwargs) = (false, false);

for s in &self.attrs {
match s {
Argument::VarArgs(_) => accept_args = true,
Argument::KeywordArgs(_) => accept_kwargs = true,
_ => continue,
}
}

(accept_args, accept_kwargs)
}

/// Return true if the function can use METH_FASTCALL.
///
/// This is true on Py3.7+, except with the stable ABI (abi3).
pub fn can_use_fastcall(&self) -> bool {
cfg!(all(Py_3_7, not(abi3)))
}

/// Parser function signature and function attributes
pub fn parse(
sig: &'a mut syn::Signature,
Expand Down
36 changes: 34 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,19 @@ pub fn impl_wrap_pyfunction(
let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, options.pass_module)?;
let methoddef = if spec.args.is_empty() {
quote!(noargs)
} else if spec.can_use_fastcall() {
quote!(fastcall_cfunction_with_keywords)
} else {
quote!(cfunction_with_keywords)
};
let cfunc = if spec.args.is_empty() {
quote!(PyCFunction)
} else if spec.can_use_fastcall() {
quote!(PyCFunctionFastWithKeywords)
} else {
quote!(PyCFunctionWithKeywords)
};

let wrapped_pyfunction = quote! {
#wrapper
pub(crate) fn #function_wrapper_ident<'a>(
Expand Down Expand Up @@ -465,8 +470,36 @@ fn function_c_wrapper(
})
}
})
} else if spec.can_use_fastcall() {
let body = impl_arg_params(spec, None, cb, &py, true)?;
Ok(quote! {
unsafe extern "C" fn #wrapper_ident(
_slf: *mut pyo3::ffi::PyObject,
_args: *const *mut pyo3::ffi::PyObject,
_nargs: pyo3::ffi::Py_ssize_t,
_kwnames: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject
{
pyo3::callback::handle_panic(|#py| {
#slf_module
// _nargs is the number of positional arguments in the _args array,
// the number of KW args is given by the length of _kwnames
let _kwnames: Option<&pyo3::types::PyTuple> = #py.from_borrowed_ptr_or_opt(_kwnames);
let _kwargs = if let Some(kwnames) = _kwnames {
std::slice::from_raw_parts(_args.offset(_nargs), kwnames.len())
} else {
&[]
};
let _args = std::slice::from_raw_parts(_args, _nargs as usize);
// Safety: see PyTuple::as_slice
let _args = &*(_args as *const [*mut pyo3::ffi::PyObject] as *const [&pyo3::PyAny]);
let _kwargs = &*(_kwargs as *const [*mut pyo3::ffi::PyObject] as *const [&pyo3::PyAny]);
#body
})
}

})
} else {
let body = impl_arg_params(spec, None, cb, &py)?;
let body = impl_arg_params(spec, None, cb, &py, false)?;
Ok(quote! {
unsafe extern "C" fn #wrapper_ident(
_slf: *mut pyo3::ffi::PyObject,
Expand All @@ -477,7 +510,6 @@ fn function_c_wrapper(
#slf_module
let _args = #py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
let _kwargs: Option<&pyo3::types::PyDict> = #py.from_borrowed_ptr_or_opt(_kwargs);

#body
})
}
Expand Down
42 changes: 27 additions & 15 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub fn impl_wrap_cfunction_with_keywords(
let body = impl_call(cls, &spec);
let slf = self_ty.receiver(cls);
let py = syn::Ident::new("_py", Span::call_site());
let body = impl_arg_params(&spec, Some(cls), body, &py)?;
let body = impl_arg_params(&spec, Some(cls), body, &py, false)?;
Ok(quote! {{
unsafe extern "C" fn __wrap(
_slf: *mut pyo3::ffi::PyObject,
Expand Down Expand Up @@ -135,7 +135,7 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<TokenStream>
let names: Vec<syn::Ident> = get_arg_names(&spec);
let cb = quote! { #cls::#name(#(#names),*) };
let py = syn::Ident::new("_py", Span::call_site());
let body = impl_arg_params(spec, Some(cls), cb, &py)?;
let body = impl_arg_params(spec, Some(cls), cb, &py, false)?;

Ok(quote! {{
#[allow(unused_mut)]
Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn impl_wrap_class(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<TokenStream
let names: Vec<syn::Ident> = get_arg_names(&spec);
let cb = quote! { pyo3::callback::convert(_py, #cls::#name(&_cls, #(#names),*)) };
let py = syn::Ident::new("_py", Span::call_site());
let body = impl_arg_params(spec, Some(cls), cb, &py)?;
let body = impl_arg_params(spec, Some(cls), cb, &py, false)?;

Ok(quote! {{
#[allow(unused_mut)]
Expand All @@ -192,7 +192,7 @@ pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<TokenStrea
let names: Vec<syn::Ident> = get_arg_names(&spec);
let cb = quote! { pyo3::callback::convert(_py, #cls::#name(#(#names),*)) };
let py = syn::Ident::new("_py", Span::call_site());
let body = impl_arg_params(spec, Some(cls), cb, &py)?;
let body = impl_arg_params(spec, Some(cls), cb, &py, false)?;

Ok(quote! {{
#[allow(unused_mut)]
Expand Down Expand Up @@ -343,6 +343,7 @@ pub fn impl_arg_params(
self_: Option<&syn::Type>,
body: TokenStream,
py: &syn::Ident,
fastcall: bool,
) -> Result<TokenStream> {
if spec.args.is_empty() {
return Ok(body);
Expand Down Expand Up @@ -392,16 +393,7 @@ pub fn impl_arg_params(
)?);
}

let (mut accept_args, mut accept_kwargs) = (false, false);

for s in spec.attrs.iter() {
use crate::pyfunction::Argument;
match s {
Argument::VarArgs(_) => accept_args = true,
Argument::KeywordArgs(_) => accept_kwargs = true,
_ => continue,
}
}
let (accept_args, accept_kwargs) = spec.accept_args_kwargs();

let cls_name = if let Some(cls) = self_ {
quote! { Some(<#cls as pyo3::type_object::PyTypeInfo>::NAME) }
Expand All @@ -410,6 +402,21 @@ pub fn impl_arg_params(
};
let python_name = &spec.python_name;

let (args_to_extract, kwargs_to_extract) = if fastcall {
// _args is a &[&PyAny], _kwnames is a Option<&PyTuple> containing the
// keyword names of the keyword args in _kwargs
(
quote! { _args.iter().map(|&obj| obj) },
quote! { _kwnames.map(|kwnames| kwnames.iter().zip(_kwargs.iter().map(|&obj| obj))) },
)
} else {
// _args is a &PyTuple, _kwargs is an Option<&PyDict>
(
quote! { _args.iter() },
quote! { _kwargs.map(|dict| dict.iter()) },
)
};

// create array of arguments, and then parse
Ok(quote! {
{
Expand All @@ -426,7 +433,12 @@ pub fn impl_arg_params(
};

let mut #args_array = [None; #num_params];
let (_args, _kwargs) = DESCRIPTION.extract_arguments(_args, _kwargs, &mut #args_array)?;
let (_args, _kwargs) = DESCRIPTION.extract_arguments(
#py,
#args_to_extract,
#kwargs_to_extract,
&mut #args_array
)?;

#(#param_conversion)*

Expand Down
24 changes: 24 additions & 0 deletions src/class/methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub enum PyMethodDefType {
pub enum PyMethodType {
PyCFunction(PyCFunction),
PyCFunctionWithKeywords(PyCFunctionWithKeywords),
#[cfg(Py_3_7)]
PyCFunctionFastWithKeywords(PyCFunctionFastWithKeywords),
}

// These newtype structs serve no purpose other than wrapping which are function pointers - because
Expand All @@ -36,6 +38,9 @@ pub enum PyMethodType {
pub struct PyCFunction(pub ffi::PyCFunction);
#[derive(Clone, Copy, Debug)]
pub struct PyCFunctionWithKeywords(pub ffi::PyCFunctionWithKeywords);
#[cfg(Py_3_7)]
#[derive(Clone, Copy, Debug)]
pub struct PyCFunctionFastWithKeywords(pub ffi::_PyCFunctionFastWithKeywords);
#[derive(Clone, Copy, Debug)]
pub struct PyGetter(pub ffi::getter);
#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -105,6 +110,21 @@ impl PyMethodDef {
}
}

/// Define a function that can take `*args` and `**kwargs`.
#[cfg(Py_3_7)]
pub const fn fastcall_cfunction_with_keywords(
name: &'static str,
cfunction: PyCFunctionFastWithKeywords,
doc: &'static str,
) -> Self {
Self {
ml_name: name,
ml_meth: PyMethodType::PyCFunctionFastWithKeywords(cfunction),
ml_flags: ffi::METH_FASTCALL | ffi::METH_KEYWORDS,
ml_doc: doc,
}
}

pub const fn flags(mut self, flags: c_int) -> Self {
self.ml_flags |= flags;
self
Expand All @@ -115,6 +135,10 @@ impl PyMethodDef {
let meth = match self.ml_meth {
PyMethodType::PyCFunction(meth) => meth.0,
PyMethodType::PyCFunctionWithKeywords(meth) => unsafe { std::mem::transmute(meth.0) },
#[cfg(Py_3_7)]
PyMethodType::PyCFunctionFastWithKeywords(meth) => unsafe {
std::mem::transmute(meth.0)
},
};

Ok(ffi::PyMethodDef {
Expand Down
29 changes: 17 additions & 12 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ impl FunctionDescription {
format!("{}()", self.func_name)
}
}

/// Extracts the `args` and `kwargs` provided into `output`, according to this function
/// definition.
///
Expand All @@ -52,8 +53,9 @@ impl FunctionDescription {
/// Unexpected, duplicate or invalid arguments will cause this function to return `TypeError`.
pub fn extract_arguments<'p>(
&self,
args: &'p PyTuple,
kwargs: Option<&'p PyDict>,
py: Python<'p>,
mut args: impl ExactSizeIterator<Item = &'p PyAny>,
kwargs: Option<impl Iterator<Item = (&'p PyAny, &'p PyAny)>>,
output: &mut [Option<&'p PyAny>],
) -> PyResult<(Option<&'p PyTuple>, Option<&'p PyDict>)> {
let num_positional_parameters = self.positional_parameter_names.len();
Expand All @@ -66,33 +68,36 @@ impl FunctionDescription {
);

// Handle positional arguments
let (args_provided, varargs) = {
let args_provided = {
let args_provided = args.len();

if self.accept_varargs {
(
std::cmp::min(num_positional_parameters, args_provided),
Some(args.slice(num_positional_parameters as isize, args_provided as isize)),
)
std::cmp::min(num_positional_parameters, args_provided)
} else if args_provided > num_positional_parameters {
return Err(self.too_many_positional_arguments(args_provided));
} else {
(args_provided, None)
args_provided
}
};

// Copy positional arguments into output
for (out, arg) in output[..args_provided].iter_mut().zip(args) {
for (out, arg) in output[..args_provided].iter_mut().zip(args.by_ref()) {
*out = Some(arg);
}

// Collect varargs into tuple
let varargs = if self.accept_varargs {
Some(PyTuple::new(py, args))
} else {
None
};

// Handle keyword arguments
let varkeywords = match (kwargs, self.accept_varkeywords) {
(Some(kwargs), true) => {
let mut varkeywords = None;
self.extract_keyword_arguments(kwargs, output, |name, value| {
varkeywords
.get_or_insert_with(|| PyDict::new(kwargs.py()))
.get_or_insert_with(|| PyDict::new(py))
.set_item(name, value)
})?;
varkeywords
Expand Down Expand Up @@ -146,7 +151,7 @@ impl FunctionDescription {
#[inline]
fn extract_keyword_arguments<'p>(
&self,
kwargs: &'p PyDict,
kwargs: impl Iterator<Item = (&'p PyAny, &'p PyAny)>,
output: &mut [Option<&'p PyAny>],
mut unexpected_keyword_handler: impl FnMut(&'p PyAny, &'p PyAny) -> PyResult<()>,
) -> PyResult<()> {
Expand Down
6 changes: 6 additions & 0 deletions src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ impl<'a> Iterator for PyTupleIterator<'a> {
}
}

impl<'a> ExactSizeIterator for PyTupleIterator<'a> {
fn len(&self) -> usize {
self.length - self.index
}
}

impl<'a> IntoIterator for &'a PyTuple {
type Item = &'a PyAny;
type IntoIter = PyTupleIterator<'a>;
Expand Down
48 changes: 24 additions & 24 deletions tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,30 +164,30 @@ fn test_function_with_custom_conversion_error() {
);
}

#[test]
fn test_raw_function() {
let gil = Python::acquire_gil();
let py = gil.python();
let raw_func = raw_pycfunction!(optional_bool);
let fun = PyCFunction::new_with_keywords(raw_func, "fun", "", py.into()).unwrap();
let res = fun.call((), None).unwrap().extract::<&str>().unwrap();
assert_eq!(res, "Some(true)");
let res = fun.call((false,), None).unwrap().extract::<&str>().unwrap();
assert_eq!(res, "Some(false)");
let no_module = fun.getattr("__module__").unwrap().is_none();
assert!(no_module);

let module = PyModule::new(py, "cool_module").unwrap();
module.add_function(fun).unwrap();
let res = module
.getattr("fun")
.unwrap()
.call((), None)
.unwrap()
.extract::<&str>()
.unwrap();
assert_eq!(res, "Some(true)");
}
// #[test]
// fn test_raw_function() {
// let gil = Python::acquire_gil();
// let py = gil.python();
// let raw_func = raw_pycfunction!(optional_bool);
// let fun = PyCFunction::new_with_keywords(raw_func, "fun", "", py.into()).unwrap();
// let res = fun.call((), None).unwrap().extract::<&str>().unwrap();
// assert_eq!(res, "Some(true)");
// let res = fun.call((false,), None).unwrap().extract::<&str>().unwrap();
// assert_eq!(res, "Some(false)");
// let no_module = fun.getattr("__module__").unwrap().is_none();
// assert!(no_module);

// let module = PyModule::new(py, "cool_module").unwrap();
// module.add_function(fun).unwrap();
// let res = module
// .getattr("fun")
// .unwrap()
// .call((), None)
// .unwrap()
// .extract::<&str>()
// .unwrap();
// assert_eq!(res, "Some(true)");
// }

#[pyfunction]
fn conversion_error(str_arg: &str, int_arg: i64, tuple_arg: (&str, f64), option_arg: Option<i64>) {
Expand Down

0 comments on commit 088bc57

Please sign in to comment.