Skip to content

Commit

Permalink
Revert PR "async runner" (#1352)
Browse files Browse the repository at this point in the history
Revert "review changes to tests and server"
Revert "delete remaining runner thread code :)"
Revert "make tests async and fix them"
Revert "have runner return asyncio.Task instead of AsyncFuture"

This reverts commit b002d54.
This reverts commit 087f482.
This reverts commit 6729d53.
This reverts commit dc5ef44.

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Jan 16, 2024
1 parent b37e961 commit 015e16f
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 89 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ optional-dependencies = { "dev" = [
"pillow",
"pyright==1.1.345",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-rerunfailures",
"pytest-xdist",
Expand Down
20 changes: 9 additions & 11 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ async def root() -> Any:

@app.get("/health-check")
async def healthcheck() -> Any:
await _check_setup_task()
_check_setup_result()
if app.state.health == Health.READY:
health = Health.BUSY if runner.is_busy() else Health.READY
else:
Expand All @@ -236,7 +236,7 @@ async def predict(request: PredictionRequest = Body(default=None), prefer: Union
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return await _predict(request=request, respond_async=respond_async)
return _predict(request=request, respond_async=respond_async)

@limited
@app.put(
Expand Down Expand Up @@ -271,10 +271,10 @@ async def predict_idempotent(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return await _predict(request=request, respond_async=respond_async)
return _predict(request=request, respond_async=respond_async)

async def _predict(
*, request: Optional[PredictionRequest], respond_async: bool = False
def _predict(
*, request: PredictionRequest, respond_async: bool = False
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
# with empty input. This will throw a ValidationError if that's not
Expand Down Expand Up @@ -302,8 +302,7 @@ async def _predict(
return JSONResponse(jsonable_encoder(initial_response), status_code=202)

try:
prediction = await async_result
response = PredictionResponse(**prediction.dict())
response = PredictionResponse(**async_result.get().dict())
except ValidationError as e:
_log_invalid_output(e)
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down Expand Up @@ -339,15 +338,14 @@ async def start_shutdown() -> Any:
shutdown_event.set()
return JSONResponse({}, status_code=200)

async def _check_setup_task() -> Any:
def _check_setup_result() -> Any:
if app.state.setup_task is None:
return

if not app.state.setup_task.done():
if not app.state.setup_task.ready():
return

# this can raise CancelledError
result = app.state.setup_task.result()
result = app.state.setup_task.get()

if result.status == schema.Status.SUCCEEDED:
app.state.health = Health.READY
Expand Down
95 changes: 55 additions & 40 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import io
import sys
import threading
import traceback
import typing # TypeAlias, py3.10
from datetime import datetime, timezone
from multiprocessing.pool import AsyncResult, ThreadPool
from typing import Any, Callable, Optional, Tuple, Union, cast

import requests
Expand Down Expand Up @@ -44,8 +45,11 @@ class SetupResult:
status: schema.Status


PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]"
PredictionTask: "typing.TypeAlias" = "AsyncResult[schema.PredictionResponse]"
SetupTask: "typing.TypeAlias" = "AsyncResult[SetupResult]"
if sys.version_info < (3, 9):
PredictionTask = AsyncResult
SetupTask = AsyncResult
RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask]


Expand All @@ -57,37 +61,38 @@ def __init__(
shutdown_event: Optional[threading.Event],
upload_url: Optional[str] = None,
) -> None:
self._thread = None
self._threadpool = ThreadPool(processes=1)

self._response: Optional[schema.PredictionResponse] = None
self._result: Optional[RunnerTask] = None

self._worker = Worker(predictor_ref=predictor_ref)
self._should_cancel = asyncio.Event()
self._should_cancel = threading.Event()

self._shutdown_event = shutdown_event
self._upload_url = upload_url

def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
def handle_error(task: RunnerTask) -> None:
exc = task.exception()
if not exc:
return
def setup(self) -> SetupTask:
if self.is_busy():
raise RunnerBusyError()

def handle_error(error: BaseException) -> None:
# Re-raise the exception in order to more easily capture exc_info,
# and then trigger shutdown, as we have no easy way to resume
# worker state if an exception was thrown.
try:
raise exc
raise error
except Exception:
log.error(f"caught exception while running {activity}", exc_info=True)
log.error("caught exception while running setup", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

return handle_error

def setup(self) -> SetupTask:
if self.is_busy():
raise RunnerBusyError()
self._result = asyncio.create_task(setup(worker=self._worker))
self._result.add_done_callback(self.make_error_handler("setup"))
self._result = self._threadpool.apply_async(
func=setup,
kwds={"worker": self._worker},
error_callback=handle_error,
)
return self._result

# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
Expand Down Expand Up @@ -116,39 +121,52 @@ def predict(
upload_url = self._upload_url if upload else None
event_handler = create_event_handler(prediction, upload_url=upload_url)

def handle_cleanup(_: PredictionTask) -> None:
def cleanup(_: schema.PredictionResponse = None) -> None:
input = cast(Any, prediction.input)
if hasattr(input, "cleanup"):
input.cleanup()

def handle_error(error: BaseException) -> None:
# Re-raise the exception in order to more easily capture exc_info,
# and then trigger shutdown, as we have no easy way to resume
# worker state if an exception was thrown.
try:
raise error
except Exception:
log.error("caught exception while running prediction", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()

self._response = event_handler.response
coro = predict(
worker=self._worker,
request=prediction,
event_handler=event_handler,
should_cancel=self._should_cancel,
self._result = self._threadpool.apply_async(
func=predict,
kwds={
"worker": self._worker,
"request": prediction,
"event_handler": event_handler,
"should_cancel": self._should_cancel,
},
callback=cleanup,
error_callback=handle_error,
)
self._result = asyncio.create_task(coro)
self._result.add_done_callback(handle_cleanup)
self._result.add_done_callback(self.make_error_handler("prediction"))

return (self._response, self._result)

def is_busy(self) -> bool:
if self._result is None:
return False

if not self._result.done():
if not self._result.ready():
return True

self._response = None
self._result = None
return False

def shutdown(self) -> None:
if self._result:
self._result.cancel()
self._worker.terminate()
self._threadpool.terminate()
self._threadpool.join()

def cancel(self, prediction_id: Optional[str] = None) -> None:
if not self.is_busy():
Expand Down Expand Up @@ -290,15 +308,13 @@ def _upload_files(self, output: Any) -> Any:
raise FileUploadError("Got error trying to upload output files") from error


async def setup(*, worker: Worker) -> SetupResult:
def setup(*, worker: Worker) -> SetupResult:
logs = []
status = None
started_at = datetime.now(tz=timezone.utc)

try:
# will be async
for event in worker.setup():
await asyncio.sleep(0)
if isinstance(event, Log):
logs.append(event.message)
elif isinstance(event, Done):
Expand Down Expand Up @@ -328,19 +344,19 @@ async def setup(*, worker: Worker) -> SetupResult:
)


async def predict(
def predict(
*,
worker: Worker,
request: schema.PredictionRequest,
event_handler: PredictionEventHandler,
should_cancel: asyncio.Event,
should_cancel: threading.Event,
) -> schema.PredictionResponse:
# Set up logger context within prediction thread.
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(prediction_id=request.id)

try:
return await _predict(
return _predict(
worker=worker,
request=request,
event_handler=event_handler,
Expand All @@ -353,12 +369,12 @@ async def predict(
raise


async def _predict(
def _predict(
*,
worker: Worker,
request: schema.PredictionRequest,
event_handler: PredictionEventHandler,
should_cancel: asyncio.Event,
should_cancel: threading.Event,
) -> schema.PredictionResponse:
initial_prediction = request.dict()

Expand All @@ -375,9 +391,8 @@ async def _predict(
event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
# will be async

for event in worker.predict(input_dict, poll=0.1):
await asyncio.sleep(0)
if should_cancel.is_set():
worker.cancel()
should_cancel.clear()
Expand Down
3 changes: 1 addition & 2 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def _wait(
if send_heartbeats:
yield Heartbeat()
continue
# this needs aioprocessing.Pipe or similar
# multiprocessing.Pipe is not async

ev = self._events.recv()
yield ev

Expand Down
Loading

0 comments on commit 015e16f

Please sign in to comment.