Skip to content

Commit

Permalink
relax explicit lifetime requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah committed Nov 19, 2023
1 parent 1b5aa54 commit 7e96a0b
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 71 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ Only available with the `macro` feature.

Requires the first non-lifetime generic parameter, as well as the function's
first input parameter to be the SIMD type.
Also currently requires that all the lifetimes be explicitly specified.

```rust
#[pulp::with_simd(sum = pulp::Arch::new())]
Expand Down
2 changes: 1 addition & 1 deletion pulp-macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pulp-macro"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
authors = ["sarah <>"]
description = "Safe generic simd"
Expand Down
178 changes: 111 additions & 67 deletions pulp-macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use proc_macro2::TokenStream;
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::{Brace, Colon};
use syn::{FieldsNamed, FnArg, ItemFn};
use syn::token::{Colon, PathSep};
use syn::{
ConstParam, FnArg, GenericParam, ItemFn, LifetimeParam, Pat, PatIdent, PatType, Path,
PathSegment, Type, TypeParam, TypePath,
};

#[proc_macro_attribute]
pub fn with_simd(
Expand Down Expand Up @@ -40,20 +43,13 @@ pub fn with_simd(
block,
} = item.clone();

let mut struct_generics = Punctuated::new();
let mut struct_generics_lifetimes = Vec::new();
let mut struct_generics_names = Vec::new();
let mut struct_generics = Vec::new();
let mut struct_field_names = Vec::new();
let mut struct_fields = FieldsNamed {
brace_token: Brace {
span: sig.paren_token.span,
},
named: Punctuated::new(),
};
let mut struct_field_types = Vec::new();

let mut first_non_lifetime = usize::MAX;
for (idx, param) in sig.generics.params.clone().into_pairs().enumerate() {
let (param, comma) = param.into_tuple();
let (param, _) = param.into_tuple();
match &param {
syn::GenericParam::Lifetime(_) => {}
_ => {
Expand All @@ -63,17 +59,7 @@ pub fn with_simd(
}
}
}
match &param {
syn::GenericParam::Type(ty) => struct_generics_names.push(ty.ident.clone()),
syn::GenericParam::Lifetime(lt) => struct_generics_lifetimes.push(lt.lifetime.clone()),
syn::GenericParam::Const(const_) => struct_generics_names.push(const_.ident.clone()),
};
struct_generics.push_value(param);
if let Some(comma) = comma {
struct_generics.push_punct(comma);
}
}

let mut new_fn_sig = sig.clone();
new_fn_sig.generics.params = new_fn_sig
.generics
Expand All @@ -83,49 +69,61 @@ pub fn with_simd(
.filter(|(idx, _)| *idx != first_non_lifetime)
.map(|(_, arg)| arg)
.collect();
new_fn_sig.inputs = new_fn_sig.inputs.into_iter().skip(1).collect();
new_fn_sig.inputs = new_fn_sig
.inputs
.into_iter()
.skip(1)
.enumerate()
.map(|(idx, arg)| {
FnArg::Typed(PatType {
attrs: Vec::new(),
pat: Box::new(Pat::Ident(PatIdent {
attrs: Vec::new(),
by_ref: None,
mutability: None,
ident: Ident::new(&format!("__{idx}"), Span::call_site()),
subpat: None,
})),
colon_token: Colon {
spans: [Span::call_site()],
},
ty: match arg {
FnArg::Typed(ty) => ty.ty,
FnArg::Receiver(_) => panic!(),
},
})
})
.collect();
new_fn_sig.ident = name.clone();
let mut param_ty = Vec::new();

for param in sig.inputs.clone().into_pairs().skip(1) {
let (param, comma) = param.into_tuple();
for (idx, param) in new_fn_sig.inputs.clone().into_pairs().enumerate() {
let (param, _) = param.into_tuple();
let FnArg::Typed(param) = param.clone() else {
return quote! {
::core::compile_error!(::core::concat!(
"pulp::with_simd only accepts free functions"
));
#item
}
.into();
panic!();
};

let name = *param.pat;
let syn::Pat::Ident(name) = name else {
return quote! {
::core::compile_error!(::core::concat!(
"pulp::with_simd requires function parameters to be idents"
));
#item
}
.into();
panic!();
};

let anon_ty = Ident::new(&format!("__T{idx}"), Span::call_site());

struct_field_names.push(name.ident.clone());
let field = syn::Field {
attrs: param.attrs,
vis: syn::Visibility::Public(syn::token::Pub {
span: proc_macro2::Span::call_site(),
}),
mutability: syn::FieldMutability::None,
ident: Some(name.ident),
colon_token: Some(Colon {
spans: [proc_macro2::Span::call_site()],
}),
ty: *param.ty,
};
struct_fields.named.push_value(field);
if let Some(comma) = comma {
struct_fields.named.push_punct(comma);
}
let mut ty = Punctuated::<_, PathSep>::new();
ty.push_value(PathSegment {
ident: anon_ty.clone(),
arguments: syn::PathArguments::None,
});
struct_field_types.push(Type::Path(TypePath {
qself: None,
path: Path {
leading_colon: None,
segments: ty,
},
}));
struct_generics.push(anon_ty);
param_ty.push(*param.ty);
}

let output_ty = match sig.output.clone() {
Expand All @@ -136,33 +134,79 @@ pub fn with_simd(
let fn_name = sig.ident.clone();

let arch = attr.value;
let new_fn_generics = new_fn_sig.generics.clone();
let params = new_fn_generics.params.clone();
let generics = params.into_iter().collect::<Vec<_>>();
let non_lt_generics_names = generics
.iter()
.map(|p| match p {
GenericParam::Type(TypeParam { ident, .. })
| GenericParam::Const(ConstParam { ident, .. }) => {
quote! { #ident, }
}
_ => quote! {},
})
.collect::<Vec<_>>();
let generics_decl = generics
.iter()
.map(|p| match p {
GenericParam::Lifetime(LifetimeParam {
lifetime,
colon_token,
bounds,
..
}) => {
quote! { #lifetime #colon_token #bounds }
}
GenericParam::Type(TypeParam {
ident,
colon_token,
bounds,
..
}) => {
quote! { #ident #colon_token #bounds }
}
GenericParam::Const(ConstParam {
const_token,
ident,
colon_token,
ty,
..
}) => {
quote! { #const_token #ident #colon_token #ty }
}
})
.collect::<Vec<_>>();
let generics_where_clause = new_fn_generics.where_clause;

quote! {
let code = quote! {
#(#attrs)*
#vis #new_fn_sig {
#[allow(non_camel_case_types)]
struct #name<#struct_generics> #struct_fields
struct #name<#(#struct_generics,)*> (#(#struct_field_types,)*);

impl<#struct_generics> ::pulp::WithSimd for #name<#(#struct_generics_lifetimes,)*
#(#struct_generics_names,)*> { type Output = #output_ty;
impl<#(#generics_decl,)*> ::pulp::WithSimd for #name<
#(#param_ty,)*
> #generics_where_clause {
type Output = #output_ty;

#[inline(always)]
fn with_simd<__S: ::pulp::Simd>(self, __simd: __S) -> <Self as
::pulp::WithSimd>::Output { let Self { #(#struct_field_names,)* } = self;
fn with_simd<__S: ::pulp::Simd>(self, __simd: __S) -> <Self as ::pulp::WithSimd>::Output {
let Self ( #(#struct_field_names,)* ) = self;
#[allow(unused_unsafe)]
unsafe {
#fn_name::<__S,
#(#struct_generics_names,)*
#(#non_lt_generics_names)*
>(__simd, #(#struct_field_names,)*)
}
}
}

(#arch).dispatch( #name::<#(#struct_generics_names,)*> { #(#struct_field_names,)* } )
(#arch).dispatch( #name ( #(#struct_field_names,)* ) )
}

#(#attrs)*
#vis #sig #block
}
.into()
};
code.into()
}
4 changes: 2 additions & 2 deletions pulp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pulp"
version = "0.18.5"
version = "0.18.6"
edition = "2021"
authors = ["sarah <>"]
description = "Safe generic simd"
Expand All @@ -10,7 +10,7 @@ license = "MIT"
keywords = ["simd"]

[dependencies]
pulp-macro = { version = "0.1.0", path = "../pulp-macro", optional = true }
pulp-macro = { version = "0.1.1", path = "../pulp-macro", optional = true }
bytemuck = "1"
num-complex = { version = "0.4.4", default-features = false, features = ["bytemuck"] }
libm = { version = "0.2", default-features = false }
Expand Down

0 comments on commit 7e96a0b

Please sign in to comment.