From 4624266826aa58b1fa1d4f860570ce3df4d5b522 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Thu, 15 Dec 2022 09:51:26 -0600 Subject: [PATCH] Add `WorkerThreadPool` for running synchronous work in threads (#7875) --- .../_internal/concurrency/event_loop.py | 33 +++- .../_internal/concurrency/primitives.py | 7 +- src/prefect/_internal/concurrency/workers.py | 172 ++++++++++++++++++ tests/_internal/concurrency/test_workers.py | 111 +++++++++++ 4 files changed, 316 insertions(+), 7 deletions(-) create mode 100644 src/prefect/_internal/concurrency/workers.py create mode 100644 tests/_internal/concurrency/test_workers.py diff --git a/src/prefect/_internal/concurrency/event_loop.py b/src/prefect/_internal/concurrency/event_loop.py index 3a41bf59d25d..21763e05b585 100644 --- a/src/prefect/_internal/concurrency/event_loop.py +++ b/src/prefect/_internal/concurrency/event_loop.py @@ -25,7 +25,7 @@ def get_running_loop() -> Optional[asyncio.BaseEventLoop]: return None -def run_in_loop_thread( +def call_in_loop( __loop: asyncio.AbstractEventLoop, __fn: Callable[P, T], *args: P.args, @@ -34,6 +34,16 @@ def run_in_loop_thread( """ Run a synchronous call in event loop's thread from another thread. """ + future = call_soon_in_loop(__loop, __fn, *args, **kwargs) + return future.result() + + +def call_soon_in_loop( + __loop: asyncio.AbstractEventLoop, + __fn: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs +) -> concurrent.futures.Future: future = concurrent.futures.Future() @functools.wraps(__fn) @@ -46,4 +56,23 @@ def wrapper() -> None: raise __loop.call_soon_threadsafe(wrapper) - return future.result() + return future + + +def call_soon( + __fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> concurrent.futures.Future: + future = concurrent.futures.Future() + __loop = asyncio.get_running_loop() + + @functools.wraps(__fn) + def wrapper() -> None: + try: + future.set_result(__fn(*args, **kwargs)) + except BaseException as exc: + future.set_exception(exc) + if not isinstance(exc, Exception): + raise + + __loop.call_soon(wrapper) + return future diff --git a/src/prefect/_internal/concurrency/primitives.py b/src/prefect/_internal/concurrency/primitives.py index 84cd63fc1aab..ac470486adb2 100644 --- a/src/prefect/_internal/concurrency/primitives.py +++ b/src/prefect/_internal/concurrency/primitives.py @@ -5,10 +5,7 @@ import concurrent.futures from typing import Generic, Optional, TypeVar -from prefect._internal.concurrency.event_loop import ( - get_running_loop, - run_in_loop_thread, -) +from prefect._internal.concurrency.event_loop import call_in_loop, get_running_loop T = TypeVar("T") @@ -33,7 +30,7 @@ def set(self) -> None: self._is_set = True if self._loop: if self._loop != get_running_loop(): - run_in_loop_thread(self._loop, self._event.set) + call_in_loop(self._loop, self._event.set) else: self._event.set() diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py new file mode 100644 index 000000000000..9e98ad600f06 --- /dev/null +++ b/src/prefect/_internal/concurrency/workers.py @@ -0,0 +1,172 @@ +import asyncio +import contextvars +import dataclasses +import threading +import weakref +from queue import Queue +from typing import Callable, Dict, Optional, Set, Tuple, TypeVar, Union + +import anyio.abc +from typing_extensions import ParamSpec + +from prefect._internal.concurrency.primitives import Future + +T = TypeVar("T") +P = ParamSpec("P") + + +@dataclasses.dataclass +class _WorkItem: + """ + A representation of work sent to a worker thread. + """ + + future: Future + fn: Callable + args: Tuple + kwargs: Dict + context: contextvars.Context + + def run(self): + if not self.future.set_running_or_notify_cancel(): + return + try: + result = self.context.run(self.fn, *self.args, **self.kwargs) + except BaseException as exc: + self.future.set_exception(exc) + # Prevent reference cycle in `exc` + self = None + else: + self.future.set_result(result) + + +class _WorkerThread(threading.Thread): + def __init__( + self, + queue: "Queue[Union[_WorkItem, None]]", # Typing only supported in Python 3.9+ + idle: threading.Semaphore, + name: str = None, + ): + super().__init__(name=name) + self._queue = queue + self._idle = idle + + def run(self) -> None: + while True: + work_item = self._queue.get() + if work_item is None: + # Shutdown command received; forward to other workers and exit + self._queue.put_nowait(None) + return + + work_item.run() + self._idle.release() + + del work_item + + +class WorkerThreadPool: + def __init__(self, max_workers: int = 40) -> None: + self._queue: "Queue[Union[_WorkItem, None]]" = Queue() + self._workers: Set[_WorkerThread] = set() + self._max_workers = max_workers + self._idle = threading.Semaphore(0) + self._lock = asyncio.Lock() + self._shutdown = False + + # On garbage collection of the pool, signal shutdown to workers + weakref.finalize(self, self._queue.put_nowait, None) + + async def submit( + self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> Future[T]: + """ + Submit a function to run in a worker thread. + + Returns a future which can be used to retrieve the result of the function. + """ + async with self._lock: + if self._shutdown: + raise RuntimeError("Work cannot be submitted to pool after shutdown.") + + future = Future() + + work_item = _WorkItem( + future=future, + fn=fn, + args=args, + kwargs=kwargs, + context=contextvars.copy_context(), + ) + + # Place the new work item on the work queue + self._queue.put_nowait(work_item) + + # Ensure there are workers available to run the work + self._adjust_worker_count() + + return future + + async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): + """ + Shutdown the pool, waiting for all workers to complete before returning. + + If work is submitted before shutdown, they will run to completion. + After shutdown, new work may not be submitted. + + When called with `TaskGroup.start(...)`, the task will be reported as started + after signalling shutdown to workers. + """ + async with self._lock: + self._shutdown = True + self._queue.put_nowait(None) + + if task_status: + task_status.started() + + # Avoid blocking the event loop while waiting for threads to join by + # joining in another thread; we use a new instance of ourself to avoid + # reimplementing threaded work. + pool = WorkerThreadPool(max_workers=1) + futures = [await pool.submit(worker.join) for worker in self._workers] + await asyncio.gather(*[future.aresult() for future in futures]) + + self._workers.clear() + + def _adjust_worker_count(self): + """ + This method should called after work is added to the queue. + + If no workers are idle and the maximum worker count is not reached, add a new + worker. Otherwise, decrement the idle worker count since work as been added + to the queue and a worker will be busy. + + Note on cleanup of workers: + Workers are only removed on shutdown. Workers could be shutdown after a + period of idle. However, we expect usage in Prefect to generally be + incurred in a workflow that will not have idle workers once they are + created. As long as the maximum number of workers remains relatively small, + the overhead of idle workers should be negligable. + """ + if ( + # `acquire` returns false if the idle count is at zero; otherwise, it + # decrements the idle count and returns true + not self._idle.acquire(blocking=False) + and len(self._workers) < self._max_workers + ): + self._add_worker() + + def _add_worker(self): + worker = _WorkerThread( + queue=self._queue, + idle=self._idle, + name=f"PrefectWorker-{len(self._workers)}", + ) + self._workers.add(worker) + worker.start() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + await self.shutdown() diff --git a/tests/_internal/concurrency/test_workers.py b/tests/_internal/concurrency/test_workers.py new file mode 100644 index 000000000000..b550e4774be8 --- /dev/null +++ b/tests/_internal/concurrency/test_workers.py @@ -0,0 +1,111 @@ +import asyncio +import time + +import anyio +import pytest + +from prefect._internal.concurrency.workers import WorkerThreadPool + + +def identity(x): + return x + + +async def test_submit(): + async with WorkerThreadPool() as pool: + future = await pool.submit(identity, 1) + assert await future.aresult() == 1 + + +async def test_submit_many(): + async with WorkerThreadPool() as pool: + futures = [await pool.submit(identity, i) for i in range(100)] + results = await asyncio.gather(*[future.aresult() for future in futures]) + assert results == list(range(100)) + assert len(pool._workers) == pool._max_workers + + +async def test_submit_reuses_idle_thread(): + async with WorkerThreadPool() as pool: + future = await pool.submit(identity, 1) + await future.aresult() + + # Spin until the worker is marked as idle + with anyio.fail_after(1): + while pool._idle._value == 0: + await anyio.sleep(0) + + future = await pool.submit(identity, 1) + await future.aresult() + assert len(pool._workers) == 1 + + +async def test_submit_after_shutdown(): + pool = WorkerThreadPool() + await pool.shutdown() + + with pytest.raises( + RuntimeError, match="Work cannot be submitted to pool after shutdown" + ): + await pool.submit(identity, 1) + + +async def test_submit_during_shutdown(): + async with WorkerThreadPool() as pool: + + async with anyio.create_task_group() as tg: + await tg.start(pool.shutdown) + + with pytest.raises( + RuntimeError, match="Work cannot be submitted to pool after shutdown" + ): + await pool.submit(identity, 1) + + +async def test_shutdown_no_workers(): + pool = WorkerThreadPool() + await pool.shutdown() + + +async def test_shutdown_multiple_times(): + pool = WorkerThreadPool() + await pool.submit(identity, 1) + await pool.shutdown() + await pool.shutdown() + + +async def test_shutdown_with_idle_workers(): + pool = WorkerThreadPool() + futures = [await pool.submit(identity, 1) for _ in range(5)] + await asyncio.gather(*[future.aresult() for future in futures]) + await pool.shutdown() + + +async def test_shutdown_with_active_worker(): + pool = WorkerThreadPool() + future = await pool.submit(time.sleep, 1) + await pool.shutdown() + assert await future.aresult() is None + + +async def test_shutdown_exception_during_join(): + pool = WorkerThreadPool() + future = await pool.submit(identity, 1) + await future.aresult() + + try: + async with anyio.create_task_group() as tg: + await tg.start(pool.shutdown) + raise ValueError() + except ValueError: + pass + + assert pool._shutdown is True + + +async def test_context_manager_with_outstanding_future(): + async with WorkerThreadPool() as pool: + future = await pool.submit(identity, 1) + + assert pool._shutdown is True + assert await future.aresult() == 1