diff --git a/pkg/config/config.go b/pkg/config/config.go index 76af8a5a7..49e0f698f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -57,16 +57,21 @@ type Build struct { pythonRequirementsContent []string } +type Concurrency struct { + Max int `json:"max,omitempty" yaml:"max"` +} + type Example struct { Input map[string]string `json:"input" yaml:"input"` Output string `json:"output" yaml:"output"` } type Config struct { - Build *Build `json:"build" yaml:"build"` - Image string `json:"image,omitempty" yaml:"image"` - Predict string `json:"predict,omitempty" yaml:"predict"` - Train string `json:"train,omitempty" yaml:"train"` + Build *Build `json:"build" yaml:"build"` + Image string `json:"image,omitempty" yaml:"image"` + Predict string `json:"predict,omitempty" yaml:"predict"` + Train string `json:"train,omitempty" yaml:"train"` + Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"` } func DefaultConfig() *Config { diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json index 958fd6889..48ae781a3 100644 --- a/pkg/config/data/config_schema_v1.0.json +++ b/pkg/config/data/config_schema_v1.0.json @@ -154,11 +154,6 @@ "$id": "#/properties/concurrency/properties/max", "type": "integer", "description": "The maximum number of concurrent predictions." - }, - "default_target": { - "$id": "#/properties/concurrency/properties/default_target", - "type": "integer", - "description": "The default target for number of concurrent predictions. This setting can be used by an autoscaler to determine when to scale a deployment of a model up or down." } } } diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 89322debb..c491c031e 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -147,6 +147,24 @@ "summary": "Healthcheck" } }, + "/ready": { + "get": { + "summary": "Ready", + "operationId": "ready_ready_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "title": "Response Ready Ready Get" + } + } + } + } + } + } + }, "/predictions": { "post": { "description": "Run a single prediction on the model", diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a0fa937bf..b31d7ff17 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -71,6 +71,13 @@ def predict(self, **kwargs: Any) -> Any: """ pass + def log(self, *messages: str) -> None: + """ + Write a log message that will be tagged with the current prediction + even during concurrent predictions. At runtime this method is overriden. + """ + print(*messages) + def run_setup(predictor: BasePredictor) -> None: weights = get_weights_argument(predictor) diff --git a/python/cog/server/eventtypes.py b/python/cog/server/eventtypes.py index a1df38fff..ee3e68997 100644 --- a/python/cog/server/eventtypes.py +++ b/python/cog/server/eventtypes.py @@ -20,6 +20,11 @@ def from_request(cls, request: schema.PredictionRequest) -> "PredictionInput": return cls(payload=payload, id=request.id) +@define +class Cancel: + id: str + + @define class Shutdown: pass diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 062d8d2a1..9a8ed4b26 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -26,11 +26,10 @@ import attrs import structlog import uvicorn -from fastapi import Body, FastAPI, Header, HTTPException, Path, Response +from fastapi import Body, FastAPI, Header, Path, Response from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from pydantic import ValidationError from pydantic.error_wrappers import ErrorWrapper from .. import schema @@ -133,10 +132,13 @@ async def start_shutdown() -> Any: add_setup_failed_routes(app, started_at, msg) return app + concurrency = config.get("concurrency", {}).get("max", "1") + runner = PredictionRunner( predictor_ref=predictor_ref, shutdown_event=shutdown_event, upload_url=upload_url, + concurrency=int(concurrency), ) class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): @@ -261,7 +263,22 @@ async def healthcheck() -> Any: else: health = app.state.health setup = attrs.asdict(app.state.setup_result) if app.state.setup_result else {} - return jsonable_encoder({"status": health.name, "setup": setup}) + activity = runner.activity_info() + return jsonable_encoder( + {"status": health.name, "setup": setup, "concurrency": activity} + ) + + # this is a readiness probe, it only returns 200 when work can be accepted + @app.get("/ready") + async def ready() -> Any: + activity = runner.activity_info() + if runner.is_busy(): + return JSONResponse( + {"status": "ready", "activity": activity}, status_code=200 + ) + return JSONResponse( + {"status": "not ready", "activity": activity}, status_code=503 + ) @limited @app.post( @@ -348,27 +365,23 @@ async def shared_predict( if respond_async: return JSONResponse(jsonable_encoder(initial_response), status_code=202) - # by now, output Path and File are already converted to str - # so when we validate the schema, those urls get cast back to Path and File - # in the previous implementation those would then get encoded as strings - # however the changes to Path and File break this and return the filename instead - try: - prediction = await async_result - # we're only doing this to catch validation errors - response = PredictionResponse(**prediction.dict()) - del response - except ValidationError as e: - _log_invalid_output(e) - raise HTTPException(status_code=500, detail=str(e)) from e - - # dict_resp = response.dict() - # output = await runner.client_manager.upload_files( - # dict_resp["output"], upload_url - # ) - # dict_resp["output"] = output - # encoded_response = jsonable_encoder(dict_resp) - - # return *prediction* and not *response* to preserve urls + # # by now, output Path and File are already converted to str + # # so when we validate the schema, those urls get cast back to Path and File + # # in the previous implementation those would then get encoded as strings + # # however the changes to Path and File break this and return the filename instead + # + # # moreover, validating outputs can be a bottleneck with enough volume + # # since it's not strictly needed, we can comment it out + # try: + # prediction = await async_result + # # we're only doing this to catch validation errors + # response = PredictionResponse(**prediction.dict()) + # del response + # except ValidationError as e: + # _log_invalid_output(e) + # raise HTTPException(status_code=500, detail=str(e)) from e + + prediction = await async_result encoded_response = jsonable_encoder(prediction.dict()) return JSONResponse(content=encoded_response) @@ -377,8 +390,7 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any: """ Cancel a running prediction """ - if not runner.is_busy(): - return JSONResponse({}, status_code=404) + # no need to check whether or not we're busy try: runner.cancel(prediction_id) except UnknownPredictionError: diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 07ff48762..ada6c10dc 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -1,10 +1,16 @@ import asyncio +import contextlib +import logging import multiprocessing +import os +import signal +import sys import threading import traceback import typing # TypeAlias, py3.10 from datetime import datetime, timezone -from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Tuple, Union, cast +from enum import Enum, auto, unique +from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union import httpx import structlog @@ -12,7 +18,9 @@ from .. import schema, types from .clients import SKIP_START_EVENT, ClientManager +from .connection import AsyncConnection from .eventtypes import ( + Cancel, Done, Heartbeat, Log, @@ -20,9 +28,14 @@ PredictionOutput, PredictionOutputType, PublicEventType, + Shutdown, +) +from .exceptions import ( + FatalWorkerException, + InvalidStateException, ) from .probes import ProbeHelper -from .worker import Worker +from .worker import Mux, _ChildWorker log = structlog.get_logger("cog.server.runner") _spawn = multiprocessing.get_context("spawn") @@ -40,6 +53,16 @@ class UnknownPredictionError(Exception): pass +@unique +class WorkerState(Enum): + NEW = auto() + STARTING = auto() + IDLE = auto() + PROCESSING = auto() + BUSY = auto() + DEFUNCT = auto() + + @define class SetupResult: started_at: datetime @@ -47,12 +70,20 @@ class SetupResult: logs: str status: schema.Status + # TODO: maybe collect events into a result here + PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]" SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]" RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask] +# TODO: we might prefer to move this back to worker +# runner would still need to do PredictionEventHandler +# if it's not inline, we would need to make sure {enter,exit}_predict is handled correctly +# this is a major outstanding piece of work for merging into main + + class PredictionRunner: def __init__( self, @@ -60,21 +91,103 @@ def __init__( predictor_ref: str, shutdown_event: Optional[threading.Event], upload_url: Optional[str] = None, + concurrency: int = 1, + tee_output: bool = True, ) -> None: - self._response: Optional[schema.PredictionResponse] = None - self._result: Optional[RunnerTask] = None + self._shutdown_event = shutdown_event # __main__ waits for this event - self._worker = Worker(predictor_ref=predictor_ref) - - self._shutdown_event = shutdown_event self._upload_url = upload_url + self._predictions: "dict[str, tuple[schema.PredictionResponse, PredictionTask]]" = {} + self._predictions_in_flight: "set[str]" = set() + # it would be lovely to merge these but it's not fully clear how best to handle it + # since idempotent requests can kinda come whenever? + # p: dict[str, PredictionTask] + # p: dict[str, PredictionEventHandler] + # p: dict[str, schema.PredictionResponse] self.client_manager = ClientManager() + # TODO: perhaps this could go back into worker, if we could get the interface right + # (unclear how to do the tests) + # + self._state = WorkerState.NEW + self._semaphore = asyncio.Semaphore(concurrency) + self._concurrency = concurrency + + # A pipe with which to communicate with the child worker. + events, child_events = _spawn.Pipe() + self._child = _ChildWorker(predictor_ref, child_events, tee_output) + self._events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection( + events + ) + # shutdown requested + self._shutting_down = False + # stop reading events + self._terminating = asyncio.Event() + self._mux = Mux(self._terminating) + # # bind logger instead of the module-level logger proxy for performance self.log = log.bind() - def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]: + def activity_info(self) -> "dict[str, int]": + return {"max": self._concurrency, "current": len(self._predictions_in_flight)} + + def setup(self) -> SetupTask: + if self._state != WorkerState.NEW: + raise RunnerBusyError + self._state = WorkerState.STARTING + + # app is allowed to respond to requests and poll the state of this task + # while it is running + async def inner() -> SetupResult: + logs = [] + status = None + started_at = datetime.now(tz=timezone.utc) + + # in 3.10 Event started doing get_running_loop + # previously it stored the loop when created, which causes an error in tests + if sys.version_info < (3, 10): + self._terminating = self._mux.terminating = asyncio.Event() + + self._child.start() + await self._events.async_init() + self._start_event_reader() + + try: + async for event in self._mux.read("SETUP", poll=0.1): + if isinstance(event, Log): + logs.append(event.message) + elif isinstance(event, Done): + if event.error: + raise FatalWorkerException( + "Predictor errored during setup: " + event.error_detail + ) + status = schema.Status.FAILED + else: + status = schema.Status.SUCCEEDED + self._state = WorkerState.IDLE + except Exception: + logs.append(traceback.format_exc()) + status = schema.Status.FAILED + + if status is None: + logs.append("Error: did not receive 'done' event from setup!") + status = schema.Status.FAILED + + completed_at = datetime.now(tz=timezone.utc) + + # Only if setup succeeded, mark the container as "ready". + if status == schema.Status.SUCCEEDED: + probes = ProbeHelper() + probes.ready() + + return SetupResult( + started_at=started_at, + completed_at=completed_at, + logs="".join(logs), + status=status, + ) + def handle_error(task: RunnerTask) -> None: exc = task.exception() if not exc: @@ -85,34 +198,64 @@ def handle_error(task: RunnerTask) -> None: try: raise exc except Exception: - self.log.error(f"caught exception while running {activity}", exc_info=True) + self.log.error("caught exception while running setup", exc_info=True) if self._shutdown_event is not None: self._shutdown_event.set() - return handle_error + result = asyncio.create_task(inner()) + result.add_done_callback(handle_error) + return result + + def state_from_predictions_in_flight(self) -> WorkerState: + valid_states = {WorkerState.IDLE, WorkerState.PROCESSING, WorkerState.BUSY} + if self._state not in valid_states: + raise InvalidStateException( + f"Invalid operation: state is {self._state} (must be IDLE, PROCESSING, or BUSY)" + ) + if len(self._predictions_in_flight) == self._concurrency: + return WorkerState.BUSY + if len(self._predictions_in_flight) == 0: + return WorkerState.IDLE + return WorkerState.PROCESSING - def setup(self) -> SetupTask: + def is_busy(self) -> bool: + return self._state not in {WorkerState.PROCESSING, WorkerState.IDLE} + + def enter_predict(self, id: str) -> None: if self.is_busy(): - raise RunnerBusyError() - self._result = asyncio.create_task(setup(worker=self._worker)) - self._result.add_done_callback(self.make_error_handler("setup")) - return self._result + raise InvalidStateException( + f"Invalid operation: state is {self._state} (must be processing or idle)" + ) + if self._shutting_down: + raise InvalidStateException( + "cannot accept new predictions because shutdown requested" + ) + self.log.info( + "accepted prediction %s in flight %s", id, self._predictions_in_flight + ) + self._predictions_in_flight.add(id) + self._state = self.state_from_predictions_in_flight() + + def exit_predict(self, id: str) -> None: + self._predictions_in_flight.remove(id) + self._state = self.state_from_predictions_in_flight() + + @contextlib.contextmanager + def prediction_ctx(self, id: str) -> Iterator[None]: + self.enter_predict(id) + try: + yield + finally: + self.exit_predict(id) # TODO: Make the return type AsyncResult[schema.PredictionResponse] when we # no longer have to support Python 3.8 def predict( - self, request: schema.PredictionRequest, upload: bool = True - ) -> Tuple[schema.PredictionResponse, PredictionTask]: - # It's the caller's responsibility to not call us if we're busy. + self, request: schema.PredictionRequest, poll: Optional[float] = None + ) -> "tuple[schema.PredictionResponse, PredictionTask]": if self.is_busy(): - # If self._result is set, but self._response is not, we're still - # doing setup. - if self._response is None: - raise RunnerBusyError() - assert self._result is not None - if request.id is not None and request.id == self._response.id: # type: ignore - result = cast(PredictionTask, self._result) - return (self._response, result) + if request.id in self._predictions: + return self._predictions[request.id] raise RunnerBusyError() # Set up logger context for main thread. The same thing happens inside @@ -122,10 +265,18 @@ def predict( # if upload url was not set, we can respect output_file_prefix # but maybe we should just throw an error upload_url = request.output_file_prefix or self._upload_url - event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log) - self._response = event_handler.response + # this is supposed to send START, but we're trapped in a sync function + # this sends START in a task, which calls jsonable_encoder on the input, + # which calls iter(io.BytesIO) with data uris that are File + # that breaks one of the tests, but happens Rarely in production, + # so let's ignore it for now + event_handler = PredictionEventHandler( + request, self.client_manager, upload_url, self.log + ) + response = event_handler.response prediction_input = PredictionInput.from_request(request) + self.enter_predict(request.id) async def async_predict_handling_errors() -> schema.PredictionResponse: try: @@ -137,9 +288,11 @@ async def async_predict_handling_errors() -> schema.PredictionResponse: if isinstance(v, types.URLTempFile): real_path = await v.convert(self.client_manager.download_client) prediction_input.payload[k] = real_path - event_stream = self._worker.predict(prediction_input.payload, poll=0.1) - result = await event_handler.handle_event_stream(event_stream) - return result + async with self._semaphore: + self._events.send(prediction_input) + event_stream = self._mux.read(prediction_input.id, poll=poll) + result = await event_handler.handle_event_stream(event_stream) + return result except httpx.HTTPError as e: tb = traceback.format_exc() await event_handler.append_logs(tb) @@ -150,44 +303,110 @@ async def async_predict_handling_errors() -> schema.PredictionResponse: tb = traceback.format_exc() await event_handler.append_logs(tb) await event_handler.failed(error=str(e)) - self.log.error("caught exception while running prediction", exc_info=True) + self.log.error( + "caught exception while running prediction", exc_info=True + ) if self._shutdown_event is not None: self._shutdown_event.set() raise # we don't actually want to raise anymore but w/e finally: + # mark the prediction as done and update state + # ... actually, we might want to mark that part earlier + # even if we're still uploading files we can accept new work + self.exit_predict(prediction_input.id) # FIXME: use isinstance(BaseInput) if hasattr(request.input, "cleanup"): request.input.cleanup() # type: ignore + # this might also, potentially, be too early + # since this is just before this coroutine exits + self._predictions.pop(request.id) # this is still a little silly - self._result = asyncio.create_task(async_predict_handling_errors()) - self._result.add_done_callback(self.make_error_handler("prediction")) + result = asyncio.create_task(async_predict_handling_errors()) + # result.add_done_callback(self.make_error_handler("prediction")) # even after inlining we might still need a callback to surface remaining exceptions/results - return (self._response, self._result) + self._predictions[request.id] = (response, result) - def is_busy(self) -> bool: - if self._result is None: - return False - - if not self._result.done(): - return True - - self._response = None - self._result = None - return False + return (response, result) def shutdown(self) -> None: - if self._result: - self._result.cancel() - self._worker.terminate() + if self._state == WorkerState.DEFUNCT: + return + # shutdown requested, but keep reading events + self._shutting_down = True - def cancel(self, prediction_id: Optional[str] = None) -> None: - if not self.is_busy(): + if self._child.is_alive(): + self._events.send(Shutdown()) + + def terminate(self) -> None: + for _, task in self._predictions.values(): + task.cancel() + if self._state == WorkerState.DEFUNCT: return - assert self._response is not None - if prediction_id is not None and prediction_id != self._response.id: + + self._terminating.set() + self._state = WorkerState.DEFUNCT + + if self._child.is_alive(): + self._child.terminate() + self._child.join() + self._events.close() + + if self._read_events_task: + self._read_events_task.cancel() + + def cancel(self, prediction_id: str) -> None: + if prediction_id not in self._predictions_in_flight: + self.log.warn( + "can't cancel %s (%s)", prediction_id, self._predictions_in_flight + ) raise UnknownPredictionError() - self._worker.cancel() + if os.getenv("COG_DISABLE_CANCEL"): + self.log.warn("cancelling is disabled for this model") + return + maybe_pid = self._child.pid + if self._child.is_alive() and maybe_pid is not None: + # since we don't know if the predictor is sync or async, we both send + # the signal (honored only if sync) and the event (honored only if async) + os.kill(maybe_pid, signal.SIGUSR1) + self.log.info("sent cancel") + self._events.send(Cancel(prediction_id)) + # maybe this should probably check self._semaphore._value == self._concurrent + + _read_events_task: "Optional[asyncio.Task[None]]" = None + + def _start_event_reader(self) -> None: + def handle_error(task: "asyncio.Task[None]") -> None: + if task.cancelled(): + return + exc = task.exception() + if exc: + logging.error("caught exception", exc_info=exc) + + if not self._read_events_task: + self._read_events_task = asyncio.create_task(self._read_events()) + self._read_events_task.add_done_callback(handle_error) + + async def _read_events(self) -> None: + while self._child.is_alive() and not self._terminating.is_set(): + # in tests this can still be running when the task is destroyed + result = await self._events.recv() + id, event = result + if id == "LOG" and self._state == WorkerState.STARTING: + id = "SETUP" + if id == "LOG" and len(self._predictions_in_flight) == 1: + id = list(self._predictions_in_flight)[0] + await self._mux.write(id, event) + # If we dropped off the end off the end of the loop, check if it's + # because the child process died. + if not self._child.is_alive() and not self._terminating.is_set(): + exitcode = self._child.exitcode + self._mux.fatal = FatalWorkerException( + f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})" + ) + # this is the same event as self._terminating + # we need to set it so mux.reads wake up and throw an error if needed + self._mux.terminating.set() class PredictionEventHandler: @@ -219,8 +438,7 @@ def __init__( # latency (this guarantees that the first output webhook won't be # throttled.) if not SKIP_START_EVENT: - # idk - # this is pretty wrong + # sending it in a coroutine is kind of wrong in some ways asyncio.create_task(self._send_webhook(schema.WebhookEvent.START)) @property @@ -309,6 +527,7 @@ def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]: return self.noop() if isinstance(event, Log): return self.append_logs(event.message) + if isinstance(event, PredictionOutputType): if self._output_type is not None: return self.failed(error="Predictor returned unexpected output") @@ -328,41 +547,5 @@ def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]: if event.error: return self.failed(error=str(event.error_detail)) return self.succeeded() - log.warn("received unexpected event from worker", data=event) + self.logger.warn("received unexpected event from worker", data=event) return self.noop() - - -async def setup(*, worker: Worker) -> SetupResult: - logs = [] - status = None - started_at = datetime.now(tz=timezone.utc) - - try: - async for event in worker.setup(): - if isinstance(event, Log): - logs.append(event.message) - elif isinstance(event, Done): - status = ( - schema.Status.FAILED if event.error else schema.Status.SUCCEEDED - ) - except Exception: - logs.append(traceback.format_exc()) - status = schema.Status.FAILED - - if status is None: - logs.append("Error: did not receive 'done' event from setup!") - status = schema.Status.FAILED - - completed_at = datetime.now(tz=timezone.utc) - - # Only if setup succeeded, mark the container as "ready". - if status == schema.Status.SUCCEEDED: - probes = ProbeHelper() - probes.ready() - - return SetupResult( - started_at=started_at, - completed_at=completed_at, - logs="".join(logs), - status=status, - ) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 941690c1d..1beea2a6e 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -1,17 +1,16 @@ import asyncio import contextlib import inspect -import logging import multiprocessing -import os import signal import sys import traceback import types from collections import defaultdict +from contextvars import ContextVar from enum import Enum, auto, unique from multiprocessing.connection import Connection -from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, TextIO, Union +from typing import Any, AsyncIterator, Callable, Iterator, Optional, TextIO from ..json import make_encodeable from ..predictor import ( @@ -23,25 +22,24 @@ ) from .connection import AsyncConnection from .eventtypes import ( + Cancel, Done, Heartbeat, Log, PredictionInput, PredictionOutput, PredictionOutputType, + PublicEventType, Shutdown, ) from .exceptions import ( CancelationException, FatalWorkerException, - InvalidStateException, ) from .helpers import StreamRedirector, WrappedStream _spawn = multiprocessing.get_context("spawn") -_PublicEventType = Union[Done, Heartbeat, Log, PredictionOutput, PredictionOutputType] - @unique class WorkerState(Enum): @@ -55,18 +53,18 @@ class WorkerState(Enum): class Mux: def __init__(self, terminating: asyncio.Event) -> None: - self.outs: "defaultdict[str, asyncio.Queue[_PublicEventType]]" = defaultdict( + self.outs: "defaultdict[str, asyncio.Queue[PublicEventType]]" = defaultdict( asyncio.Queue ) self.terminating = terminating self.fatal: "Optional[FatalWorkerException]" = None - async def write(self, id: str, item: _PublicEventType) -> None: + async def write(self, id: str, item: PublicEventType) -> None: await self.outs[id].put(item) async def read( self, id: str, poll: Optional[float] = None - ) -> AsyncIterator[_PublicEventType]: + ) -> AsyncIterator[PublicEventType]: if poll: send_heartbeats = True else: @@ -87,187 +85,6 @@ async def read( raise self.fatal -class Worker: - def __init__( - self, predictor_ref: str, tee_output: bool = True, concurrent: int = 1 - ) -> None: - self._state = WorkerState.NEW - self._allow_cancel = False - self._semaphore = asyncio.Semaphore(concurrent) - self._concurrent = concurrent - - # A pipe with which to communicate with the child worker. - events, child_events = _spawn.Pipe() - self._child = _ChildWorker(predictor_ref, child_events, tee_output) - self._events: "AsyncConnection[tuple[str, _PublicEventType]]" = AsyncConnection( - events - ) - # shutdown requested - self._shutting_down = False - # stop reading events - self._terminating = asyncio.Event() - self._mux = Mux(self._terminating) - self._predictions_in_flight = set() - - def setup(self) -> AsyncIterator[_PublicEventType]: - self._assert_state(WorkerState.NEW) - self._state = WorkerState.STARTING - - async def inner() -> AsyncIterator[_PublicEventType]: - # in 3.10 Event started doing get_running_loop - # previously it stored the loop when created, which causes an error in tests - if sys.version_info < (3, 10): - self._terminating = self._mux.terminating = asyncio.Event() - - self._child.start() - self._ensure_event_reader() - async for event in self._mux.read("SETUP", poll=0.1): - yield event - if isinstance(event, Done): - if event.error: - raise FatalWorkerException( - "Predictor errored during setup: " + event.error_detail - ) - self._state = WorkerState.IDLE - - return inner() - - def state_from_predictions_in_flight(self) -> WorkerState: - valid_states = {WorkerState.IDLE, WorkerState.PROCESSING, WorkerState.BUSY} - if self._state not in valid_states: - raise InvalidStateException( - f"Invalid operation: state is {self._state} (must be IDLE, PROCESSING, or BUSY)" - ) - # changing _allow_cancel like this is a little bit weird - # because this is kind of a pure function - # however, we'll remove all of this cancel logic soon, dwabi - if len(self._predictions_in_flight) == self._concurrent: - self._allow_cancel = True - return WorkerState.BUSY - if len(self._predictions_in_flight) == 0: - self._allow_cancel = False - return WorkerState.IDLE - self._allow_cancel = True - return WorkerState.PROCESSING - - def is_busy(self) -> bool: - return self._state not in {WorkerState.PROCESSING, WorkerState.IDLE} - - @contextlib.asynccontextmanager - async def _prediction_ctx(self, input: PredictionInput) -> AsyncIterator[None]: - async with self._semaphore: - self._predictions_in_flight.add(input.id) # idempotent ig - self._state = self.state_from_predictions_in_flight() - try: - yield - finally: - self._predictions_in_flight.remove(input.id) - self._state = self.state_from_predictions_in_flight() - - def predict( - self, payload: Dict[str, Any], poll: Optional[float] = None - ) -> AsyncIterator[_PublicEventType]: - # this has to be eager for hypothesis - if self.is_busy(): - raise InvalidStateException( - f"Invalid operation: state is {self._state} (must be processing or idle)" - ) - if self._shutting_down: - raise InvalidStateException( - "cannot accept new predictions because shutdown requested" - ) - input = PredictionInput(payload=payload) - self._predictions_in_flight.add(input.id) # idempotent ig - self._state = self.state_from_predictions_in_flight() - - async def inner() -> AsyncIterator[_PublicEventType]: - async with self._prediction_ctx(input): - self._events.send(input) - async for event in self._mux.read(input.id, poll=poll): - yield event - - return inner() - - def shutdown(self) -> None: - if self._state == WorkerState.DEFUNCT: - return - # shutdown requested, but keep reading events - self._shutting_down = True - - if self._child.is_alive(): - self._events.send(Shutdown()) - - def terminate(self) -> None: - if self._state == WorkerState.DEFUNCT: - return - - self._terminating.set() - self._state = WorkerState.DEFUNCT - - if self._child.is_alive(): - self._child.terminate() - self._child.join() - if self._read_events_task: - self._read_events_task.cancel() - self._events.close() - - # FIXME: this will need to use a combination - # of signals and Cancel events on the pipe - def cancel(self) -> None: - if ( - self._allow_cancel - and self._child.is_alive() - and self._child.pid is not None - ): - os.kill(self._child.pid, signal.SIGUSR1) - # this should probably check self._semaphore._value == self._concurrent - self._allow_cancel = False - - def _assert_state(self, state: WorkerState) -> None: - if self._state != state: - raise InvalidStateException( - f"Invalid operation: state is {self._state} (must be {state})" - ) - - _read_events_task: "Optional[asyncio.Task[None]]" = None - - def _ensure_event_reader(self) -> None: - def handle_error(task: "asyncio.Task[None]") -> None: - if task.cancelled(): - return - exc = task.exception() - if exc: - logging.error("caught exception", exc_info=exc) - - if not self._read_events_task: - self._read_events_task = asyncio.create_task(self._read_events()) - self._read_events_task.add_done_callback(handle_error) - - async def _read_events(self) -> None: - while self._child.is_alive() and not self._terminating.is_set(): - # this can still be running when the task is destroyed - result = await self._events.recv() # this might be kind of risky - # event loop closed or child died - if result is None: # type: ignore - break - id, event = result - if id == "LOG" and self._state == WorkerState.STARTING: - id = "SETUP" - if id == "LOG" and len(self._predictions_in_flight) == 1: - id = list(self._predictions_in_flight)[0] - await self._mux.write(id, event) - # If we dropped off the end off the end of the loop, check if it's - # because the child process died. - if not self._child.is_alive() and not self._terminating.is_set(): - exitcode = self._child.exitcode - self._mux.fatal = FatalWorkerException( - f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})" - ) - # this is the same event as _terminating - # we need to set it so mux.reads wake up and throw an error if needed - self._mux.terminating.set() - - class _ChildWorker(_spawn.Process): # type: ignore def __init__( self, @@ -292,6 +109,9 @@ def run(self) -> None: # We use SIGUSR1 to signal an interrupt for cancelation. signal.signal(signal.SIGUSR1, self._signal_handler) + self.prediction_id_context: ContextVar[str] = ContextVar("prediction_context") + + # ws_stdout = WrappedStream("stdout", sys.stdout) ws_stderr = WrappedStream("stderr", sys.stderr) ws_stdout.wrap() @@ -303,6 +123,8 @@ def run(self) -> None: [ws_stdout, ws_stderr], self._stream_write_hook ) self._stream_redirector.start() + # + self._setup() self._loop() self._stream_redirector.shutdown() @@ -312,6 +134,7 @@ def _setup(self) -> None: with self._handle_setup_error(): # we need to load the predictor to know if setup is async self._predictor = load_predictor_from_ref(self._predictor_ref) + self._predictor.log = self._log # if users want to access the same event loop from setup and predict, # both have to be async. if setup isn't async, it doesn't matter if we # create the event loop here or after setup @@ -355,26 +178,38 @@ def _loop_sync(self) -> None: break if isinstance(ev, PredictionInput): self._predict_sync(ev) + elif isinstance(ev, Cancel): + # in sync mode, Cancel events are ignored + # only signals are respected + pass else: print(f"Got unexpected event: {ev}", file=sys.stderr) async def _loop_async(self) -> None: - events: "AsyncConnection[tuple[str, _PublicEventType]]" = AsyncConnection( + events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection( self._events ) - while True: - try: - ev = await events.recv() - except asyncio.CancelledError: - return - if isinstance(ev, Shutdown): - return - if isinstance(ev, PredictionInput): - # keep track of these so they can be cancelled - await self._predict_async(ev) - # handle Cancel - else: - print(f"Got unexpected event: {ev}", file=sys.stderr) + with events: + tasks: "dict[str, asyncio.Task[None]]" = {} + while True: + try: + ev = await events.recv() + except asyncio.CancelledError: + return + if isinstance(ev, Shutdown): + return + if isinstance(ev, PredictionInput): + # keep track of these so they can be cancelled + tasks[ev.id] = asyncio.create_task(self._predict_async(ev)) + elif isinstance(ev, Cancel): + # in async mode, cancel signals are ignored + # only Cancel events are ignored + if ev.id in tasks: + tasks[ev.id].cancel() + else: + print(f"Got unexpected cancellation: {ev}", file=sys.stderr) + else: + print(f"Got unexpected event: {ev}", file=sys.stderr) def _loop(self) -> None: if is_async(get_predict(self._predictor)): @@ -387,21 +222,25 @@ def _handle_predict_error(self, id: str) -> Iterator[None]: assert self._predictor done = Done() self._cancelable = True + token = self.prediction_id_context.set(id) try: yield except CancelationException: done.canceled = True + except asyncio.CancelledError: + done.canceled = True except Exception as e: traceback.print_exc() done.error = True done.error_detail = str(e) finally: + self.prediction_id_context.reset(token) self._cancelable = False self._stream_redirector.drain() self._events.send((id, done)) - def _mk_send(self, id: str) -> Callable[[_PublicEventType], None]: - def send(event: _PublicEventType) -> None: + def _mk_send(self, id: str) -> Callable[[PublicEventType], None]: + def send(event: PublicEventType) -> None: self._events.send((id, event)) return send @@ -417,9 +256,9 @@ async def _predict_async(self, input: PredictionInput) -> None: async for r in result: send(PredictionOutput(payload=make_encodeable(r))) elif inspect.isawaitable(result): - result = await result + output = await result send(PredictionOutputType(multi=False)) - send(PredictionOutput(payload=make_encodeable(result))) + send(PredictionOutput(payload=make_encodeable(output))) def _predict_sync(self, input: PredictionInput) -> None: with self._handle_predict_error(input.id): @@ -436,16 +275,26 @@ def _predict_sync(self, input: PredictionInput) -> None: send(PredictionOutput(payload=make_encodeable(result))) def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None: + if self._predictor and is_async(get_predict(self._predictor)): + # we could try also canceling the async task around here + # but for now in async mode signals are ignored + return + # this logic might need to be refined if signum == signal.SIGUSR1 and self._cancelable: raise CancelationException() + def _log(self, *messages: str, source: str = "stderr") -> None: + id = self.prediction_id_context.get("LOG") + self._events.send((id, Log(" ".join(messages), source=source))) + def _stream_write_hook( self, stream_name: str, original_stream: TextIO, data: str ) -> None: if self._tee_output: original_stream.write(data) original_stream.flush() - self._events.send(("LOG", Log(data, source=stream_name))) + # this won't work, this fn gets called from a thread, not the async task + self._log(data, stream_name) def get_loop() -> asyncio.AbstractEventLoop: diff --git a/python/tests/server/fixtures/async_sleep.py b/python/tests/server/fixtures/async_sleep.py new file mode 100644 index 000000000..7c5f73407 --- /dev/null +++ b/python/tests/server/fixtures/async_sleep.py @@ -0,0 +1,9 @@ +import asyncio + +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, sleep: float = 0) -> str: + await asyncio.sleep(sleep) + return f"done in {sleep} seconds" diff --git a/python/tests/server/test_http_output.py b/python/tests/server/test_http_output.py index bb19de66c..ad7bba586 100644 --- a/python/tests/server/test_http_output.py +++ b/python/tests/server/test_http_output.py @@ -7,11 +7,11 @@ from .conftest import uses_predictor, uses_predictor_with_client_options - -@uses_predictor("output_wrong_type") -def test_return_wrong_type(client): - resp = client.post("/predictions") - assert resp.status_code == 500 +# it's not the worst idea to validate outputs but it's slow and not required +# @uses_predictor("output_wrong_type") +# def test_return_wrong_type(client): +# resp = client.post("/predictions") +# assert resp.status_code == 500 @uses_predictor("output_file") diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index e23a89164..bbbb532f0 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -1,9 +1,9 @@ import asyncio import os import threading +import time from datetime import datetime from unittest import mock - import pytest import pytest_asyncio from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent @@ -23,6 +23,11 @@ ) +# TODO +# - setup logs +# - file inputs being converted + + def _fixture_path(name): test_dir = os.path.dirname(os.path.realpath(__file__)) return os.path.join(test_dir, f"fixtures/{name}.py") + ":Predictor" @@ -69,6 +74,31 @@ async def test_prediction_runner(runner): assert isinstance(response.completed_at, datetime) +@pytest.mark.asyncio +async def test_prediction_runner_async(): + "verify that predictions are not run back to back" + runner = PredictionRunner( + predictor_ref=_fixture_path("async_sleep"), shutdown_event=None, concurrency=10 + ) + await runner.setup() + results = [] + st = time.time() + for i in range(10): + _, result = runner.predict(PredictionRequest(input={"sleep": 0.1})) + results.append(result) + with pytest.raises(RunnerBusyError): + runner.predict(PredictionRequest(input={"sleep": 0.1})) + responses = await asyncio.gather(*results) + assert time.time() - st < 0.5 + for response in responses: + assert response.output == "done in 0.1 seconds" + assert response.status == "succeeded" + assert response.error is None + assert response.logs == "" + assert isinstance(response.started_at, datetime) + assert isinstance(response.completed_at, datetime) + + @pytest.mark.asyncio async def test_prediction_runner_called_while_busy(runner): request = PredictionRequest(input={"sleep": 1}) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 0e49c0921..b0138d792 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -12,6 +12,7 @@ Done, Heartbeat, Log, + PredictionInput, PredictionOutput, PredictionOutputType, ) @@ -327,19 +328,24 @@ async def test_cancel_is_safe(): try: for _ in range(50): - w.cancel() + with pytest.raises(KeyError): + w.cancel("1") await _process(w.setup()) for _ in range(50): - w.cancel() + with pytest.raises(KeyError): + w.cancel("1") - result1 = await _process(w.predict({"sleep": 0.5})) + input1 = PredictionInput({"sleep": 0.5}) + result1 = await _process(w.predict(input1)) for _ in range(50): - w.cancel() + with pytest.raises(KeyError): + w.cancel(input1.id) - result2 = await _process(w.predict({"sleep": 0.1})) + input2 = {"sleep": 0.1} + result2 = await _process(w.predict(input2)) assert not result1.done.canceled assert not result2.done.canceled @@ -361,21 +367,22 @@ async def test_cancel_idempotency(): await _process(w.setup()) p1_done = None + input1 = PredictionInput({"sleep": 0.5}) - async for event in w.predict({"sleep": 0.5}, poll=0.01): + async for event in w.predict(input1, poll=0.01): # We call cancel a WHOLE BUNCH to make sure that we don't propagate # any of those cancelations to subsequent predictions, regardless # of the internal implementation of exceptions raised inside signal # handlers. for _ in range(100): - w.cancel() + w.cancel(input1.id) if isinstance(event, Done): p1_done = event assert p1_done.canceled - result2 = await _process(w.predict({"sleep": 0.1})) + result2 = await _process(w.predict(PredictionInput({"sleep": 0.1}))) assert result2.done and not result2.done.canceled assert result2.output == "done in 0.1 seconds" @@ -400,10 +407,11 @@ async def test_cancel_multiple_predictions(): for _ in range(5): canceled = False + input = PredictionInput({"sleep": 0.5}) - async for event in w.predict({"sleep": 0.5}, poll=0.01): + async for event in w.predict(input, poll=0.01): if not canceled: - w.cancel() + w.cancel(input.id) canceled = True if isinstance(event, Done): @@ -449,12 +457,13 @@ async def test_heartbeats_cancel(): start = time.time() canceled = False - async for event in w.predict({"sleep": 10}, poll=0.1): + input = PredictionInput({"sleep": 10}) + async for event in w.predict(input, poll=0.1): if isinstance(event, Heartbeat): heartbeat_count += 1 if time.time() - start > 0.5: if not canceled: - w.cancel() + w.cancel(input.id) canceled = True elapsed = time.time() - start @@ -556,8 +565,10 @@ def _check_setup_events(self): def predict(self, name, steps): try: payload = {"name": name, "steps": steps} - self.predict_generator = self.worker.predict(payload) - self.predict_payload = payload + input = PredictionInput(payload) + self.worker.enter_predict(input.id) + self.predict_generator = self.worker.predict(input) + self.predict_payload = input self.predict_events = [] except InvalidStateException: pass @@ -574,9 +585,10 @@ def read_predict_events(self, n): self._check_predict_events() def _check_predict_events(self): + self.worker.exit_predict(self.predict_payload.id) assert isinstance(self.predict_events[-1], Done) - payload = self.predict_payload + payload = self.predict_payload.payload result = _sync_process(self.predict_events) expected_stdout = ["START\n"] @@ -594,7 +606,7 @@ def cancel(self, r): if isinstance(r, InvalidStateException): return - self.worker.cancel() + self.worker.cancel(self.predict_payload.id) result = self.await_(_process(r)) # We'd love to be able to assert result.done.canceled here, but we diff --git a/test-integration/test_integration/test_run.py b/test-integration/test_integration/test_run.py index ce35f805c..812b3acca 100644 --- a/test-integration/test_integration/test_run.py +++ b/test-integration/test_integration/test_run.py @@ -6,7 +6,7 @@ def test_run(tmpdir_factory): with open(tmpdir / "cog.yaml", "w") as f: cog_yaml = """ build: - python_version: "3.8" + python_version: "3.9" """ f.write(cog_yaml) @@ -24,7 +24,7 @@ def test_run_with_secret(tmpdir_factory): with open(tmpdir / "cog.yaml", "w") as f: cog_yaml = """ build: - python_version: "3.8" + python_version: "3.9" run: - echo hello world - command: >- diff --git a/tools/compatgen/internal/torch.go b/tools/compatgen/internal/torch.go index 55239b726..ecc52695a 100644 --- a/tools/compatgen/internal/torch.go +++ b/tools/compatgen/internal/torch.go @@ -209,7 +209,7 @@ func parseTorchInstallString(s string, defaultVersions map[string]string, cuda * torchaudio := libVersions["torchaudio"] // TODO: this could be determined from https://download.pytorch.org/whl/torch/ - pythons := []string{"3.7", "3.8", "3.9", "3.10", "3.11"} + pythons := []string{"3.8", "3.9", "3.10", "3.11"} return &config.TorchCompatibility{ Torch: torch,