Skip to content

Commit

Permalink
Add Sync support (#40)
Browse files Browse the repository at this point in the history
Add `Sync` argument which marks the returned `Future` as `Sync`
  • Loading branch information
dcchut authored Mar 17, 2024
1 parent f01f486 commit 2d0d43b
Show file tree
Hide file tree
Showing 21 changed files with 343 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: rust-src # required for consistent error messages
- run: cargo install cargo-expand
- run: cargo test --verbose

Expand Down
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,35 @@ async fn fib(n : u32) -> u32 {
}
```

## ?Send Option
## ?Send option

The returned future has a `Send` bound to make sure it can be sent between threads.
The returned `Future` has a `Send` bound to make sure it can be sent between threads.
If this is undesirable you can mark that the bound should be left out like so:

```rust
#[async_recursion(?Send)]
async fn example() {
async fn returned_future_is_not_send() {
// ...
}
```

In detail:
## Sync option

The returned `Future` doesn't have a `Sync` bound as it is usually not required.
You can include a `Sync` bound as follows:

```rust
#[async_recursion(Sync)]
async fn returned_future_is_sync() {
// ...
}
```

- `#[async_recursion]` modifies your function to return a [`BoxFuture`], and
- `#[async_recursion(?Send)]` modifies your function to return a [`LocalBoxFuture`].
In detail:

[`BoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.BoxFuture.html
[`LocalBoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.LocalBoxFuture.html
- `#[async_recursion]` modifies your function to return a boxed `Future` with a `Send` bound.
- `#[async_recursion(?Send)]` modifies your function to return a boxed `Future` _without_ a `Send` bound.
- `#[async_recursion(Sync)]` modifies your function to return a boxed `Future` with a `Send` and `Sync` bound.

### License

Expand Down
8 changes: 7 additions & 1 deletion src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
quote!()
};

let sync_bound: TokenStream = if args.sync_bound {
quote!(+ ::core::marker::Sync)
} else {
quote!()
};

let where_clause = sig
.generics
.where_clause
Expand All @@ -196,6 +202,6 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
// Modify the return type
sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound >>
dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound #sync_bound>>
};
}
27 changes: 20 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,40 @@
//! }
//! ```
//!
//! ## ?Send Option
//! ## ?Send option
//!
//! The returned future has a [`Send`] bound to make sure it can be sent between threads.
//! The returned [`Future`] has a [`Send`] bound to make sure it can be sent between threads.
//! If this is undesirable you can mark that the bound should be left out like so:
//!
//! ```rust
//! # use async_recursion::async_recursion;
//!
//! #[async_recursion(?Send)]
//! async fn example() {
//! async fn returned_future_is_not_send() {
//! // ...
//! }
//! ```
//!
//! ## Sync option
//!
//! The returned [`Future`] doesn't have a [`Sync`] bound as it is usually not required.
//! You can include a [`Sync`] bound as follows:
//!
//! ```rust
//! # use async_recursion::async_recursion;
//!
//! #[async_recursion(Sync)]
//! async fn returned_future_is_send_and_sync() {
//! // ...
//! }
//! ```
//!
//! In detail:
//!
//! - `#[async_recursion]` modifies your function to return a [`BoxFuture`], and
//! - `#[async_recursion(?Send)]` modifies your function to return a [`LocalBoxFuture`].
//!
//! [`BoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.BoxFuture.html
//! [`LocalBoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.LocalBoxFuture.html
//! - `#[async_recursion]` modifies your function to return a boxed [`Future`] with a [`Send`] bound.
//! - `#[async_recursion(?Send)]` modifies your function to return a boxed [`Future`] _without_ a [`Send`] bound.
//! - `#[async_recursion(Sync)]` modifies your function to return a boxed [`Future`] with [`Send`] and [`Sync`] bounds.
//!
//! ### License
//!
Expand Down
68 changes: 56 additions & 12 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,74 @@ impl Parse for AsyncItem {

pub struct RecursionArgs {
pub send_bound: bool,
}

impl Default for RecursionArgs {
fn default() -> Self {
RecursionArgs { send_bound: true }
}
pub sync_bound: bool,
}

/// Custom keywords for parser
mod kw {
syn::custom_keyword!(Send);
syn::custom_keyword!(Sync);
}

impl Parse for RecursionArgs {
#[derive(Debug, PartialEq, Eq)]
enum Arg {
NotSend,
Sync,
}

impl std::fmt::Display for Arg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotSend => write!(f, "?Send"),
Self::Sync => write!(f, "Sync"),
}
}
}

impl Parse for Arg {
fn parse(input: ParseStream) -> Result<Self> {
// Check for the `?Send` option
if input.peek(Token![?]) {
input.parse::<Question>()?;
input.parse::<kw::Send>()?;
Ok(Self { send_bound: false })
} else if !input.is_empty() {
Err(input.error("expected `?Send` or empty"))
Ok(Arg::NotSend)
} else {
Ok(Self::default())
input.parse::<kw::Sync>()?;
Ok(Arg::Sync)
}
}
}

