Skip to content

Commit

Permalink
Refactoring of FnArg (#4033)
Browse files Browse the repository at this point in the history
* refactor `FnArg`

* add UI tests

* use enum variant types

* add comment

* remove dead code

* remove last FIXME

* review feedback davidhewitt
  • Loading branch information
Icxolu authored Apr 14, 2024
1 parent 721100a commit cc7e16f
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 238 deletions.
179 changes: 139 additions & 40 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::utils::Ctx;
use crate::{
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::{impl_arg_params, Holders},
pyfunction::{
Expand All @@ -17,19 +17,109 @@ use crate::{
};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
pub struct RegularArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
pub optional: Option<&'a syn::Type>,
pub default: Option<syn::Expr>,
pub py: bool,
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
pub is_cancel_handle: bool,
pub from_py_with: Option<FromPyWithAttribute>,
pub default_value: Option<syn::Expr>,
pub option_wrapped_type: Option<&'a syn::Type>,
}

/// Pythons *args argument
#[derive(Clone, Debug)]
pub struct VarargsArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

/// Pythons **kwarg argument
#[derive(Clone, Debug)]
pub struct KwargsArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub struct CancelHandleArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub struct PyArg<'a> {
pub name: &'a syn::Ident,
pub ty: &'a syn::Type,
}

#[derive(Clone, Debug)]
pub enum FnArg<'a> {
Regular(RegularArg<'a>),
VarArgs(VarargsArg<'a>),
KwArgs(KwargsArg<'a>),
Py(PyArg<'a>),
CancelHandle(CancelHandleArg<'a>),
}

impl<'a> FnArg<'a> {
pub fn name(&self) -> &'a syn::Ident {
match self {
FnArg::Regular(RegularArg { name, .. }) => name,
FnArg::VarArgs(VarargsArg { name, .. }) => name,
FnArg::KwArgs(KwargsArg { name, .. }) => name,
FnArg::Py(PyArg { name, .. }) => name,
FnArg::CancelHandle(CancelHandleArg { name, .. }) => name,
}
}

pub fn ty(&self) -> &'a syn::Type {
match self {
FnArg::Regular(RegularArg { ty, .. }) => ty,
FnArg::VarArgs(VarargsArg { ty, .. }) => ty,
FnArg::KwArgs(KwargsArg { ty, .. }) => ty,
FnArg::Py(PyArg { ty, .. }) => ty,
FnArg::CancelHandle(CancelHandleArg { ty, .. }) => ty,
}
}

#[allow(clippy::wrong_self_convention)]
pub fn from_py_with(&self) -> Option<&FromPyWithAttribute> {
if let FnArg::Regular(RegularArg { from_py_with, .. }) = self {
from_py_with.as_ref()
} else {
None
}
}

pub fn to_varargs_mut(&mut self) -> Result<&mut Self> {
if let Self::Regular(RegularArg {
name,
ty,
option_wrapped_type: None,
..
}) = self
{
*self = Self::VarArgs(VarargsArg { name, ty });
Ok(self)
} else {
bail_spanned!(self.name().span() => "args cannot be optional")
}
}

pub fn to_kwargs_mut(&mut self) -> Result<&mut Self> {
if let Self::Regular(RegularArg {
name,
ty,
option_wrapped_type: Some(..),
..
}) = self
{
*self = Self::KwArgs(KwargsArg { name, ty });
Ok(self)
} else {
bail_spanned!(self.name().span() => "kwargs must be Option<_>")
}
}

/// Transforms a rust fn arg parsed with syn into a method::FnArg
pub fn parse(arg: &'a mut syn::FnArg) -> Result<Self> {
match arg {
Expand All @@ -41,32 +131,43 @@ impl<'a> FnArg<'a> {
bail_spanned!(cap.ty.span() => IMPL_TRAIT_ERR);
}

let arg_attrs = PyFunctionArgPyO3Attributes::from_attrs(&mut cap.attrs)?;
let PyFunctionArgPyO3Attributes {
from_py_with,
cancel_handle,
} = PyFunctionArgPyO3Attributes::from_attrs(&mut cap.attrs)?;
let ident = match &*cap.pat {
syn::Pat::Ident(syn::PatIdent { ident, .. }) => ident,
other => return Err(handle_argument_error(other)),
};

let is_cancel_handle = arg_attrs.cancel_handle.is_some();
if utils::is_python(&cap.ty) {
return Ok(Self::Py(PyArg {
name: ident,
ty: &cap.ty,
}));
}

Ok(FnArg {
if cancel_handle.is_some() {
// `PyFunctionArgPyO3Attributes::from_attrs` validates that
// only compatible attributes are specified, either
// `cancel_handle` or `from_py_with`, dublicates and any
// combination of the two are already rejected.
return Ok(Self::CancelHandle(CancelHandleArg {
name: ident,
ty: &cap.ty,
}));
}

Ok(Self::Regular(RegularArg {
name: ident,
ty: &cap.ty,
optional: utils::option_type_argument(&cap.ty),
default: None,
py: utils::is_python(&cap.ty),
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle,
})
from_py_with,
default_value: None,
option_wrapped_type: utils::option_type_argument(&cap.ty),
}))
}
}
}

pub fn is_regular(&self) -> bool {
!self.py && !self.is_cancel_handle && !self.is_kwargs && !self.is_varargs
}
}

fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
Expand Down Expand Up @@ -492,12 +593,14 @@ impl<'a> FnSpec<'a> {
.signature
.arguments
.iter()
.filter(|arg| arg.is_cancel_handle);
.filter(|arg| matches!(arg, FnArg::CancelHandle(..)));
let cancel_handle = cancel_handle_iter.next();
if let Some(arg) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(arg2) = cancel_handle_iter.next() {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
if let Some(FnArg::CancelHandle(CancelHandleArg { name, .. })) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(FnArg::CancelHandle(CancelHandleArg { name, .. })) =
cancel_handle_iter.next()
{
bail_spanned!(name.span() => "`cancel_handle` may only be specified once");
}
}

Expand Down Expand Up @@ -605,14 +708,10 @@ impl<'a> FnSpec<'a> {
.signature
.arguments
.iter()
.map(|arg| {
if arg.py {
quote!(py)
} else if arg.is_cancel_handle {
quote!(__cancel_handle)
} else {
unreachable!()
}
.map(|arg| match arg {
FnArg::Py(..) => quote!(py),
FnArg::CancelHandle(..) => quote!(__cancel_handle),
_ => unreachable!("`CallingConvention::Noargs` should not contain any arguments (reaching Python) except for `self`, which is handled below."),
})
.collect();
let call = rust_call(args, &mut holders);
Expand All @@ -635,7 +734,7 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::Fastcall => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders, ctx)?;
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders, ctx);
let call = rust_call(args, &mut holders);
let init_holders = holders.init_holders(ctx);
let check_gil_refs = holders.check_gil_refs();
Expand All @@ -660,7 +759,7 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::Varargs => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx)?;
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx);
let call = rust_call(args, &mut holders);
let init_holders = holders.init_holders(ctx);
let check_gil_refs = holders.check_gil_refs();
Expand All @@ -684,7 +783,7 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::TpNew => {
let mut holders = Holders::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx)?;
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx);
let self_arg = self
.tp
.self_arg(cls, ExtractErrorMode::Raise, &mut holders, ctx);
Expand Down
Loading

0 comments on commit cc7e16f

Please sign in to comment.