Skip to content

Commit

Permalink
Add WorkerThreadPool for running synchronous work in threads (#7875)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb authored and github-actions[bot] committed Jan 12, 2023
1 parent 2cec0c4 commit 0401d1e
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/prefect/_internal/concurrency/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def call_soon_in_loop(
Returns a future that can be used to retrieve the result of the call.
"""
future = call_soon_in_loop(__loop, __fn, *args, **kwargs)
return future.result()


def call_soon_in_loop(
__loop: asyncio.AbstractEventLoop,
__fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs
) -> concurrent.futures.Future:
future = concurrent.futures.Future()

@functools.wraps(__fn)
Expand Down
172 changes: 172 additions & 0 deletions src/prefect/_internal/concurrency/workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import asyncio
import contextvars
import dataclasses
import threading
import weakref
from queue import Queue
from typing import Callable, Dict, Optional, Set, Tuple, TypeVar, Union

import anyio.abc
from typing_extensions import ParamSpec

from prefect._internal.concurrency.primitives import Future

T = TypeVar("T")
P = ParamSpec("P")


@dataclasses.dataclass
class _WorkItem:
"""
A representation of work sent to a worker thread.
"""

future: Future
fn: Callable
args: Tuple
kwargs: Dict
context: contextvars.Context

def run(self):
if not self.future.set_running_or_notify_cancel():
return
try:
result = self.context.run(self.fn, *self.args, **self.kwargs)
except BaseException as exc:
self.future.set_exception(exc)
# Prevent reference cycle in `exc`
self = None
else:
self.future.set_result(result)


class _WorkerThread(threading.Thread):
def __init__(
self,
queue: "Queue[Union[_WorkItem, None]]", # Typing only supported in Python 3.9+
idle: threading.Semaphore,
name: str = None,
):
super().__init__(name=name)
self._queue = queue
self._idle = idle

def run(self) -> None:
while True:
work_item = self._queue.get()
if work_item is None:
# Shutdown command received; forward to other workers and exit
self._queue.put_nowait(None)
return

work_item.run()
self._idle.release()

del work_item


class WorkerThreadPool:
def __init__(self, max_workers: int = 40) -> None:
self._queue: "Queue[Union[_WorkItem, None]]" = Queue()
self._workers: Set[_WorkerThread] = set()
self._max_workers = max_workers
self._idle = threading.Semaphore(0)
self._lock = asyncio.Lock()
self._shutdown = False

# On garbage collection of the pool, signal shutdown to workers
weakref.finalize(self, self._queue.put_nowait, None)

async def submit(
self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> Future[T]:
"""
Submit a function to run in a worker thread.
Returns a future which can be used to retrieve the result of the function.
"""
async with self._lock:
if self._shutdown:
raise RuntimeError("Work cannot be submitted to pool after shutdown.")

future = Future()

work_item = _WorkItem(
future=future,
fn=fn,
args=args,
kwargs=kwargs,
context=contextvars.copy_context(),
)

# Place the new work item on the work queue
self._queue.put_nowait(work_item)

# Ensure there are workers available to run the work
self._adjust_worker_count()

return future

async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None):
"""
Shutdown the pool, waiting for all workers to complete before returning.
If work is submitted before shutdown, they will run to completion.
After shutdown, new work may not be submitted.
When called with `TaskGroup.start(...)`, the task will be reported as started
after signalling shutdown to workers.
"""
async with self._lock:
self._shutdown = True
self._queue.put_nowait(None)

if task_status:
task_status.started()

# Avoid blocking the event loop while waiting for threads to join by
# joining in another thread; we use a new instance of ourself to avoid
# reimplementing threaded work.
pool = WorkerThreadPool(max_workers=1)
futures = [await pool.submit(worker.join) for worker in self._workers]
await asyncio.gather(*[future.aresult() for future in futures])

self._workers.clear()

def _adjust_worker_count(self):
"""
This method should called after work is added to the queue.
If no workers are idle and the maximum worker count is not reached, add a new
worker. Otherwise, decrement the idle worker count since work as been added
to the queue and a worker will be busy.
Note on cleanup of workers:
Workers are only removed on shutdown. Workers could be shutdown after a
period of idle. However, we expect usage in Prefect to generally be
incurred in a workflow that will not have idle workers once they are
created. As long as the maximum number of workers remains relatively small,
the overhead of idle workers should be negligable.
"""
if (
# `acquire` returns false if the idle count is at zero; otherwise, it
# decrements the idle count and returns true
not self._idle.acquire(blocking=False)
and len(self._workers) < self._max_workers
):
self._add_worker()

def _add_worker(self):
worker = _WorkerThread(
queue=self._queue,
idle=self._idle,
name=f"PrefectWorker-{len(self._workers)}",
)
self._workers.add(worker)
worker.start()

async def __aenter__(self):
return self

async def __aexit__(self, *_):
await self.shutdown()
111 changes: 111 additions & 0 deletions tests/_internal/concurrency/test_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import asyncio
import time

import anyio
import pytest

from prefect._internal.concurrency.workers import WorkerThreadPool


def identity(x):
return x


async def test_submit():
async with WorkerThreadPool() as pool:
future = await pool.submit(identity, 1)
assert await future.aresult() == 1


async def test_submit_many():
async with WorkerThreadPool() as pool:
futures = [await pool.submit(identity, i) for i in range(100)]
results = await asyncio.gather(*[future.aresult() for future in futures])
assert results == list(range(100))
assert len(pool._workers) == pool._max_workers


async def test_submit_reuses_idle_thread():
async with WorkerThreadPool() as pool:
future = await pool.submit(identity, 1)
await future.aresult()

# Spin until the worker is marked as idle
with anyio.fail_after(1):
while pool._idle._value == 0:
await anyio.sleep(0)

future = await pool.submit(identity, 1)
await future.aresult()
assert len(pool._workers) == 1


async def test_submit_after_shutdown():
pool = WorkerThreadPool()
await pool.shutdown()

with pytest.raises(
RuntimeError, match="Work cannot be submitted to pool after shutdown"
):
await pool.submit(identity, 1)


async def test_submit_during_shutdown():
async with WorkerThreadPool() as pool:

async with anyio.create_task_group() as tg:
await tg.start(pool.shutdown)

with pytest.raises(
RuntimeError, match="Work cannot be submitted to pool after shutdown"
):
await pool.submit(identity, 1)


async def test_shutdown_no_workers():
pool = WorkerThreadPool()
await pool.shutdown()


async def test_shutdown_multiple_times():
pool = WorkerThreadPool()
await pool.submit(identity, 1)
await pool.shutdown()
await pool.shutdown()


async def test_shutdown_with_idle_workers():
pool = WorkerThreadPool()
futures = [await pool.submit(identity, 1) for _ in range(5)]
await asyncio.gather(*[future.aresult() for future in futures])
await pool.shutdown()


async def test_shutdown_with_active_worker():
pool = WorkerThreadPool()
future = await pool.submit(time.sleep, 1)
await pool.shutdown()
assert await future.aresult() is None


async def test_shutdown_exception_during_join():
pool = WorkerThreadPool()
future = await pool.submit(identity, 1)
await future.aresult()

try:
async with anyio.create_task_group() as tg:
await tg.start(pool.shutdown)
raise ValueError()
except ValueError:
pass

assert pool._shutdown is True


async def test_context_manager_with_outstanding_future():
async with WorkerThreadPool() as pool:
future = await pool.submit(identity, 1)

assert pool._shutdown is True
assert await future.aresult() == 1

0 comments on commit 0401d1e

Please sign in to comment.