impl Parse for RecursionArgs {
fn parse(input: ParseStream) -> Result<Self> {
let mut send_bound: bool = true;
let mut sync_bound: bool = false;

let args_parsed: Vec<Arg> =
syn::punctuated::Punctuated::<Arg, syn::Token![,]>::parse_terminated(input)
.map_err(|e| input.error(format!("failed to parse macro arguments: {e}")))?
.into_iter()
.collect();

// Avoid sloppy input
if args_parsed.len() > 2 {
return Err(Error::new(Span::call_site(), "received too many arguments"));
} else if args_parsed.len() == 2 && args_parsed[0] == args_parsed[1] {
return Err(Error::new(
Span::call_site(),
format!("received duplicate argument: `{}`", args_parsed[0]),
));
}

for arg in args_parsed {
match arg {
Arg::NotSend => send_bound = false,
Arg::Sync => sync_bound = true,
}
}

Ok(Self {
send_bound,
sync_bound,
})
}
}
11 changes: 11 additions & 0 deletions tests/args_sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use async_recursion::async_recursion;

#[async_recursion(Sync)]
async fn send_and_sync() {}

fn assert_is_send_and_sync(_: impl Send + Sync) {}

#[test]
fn test_sync_argument() {
assert_is_send_and_sync(send_and_sync());
}
5 changes: 5 additions & 0 deletions tests/expand/args_not_send.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use async_recursion::async_recursion;
#[must_use]
fn no_send_bound() -> ::core::pin::Pin<Box<dyn ::core::future::Future<Output = ()>>> {
Box::pin(async move {})
}
4 changes: 4 additions & 0 deletions tests/expand/args_not_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
use async_recursion::async_recursion;

#[async_recursion(?Send)]
async fn no_send_bound() {}
25 changes: 25 additions & 0 deletions tests/expand/args_punctuated.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use async_recursion::async_recursion;
#[must_use]
fn not_send_sync_1() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
#[must_use]
fn not_send_sync_2() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
#[must_use]
fn sync_not_send_1() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
#[must_use]
fn sync_not_send_2() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
13 changes: 13 additions & 0 deletions tests/expand/args_punctuated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use async_recursion::async_recursion;

#[async_recursion(?Send, Sync)]
async fn not_send_sync_1() {}

#[async_recursion(?Send,Sync)]
async fn not_send_sync_2() {}

#[async_recursion(Sync, ?Send)]
async fn sync_not_send_1() {}

#[async_recursion(Sync,?Send)]
async fn sync_not_send_2() {}
11 changes: 11 additions & 0 deletions tests/expand/args_sync.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use async_recursion::async_recursion;
#[must_use]
fn sync() -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<
Output = (),
> + ::core::marker::Send + ::core::marker::Sync,
>,
> {
Box::pin(async move {})
}
4 changes: 4 additions & 0 deletions tests/expand/args_sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
use async_recursion::async_recursion;

#[async_recursion(Sync)]
async fn sync() {}
19 changes: 19 additions & 0 deletions tests/expand/lifetimes_explicit_async_recursion_bound.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use async_recursion::async_recursion;
#[must_use]
fn explicit_async_recursion_bound<'life0, 'life1, 'async_recursion>(
t: &'life0 T,
p: &'life1 [String],
prefix: Option<&'async_recursion [u8]>,
layer: Option<&'async_recursion [u8]>,
) -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<Output = ()> + 'async_recursion + ::core::marker::Send,
>,
>
where
'life0: 'async_recursion,
'life1: 'async_recursion,
'async_recursion: 'async_recursion,
{
Box::pin(async move {})
}
12 changes: 12 additions & 0 deletions tests/expand/lifetimes_explicit_async_recursion_bound.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Test that an explicit `async_recursion bound is left alone.
// This is a workaround many
use async_recursion::async_recursion;


#[async_recursion]
async fn explicit_async_recursion_bound(
t: &T,
p: &[String],
prefix: Option<&'async_recursion [u8]>,
layer: Option<&'async_recursion [u8]>,
) {}
4 changes: 4 additions & 0 deletions tests/lifetimes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ async fn count_down(foo: Option<&str>) -> i32 {
0
}

#[async_recursion]
async fn explicit_async_recursion_bound(_: Option<&'async_recursion String>) {}

#[test]
fn lifetime_expansion_works() {
block_on(async move {
Expand Down Expand Up @@ -73,5 +76,6 @@ fn lifetime_expansion_works() {
assert_eq!(contains_value_2(&12, &node).await, false);

count_down(None).await;
explicit_async_recursion_bound(None).await;
});
}
15 changes: 15 additions & 0 deletions tests/ui/arg_not_sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use async_recursion::async_recursion;

fn assert_is_sync(_: impl Sync) {}


#[async_recursion]
async fn send_not_sync() {}

#[async_recursion(?Send)]
async fn not_send_not_sync() {}

fn main() {
assert_is_sync(send_not_sync());
assert_is_sync(not_send_not_sync());
}
Loading

0 comments on commit 2d0d43b

Please sign in to comment.