Skip to content

Commit

Permalink
async runner (#1352)
Browse files Browse the repository at this point in the history
* have runner return asyncio.Task instead of AsyncFuture
* make tests async and fix them
* delete remaining runner thread code :)
* review changes to tests and server

(reverts commit 828eee9)

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Jul 25, 2024
1 parent eae972d commit ddc2339
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 89 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [
"pillow",
"pyright==1.1.347",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-rerunfailures",
"pytest-xdist",
Expand Down
21 changes: 11 additions & 10 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def root() -> Any:

@app.get("/health-check")
async def healthcheck() -> Any:
_check_setup_result()
await _check_setup_task()
if app.state.health == Health.READY:
health = Health.BUSY if runner.is_busy() else Health.READY
else:
Expand Down Expand Up @@ -291,7 +291,7 @@ async def predict(
with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
request=request,
respond_async=respond_async,
respond_async=respond_async
)

@limited
Expand Down Expand Up @@ -335,10 +335,9 @@ async def predict_idempotent(
respond_async=respond_async,
)

def _predict(
*,
request: Optional[PredictionRequest],
respond_async: bool = False,

async def _predict(
*, request: Optional[PredictionRequest], respond_async: bool = False
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
# with empty input. This will throw a ValidationError if that's not
Expand Down Expand Up @@ -367,7 +366,8 @@ def _predict(
return JSONResponse(jsonable_encoder(initial_response), status_code=202)

try:
response = PredictionResponse(**async_result.get().dict())
prediction = await async_result
response = PredictionResponse(**prediction.dict())
except ValidationError as e:
_log_invalid_output(e)
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down Expand Up @@ -396,14 +396,15 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any:
else:
return JSONResponse({}, status_code=200)

def _check_setup_result() -> Any:
async def _check_setup_task() -> Any:
if app.state.setup_task is None:
return

if not app.state.setup_task.ready():
if not app.state.setup_task.done():
return

result = app.state.setup_task.get()
# this can raise CancelledError
result = app.state.setup_task.result()

if result.status == schema.Status.SUCCEEDED:
app.state.health = Health.READY
Expand Down
94 changes: 40 additions & 54 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import asyncio
import io
import sys
import threading
import traceback
import typing # TypeAlias, py3.10
from datetime import datetime, timezone
from multiprocessing.pool import AsyncResult, ThreadPool
from typing import Any, Callable, Optional, Tuple, Union, cast

import requests
Expand Down Expand Up @@ -46,11 +45,8 @@ class SetupResult:
status: schema.Status


PredictionTask: "typing.TypeAlias" = "AsyncResult[schema.PredictionResponse]"
SetupTask: "typing.TypeAlias" = "AsyncResult[SetupResult]"
if sys.version_info < (3, 9):
PredictionTask = AsyncResult
SetupTask = AsyncResult
PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]"
RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask]


Expand All @@ -62,38 +58,37 @@ def __init__(
shutdown_event: Optional[threading.Event],
upload_url: Optional[str] = None,
) -> None:
self._thread = None
self._threadpool = ThreadPool(processes=1)

self._response: Optional[schema.PredictionResponse] = None
self._result: Optional[RunnerTask] = None

self._worker = Worker(predictor_ref=predictor_ref)
self._should_cancel = threading.Event()
self._should_cancel = asyncio.Event()

self._shutdown_event = shutdown_event
self._upload_url = upload_url

def setup(self) -> SetupTask:
if self.is_busy():
raise RunnerBusyError()

def handle_error(error: BaseException) -> None:
def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
def handle_error(task: RunnerTask) -> None:
exc = task.exception()
if not exc:
return
# Re-raise the exception in order to more easily capture exc_info,
# and then trigger shutdown, as we have no easy way to resume
# worker state if an exception was thrown.
try:
raise error
raise exc
except Exception:
log.error("caught exception while running setup", exc_info=True)
log.error(f"caught exception while running {activity}", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

self._result = self._threadpool.apply_async(
func=setup,
kwds={"worker": self._worker},
error_callback=handle_error,
)
return handle_error

def setup(self) -> SetupTask:
if self.is_busy():
raise RunnerBusyError()
self._result = asyncio.create_task(setup(worker=self._worker))
self._result.add_done_callback(self.make_error_handler("setup"))
return self._result

# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
Expand Down Expand Up @@ -127,52 +122,39 @@ def predict(
upload_url=upload_url,
)

def cleanup(_: Optional[schema.PredictionResponse] = None) -> None:
def handle_cleanup(_: Optional[schema.PredictionResponse] = None) -> None:
input = cast(Any, prediction.input)
if hasattr(input, "cleanup"):
input.cleanup()

def handle_error(error: BaseException) -> None:
# Re-raise the exception in order to more easily capture exc_info,
# and then trigger shutdown, as we have no easy way to resume
# worker state if an exception was thrown.
try:
raise error
except Exception:
log.error("caught exception while running prediction", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

self._response = event_handler.response
self._result = self._threadpool.apply_async(
func=predict,
kwds={
"worker": self._worker,
"request": prediction,
"event_handler": event_handler,
"should_cancel": self._should_cancel,
},
callback=cleanup,
error_callback=handle_error,
coro = predict(
worker=self._worker,
request=prediction,
event_handler=event_handler,
should_cancel=self._should_cancel,
)
self._result = asyncio.create_task(coro)
self._result.add_done_callback(handle_cleanup)
self._result.add_done_callback(self.make_error_handler("prediction"))

return (self._response, self._result)

def is_busy(self) -> bool:
if self._result is None:
return False

if not self._result.ready():
if not self._result.done():
return True

self._response = None
self._result = None
return False

def shutdown(self) -> None:
if self._result:
self._result.cancel()
self._worker.terminate()
self._threadpool.terminate()
self._threadpool.join()

def cancel(self, prediction_id: Optional[str] = None) -> None:
if not self.is_busy():
Expand Down Expand Up @@ -318,13 +300,15 @@ def _upload_files(self, output: Any) -> Any:
raise FileUploadError("Got error trying to upload output files") from error


def setup(*, worker: Worker) -> SetupResult:
async def setup(*, worker: Worker) -> SetupResult:
logs = []
status = None
started_at = datetime.now(tz=timezone.utc)

try:
# will be async
for event in worker.setup():
await asyncio.sleep(0)
if isinstance(event, Log):
logs.append(event.message)
elif isinstance(event, Done):
Expand Down Expand Up @@ -354,19 +338,19 @@ def setup(*, worker: Worker) -> SetupResult:
)


def predict(
async def predict(
*,
worker: Worker,
request: schema.PredictionRequest,
event_handler: PredictionEventHandler,
should_cancel: threading.Event,
should_cancel: asyncio.Event,
) -> schema.PredictionResponse:
# Set up logger context within prediction thread.
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(prediction_id=request.id)

try:
return _predict(
return await _predict(
worker=worker,
request=request,
event_handler=event_handler,
Expand All @@ -379,12 +363,12 @@ def predict(
raise


def _predict(
async def _predict(
*,
worker: Worker,
request: schema.PredictionRequest,
event_handler: PredictionEventHandler,
should_cancel: threading.Event,
should_cancel: asyncio.Event,
) -> schema.PredictionResponse:
initial_prediction = request.dict()

Expand All @@ -408,7 +392,9 @@ def _predict(
log.warn("Failed to download url path from input", exc_info=True)
return event_handler.response

# will be async
for event in worker.predict(input_dict, poll=0.1):
await asyncio.sleep(0)
if should_cancel.is_set():
worker.cancel()
should_cancel.clear()
Expand Down
3 changes: 2 additions & 1 deletion python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def _wait(
if send_heartbeats:
yield Heartbeat()
continue

# this needs aioprocessing.Pipe or similar
# multiprocessing.Pipe is not async
ev = self._events.recv()
yield ev

Expand Down
Loading

0 comments on commit ddc2339

Please sign in to comment.