Skip to content

Commit

Permalink
tests mostly passing
Browse files Browse the repository at this point in the history
  • Loading branch information
alecmocatta committed Aug 17, 2020
1 parent 9177e0b commit dfc009f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 64 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ wasm-bindgen-test = "0.3"
[build-dependencies]
rustversion = "1.0"

[patch.crates-io]
tokio = {git = "https://github.com/tokio-rs/tokio", branch = "v0.2.x"}

[profile.bench]
codegen-units = 1
debug = 2
Expand Down
112 changes: 48 additions & 64 deletions src/pool/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ impl ThreadPool {
T: Send + 'a,
{
#[cfg(not(target_arch = "wasm32"))]
return Guard::new(
self.0
.pool
.spawn_pinned_unchecked(task)
.map_err(JoinError::into_panic)
.map_err(Panicked::from),
);
return self
.0
.pool
.spawn_pinned_unchecked(task)
.map_err(JoinError::into_panic)
.map_err(Panicked::from);
#[cfg(target_arch = "wasm32")]
{
let _self = self;
Expand All @@ -104,10 +103,10 @@ impl ThreadPool {
.map_err(Into::into)
.remote_handle();
wasm_bindgen_futures::spawn_local(remote);
Guard::new(remote_handle.map_ok(|t| {
remote_handle.map_ok(|t| {
let t: *mut dyn Send = Box::into_raw(t);
*Box::from_raw(t as *mut T)
}))
})
}
}
}
Expand All @@ -125,39 +124,6 @@ impl Clone for ThreadPool {
impl UnwindSafe for ThreadPool {}
impl RefUnwindSafe for ThreadPool {}

#[pin_project(PinnedDrop)]
struct Guard<F>(#[pin] Option<F>);
impl<F> Guard<F> {
fn new(f: F) -> Self {
Self(Some(f))
}
}
impl<F> Future for Guard<F>
where
F: Future,
{
type Output = F::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.as_mut().project().0.as_pin_mut() {
Some(fut) => {
let output = ready!(fut.poll(cx));
self.project().0.set(None);
Poll::Ready(output)
}
None => Poll::Pending,
}
}
}
#[pinned_drop]
impl<F> PinnedDrop for Guard<F> {
fn drop(self: Pin<&mut Self>) {
if self.project().0.is_some() {
panic!("dropped before finished polling!");
}
}
}

fn _assert() {
let _ = assert_sync_and_send::<ThreadPool>;
}
Expand All @@ -166,41 +132,47 @@ fn _assert() {
#[cfg(not(target_arch = "wasm32"))]
mod pool {
use async_channel::{bounded, Sender};
use futures::{future::RemoteHandle, FutureExt};
use futures::{
future::{join_all, RemoteHandle}, FutureExt
};
use std::{any::Any, future::Future, mem, panic::AssertUnwindSafe, pin::Pin};
use tokio::{
runtime::Handle, task::{JoinError, LocalSet}
runtime::Handle, task, task::{JoinError, JoinHandle, LocalSet}
};

type Request = Box<dyn FnOnce() -> Box<dyn Future<Output = Response>> + Send>;
type Response = Result<Box<dyn Any + Send>, Box<dyn Any + Send>>;

#[derive(Debug)]
pub(super) struct Pool {
sender: Sender<(Request, Sender<RemoteHandle<Response>>)>,
sender: Option<Sender<(Request, Sender<RemoteHandle<Response>>)>>,
threads: Vec<JoinHandle<()>>,
}
impl Pool {
pub(super) fn new(threads: usize) -> Self {
let handle = Handle::current();
let handle1 = handle.clone();
let (sender, receiver) = bounded::<(Request, Sender<RemoteHandle<Response>>)>(1);
for _ in 0..threads {
let receiver = receiver.clone();
let handle = handle.clone();
let _ = handle1.spawn_blocking(move || {
let local = LocalSet::new();
handle.block_on(local.run_until(async {
while let Ok((task, sender)) = receiver.recv().await {
let _ = local.spawn_local(async move {
let (remote, remote_handle) = Pin::from(task()).remote_handle();
let _ = sender.send(remote_handle).await;
remote.await;
});
}
}))
});
}
Self { sender }
let threads = (0..threads)
.map(|_| {
let receiver = receiver.clone();
let handle = handle.clone();
handle1.spawn_blocking(move || {
let local = LocalSet::new();
handle.block_on(local.run_until(async {
while let Ok((task, sender)) = receiver.recv().await {
let _ = local.spawn_local(async move {
let (remote, remote_handle) = Pin::from(task()).remote_handle();
let _ = sender.send(remote_handle).await;
remote.await;
});
}
}))
})
})
.collect();
let sender = Some(sender);
Self { sender, threads }
}
pub(super) fn spawn_pinned<F, Fut, T>(
&self, task: F,
Expand All @@ -210,7 +182,7 @@ mod pool {
Fut: Future<Output = T> + 'static,
T: Send + 'static,
{
let sender = self.sender.clone();
let sender = self.sender.as_ref().unwrap().clone();
async move {
let task: Request = Box::new(|| {
Box::new(
Expand All @@ -236,7 +208,7 @@ mod pool {
Fut: Future<Output = T> + 'a,
T: Send + 'a,
{
let sender = self.sender.clone();
let sender = self.sender.as_ref().unwrap().clone();
async move {
let task: Box<dyn FnOnce() -> Box<dyn Future<Output = Response>> + Send> =
Box::new(|| {
Expand Down Expand Up @@ -264,6 +236,18 @@ mod pool {
}
}
}
impl Drop for Pool {
fn drop(&mut self) {
let _ = self.sender.take().unwrap();
task::block_in_place(|| {
let handle = Handle::current();
handle.block_on(join_all(mem::take(&mut self.threads)))
})
.into_iter()
.collect::<Result<(), _>>()
.unwrap();
}
}

#[cfg(test)]
mod tests {
Expand Down

0 comments on commit dfc009f

Please sign in to comment.