-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
WorkerThreadPool
for running synchronous work in threads (#7875)
- Loading branch information
Showing
4 changed files
with
316 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import asyncio | ||
import contextvars | ||
import dataclasses | ||
import threading | ||
import weakref | ||
from queue import Queue | ||
from typing import Callable, Dict, Optional, Set, Tuple, TypeVar, Union | ||
|
||
import anyio.abc | ||
from typing_extensions import ParamSpec | ||
|
||
from prefect._internal.concurrency.primitives import Future | ||
|
||
T = TypeVar("T") | ||
P = ParamSpec("P") | ||
|
||
|
||
@dataclasses.dataclass | ||
class _WorkItem: | ||
""" | ||
A representation of work sent to a worker thread. | ||
""" | ||
|
||
future: Future | ||
fn: Callable | ||
args: Tuple | ||
kwargs: Dict | ||
context: contextvars.Context | ||
|
||
def run(self): | ||
if not self.future.set_running_or_notify_cancel(): | ||
return | ||
try: | ||
result = self.context.run(self.fn, *self.args, **self.kwargs) | ||
except BaseException as exc: | ||
self.future.set_exception(exc) | ||
# Prevent reference cycle in `exc` | ||
self = None | ||
else: | ||
self.future.set_result(result) | ||
|
||
|
||
class _WorkerThread(threading.Thread): | ||
def __init__( | ||
self, | ||
queue: "Queue[Union[_WorkItem, None]]", # Typing only supported in Python 3.9+ | ||
idle: threading.Semaphore, | ||
name: str = None, | ||
): | ||
super().__init__(name=name) | ||
self._queue = queue | ||
self._idle = idle | ||
|
||
def run(self) -> None: | ||
while True: | ||
work_item = self._queue.get() | ||
if work_item is None: | ||
# Shutdown command received; forward to other workers and exit | ||
self._queue.put_nowait(None) | ||
return | ||
|
||
work_item.run() | ||
self._idle.release() | ||
|
||
del work_item | ||
|
||
|
||
class WorkerThreadPool: | ||
def __init__(self, max_workers: int = 40) -> None: | ||
self._queue: "Queue[Union[_WorkItem, None]]" = Queue() | ||
self._workers: Set[_WorkerThread] = set() | ||
self._max_workers = max_workers | ||
self._idle = threading.Semaphore(0) | ||
self._lock = asyncio.Lock() | ||
self._shutdown = False | ||
|
||
# On garbage collection of the pool, signal shutdown to workers | ||
weakref.finalize(self, self._queue.put_nowait, None) | ||
|
||
async def submit( | ||
self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs | ||
) -> Future[T]: | ||
""" | ||
Submit a function to run in a worker thread. | ||
Returns a future which can be used to retrieve the result of the function. | ||
""" | ||
async with self._lock: | ||
if self._shutdown: | ||
raise RuntimeError("Work cannot be submitted to pool after shutdown.") | ||
|
||
future = Future() | ||
|
||
work_item = _WorkItem( | ||
future=future, | ||
fn=fn, | ||
args=args, | ||
kwargs=kwargs, | ||
context=contextvars.copy_context(), | ||
) | ||
|
||
# Place the new work item on the work queue | ||
self._queue.put_nowait(work_item) | ||
|
||
# Ensure there are workers available to run the work | ||
self._adjust_worker_count() | ||
|
||
return future | ||
|
||
async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): | ||
""" | ||
Shutdown the pool, waiting for all workers to complete before returning. | ||
If work is submitted before shutdown, they will run to completion. | ||
After shutdown, new work may not be submitted. | ||
When called with `TaskGroup.start(...)`, the task will be reported as started | ||
after signalling shutdown to workers. | ||
""" | ||
async with self._lock: | ||
self._shutdown = True | ||
self._queue.put_nowait(None) | ||
|
||
if task_status: | ||
task_status.started() | ||
|
||
# Avoid blocking the event loop while waiting for threads to join by | ||
# joining in another thread; we use a new instance of ourself to avoid | ||
# reimplementing threaded work. | ||
pool = WorkerThreadPool(max_workers=1) | ||
futures = [await pool.submit(worker.join) for worker in self._workers] | ||
await asyncio.gather(*[future.aresult() for future in futures]) | ||
|
||
self._workers.clear() | ||
|
||
def _adjust_worker_count(self): | ||
""" | ||
This method should called after work is added to the queue. | ||
If no workers are idle and the maximum worker count is not reached, add a new | ||
worker. Otherwise, decrement the idle worker count since work as been added | ||
to the queue and a worker will be busy. | ||
Note on cleanup of workers: | ||
Workers are only removed on shutdown. Workers could be shutdown after a | ||
period of idle. However, we expect usage in Prefect to generally be | ||
incurred in a workflow that will not have idle workers once they are | ||
created. As long as the maximum number of workers remains relatively small, | ||
the overhead of idle workers should be negligable. | ||
""" | ||
if ( | ||
# `acquire` returns false if the idle count is at zero; otherwise, it | ||
# decrements the idle count and returns true | ||
not self._idle.acquire(blocking=False) | ||
and len(self._workers) < self._max_workers | ||
): | ||
self._add_worker() | ||
|
||
def _add_worker(self): | ||
worker = _WorkerThread( | ||
queue=self._queue, | ||
idle=self._idle, | ||
name=f"PrefectWorker-{len(self._workers)}", | ||
) | ||
self._workers.add(worker) | ||
worker.start() | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, *_): | ||
await self.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import asyncio | ||
import time | ||
|
||
import anyio | ||
import pytest | ||
|
||
from prefect._internal.concurrency.workers import WorkerThreadPool | ||
|
||
|
||
def identity(x): | ||
return x | ||
|
||
|
||
async def test_submit(): | ||
async with WorkerThreadPool() as pool: | ||
future = await pool.submit(identity, 1) | ||
assert await future.aresult() == 1 | ||
|
||
|
||
async def test_submit_many(): | ||
async with WorkerThreadPool() as pool: | ||
futures = [await pool.submit(identity, i) for i in range(100)] | ||
results = await asyncio.gather(*[future.aresult() for future in futures]) | ||
assert results == list(range(100)) | ||
assert len(pool._workers) == pool._max_workers | ||
|
||
|
||
async def test_submit_reuses_idle_thread(): | ||
async with WorkerThreadPool() as pool: | ||
future = await pool.submit(identity, 1) | ||
await future.aresult() | ||
|
||
# Spin until the worker is marked as idle | ||
with anyio.fail_after(1): | ||
while pool._idle._value == 0: | ||
await anyio.sleep(0) | ||
|
||
future = await pool.submit(identity, 1) | ||
await future.aresult() | ||
assert len(pool._workers) == 1 | ||
|
||
|
||
async def test_submit_after_shutdown(): | ||
pool = WorkerThreadPool() | ||
await pool.shutdown() | ||
|
||
with pytest.raises( | ||
RuntimeError, match="Work cannot be submitted to pool after shutdown" | ||
): | ||
await pool.submit(identity, 1) | ||
|
||
|
||
async def test_submit_during_shutdown(): | ||
async with WorkerThreadPool() as pool: | ||
|
||
async with anyio.create_task_group() as tg: | ||
await tg.start(pool.shutdown) | ||
|
||
with pytest.raises( | ||
RuntimeError, match="Work cannot be submitted to pool after shutdown" | ||
): | ||
await pool.submit(identity, 1) | ||
|
||
|
||
async def test_shutdown_no_workers(): | ||
pool = WorkerThreadPool() | ||
await pool.shutdown() | ||
|
||
|
||
async def test_shutdown_multiple_times(): | ||
pool = WorkerThreadPool() | ||
await pool.submit(identity, 1) | ||
await pool.shutdown() | ||
await pool.shutdown() | ||
|
||
|
||
async def test_shutdown_with_idle_workers(): | ||
pool = WorkerThreadPool() | ||
futures = [await pool.submit(identity, 1) for _ in range(5)] | ||
await asyncio.gather(*[future.aresult() for future in futures]) | ||
await pool.shutdown() | ||
|
||
|
||
async def test_shutdown_with_active_worker(): | ||
pool = WorkerThreadPool() | ||
future = await pool.submit(time.sleep, 1) | ||
await pool.shutdown() | ||
assert await future.aresult() is None | ||
|
||
|
||
async def test_shutdown_exception_during_join(): | ||
pool = WorkerThreadPool() | ||
future = await pool.submit(identity, 1) | ||
await future.aresult() | ||
|
||
try: | ||
async with anyio.create_task_group() as tg: | ||
await tg.start(pool.shutdown) | ||
raise ValueError() | ||
except ValueError: | ||
pass | ||
|
||
assert pool._shutdown is True | ||
|
||
|
||
async def test_context_manager_with_outstanding_future(): | ||
async with WorkerThreadPool() as pool: | ||
future = await pool.submit(identity, 1) | ||
|
||
assert pool._shutdown is True | ||
assert await future.aresult() == 1 |