diff --git a/pkg/config/config.go b/pkg/config/config.go
index 6069b88de..c56051513 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -56,16 +56,21 @@ type Build struct {
pythonRequirementsContent []string
}
+type Concurrency struct {
+ Max int `json:"max,omitempty" yaml:"max"`
+}
+
type Example struct {
Input map[string]string `json:"input" yaml:"input"`
Output string `json:"output" yaml:"output"`
}
type Config struct {
- Build *Build `json:"build" yaml:"build"`
- Image string `json:"image,omitempty" yaml:"image"`
- Predict string `json:"predict,omitempty" yaml:"predict"`
- Train string `json:"train,omitempty" yaml:"train"`
+ Build *Build `json:"build" yaml:"build"`
+ Image string `json:"image,omitempty" yaml:"image"`
+ Predict string `json:"predict,omitempty" yaml:"predict"`
+ Train string `json:"train,omitempty" yaml:"train"`
+ Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"`
}
func DefaultConfig() *Config {
diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json
index 958fd6889..48ae781a3 100644
--- a/pkg/config/data/config_schema_v1.0.json
+++ b/pkg/config/data/config_schema_v1.0.json
@@ -154,11 +154,6 @@
"$id": "#/properties/concurrency/properties/max",
"type": "integer",
"description": "The maximum number of concurrent predictions."
- },
- "default_target": {
- "$id": "#/properties/concurrency/properties/default_target",
- "type": "integer",
- "description": "The default target for number of concurrent predictions. This setting can be used by an autoscaler to determine when to scale a deployment of a model up or down."
}
}
}
diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py
index 89322debb..c491c031e 100644
--- a/python/cog/command/ast_openapi_schema.py
+++ b/python/cog/command/ast_openapi_schema.py
@@ -147,6 +147,24 @@
"summary": "Healthcheck"
}
},
+ "/ready": {
+ "get": {
+ "summary": "Ready",
+ "operationId": "ready_ready_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "title": "Response Ready Ready Get"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
"/predictions": {
"post": {
"description": "Run a single prediction on the model",
diff --git a/python/cog/predictor.py b/python/cog/predictor.py
index 6a4316385..c546cdda8 100644
--- a/python/cog/predictor.py
+++ b/python/cog/predictor.py
@@ -63,6 +63,13 @@ def predict(self, **kwargs: Any) -> Any:
"""
pass
+ def log(self, *messages: str) -> None:
+ """
+ Write a log message that will be tagged with the current prediction
+ even during concurrent predictions. At runtime this method is overriden.
+ """
+ print(*messages)
+
def run_setup(predictor: BasePredictor) -> None:
weights = get_weights_argument(predictor)
diff --git a/python/cog/server/eventtypes.py b/python/cog/server/eventtypes.py
index a1df38fff..ee3e68997 100644
--- a/python/cog/server/eventtypes.py
+++ b/python/cog/server/eventtypes.py
@@ -20,6 +20,11 @@ def from_request(cls, request: schema.PredictionRequest) -> "PredictionInput":
return cls(payload=payload, id=request.id)
+@define
+class Cancel:
+ id: str
+
+
@define
class Shutdown:
pass
diff --git a/python/cog/server/http.py b/python/cog/server/http.py
index 81825a6a8..43cac5835 100644
--- a/python/cog/server/http.py
+++ b/python/cog/server/http.py
@@ -27,11 +27,10 @@
import attrs
import structlog
import uvicorn
-from fastapi import Body, FastAPI, Header, HTTPException, Path, Response
+from fastapi import Body, FastAPI, Header, Path, Response
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
-from pydantic import ValidationError
from pydantic.error_wrappers import ErrorWrapper
from .. import schema
@@ -133,10 +132,13 @@ async def start_shutdown() -> Any:
add_setup_failed_routes(app, started_at, msg)
return app
+ concurrency = config.get("concurrency", {}).get("max", "1")
+
runner = PredictionRunner(
predictor_ref=predictor_ref,
shutdown_event=shutdown_event,
upload_url=upload_url,
+ concurrency=int(concurrency),
)
class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
@@ -261,7 +263,22 @@ async def healthcheck() -> Any:
else:
health = app.state.health
setup = attrs.asdict(app.state.setup_result) if app.state.setup_result else {}
- return jsonable_encoder({"status": health.name, "setup": setup})
+ activity = runner.activity_info()
+ return jsonable_encoder(
+ {"status": health.name, "setup": setup, "concurrency": activity}
+ )
+
+ # this is a readiness probe, it only returns 200 when work can be accepted
+ @app.get("/ready")
+ async def ready() -> Any:
+ activity = runner.activity_info()
+ if runner.is_busy():
+ return JSONResponse(
+ {"status": "ready", "activity": activity}, status_code=200
+ )
+ return JSONResponse(
+ {"status": "not ready", "activity": activity}, status_code=503
+ )
@limited
@app.post(
@@ -348,27 +365,23 @@ async def shared_predict(
if respond_async:
return JSONResponse(jsonable_encoder(initial_response), status_code=202)
- # by now, output Path and File are already converted to str
- # so when we validate the schema, those urls get cast back to Path and File
- # in the previous implementation those would then get encoded as strings
- # however the changes to Path and File break this and return the filename instead
- try:
- prediction = await async_result
- # we're only doing this to catch validation errors
- response = PredictionResponse(**prediction.dict())
- del response
- except ValidationError as e:
- _log_invalid_output(e)
- raise HTTPException(status_code=500, detail=str(e)) from e
-
- # dict_resp = response.dict()
- # output = await runner.client_manager.upload_files(
- # dict_resp["output"], upload_url
- # )
- # dict_resp["output"] = output
- # encoded_response = jsonable_encoder(dict_resp)
-
- # return *prediction* and not *response* to preserve urls
+ # # by now, output Path and File are already converted to str
+ # # so when we validate the schema, those urls get cast back to Path and File
+ # # in the previous implementation those would then get encoded as strings
+ # # however the changes to Path and File break this and return the filename instead
+ #
+ # # moreover, validating outputs can be a bottleneck with enough volume
+ # # since it's not strictly needed, we can comment it out
+ # try:
+ # prediction = await async_result
+ # # we're only doing this to catch validation errors
+ # response = PredictionResponse(**prediction.dict())
+ # del response
+ # except ValidationError as e:
+ # _log_invalid_output(e)
+ # raise HTTPException(status_code=500, detail=str(e)) from e
+
+ prediction = await async_result
encoded_response = jsonable_encoder(prediction.dict())
return JSONResponse(content=encoded_response)
@@ -377,8 +390,7 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any:
"""
Cancel a running prediction
"""
- if not runner.is_busy():
- return JSONResponse({}, status_code=404)
+ # no need to check whether or not we're busy
try:
runner.cancel(prediction_id)
except UnknownPredictionError:
@@ -433,7 +445,6 @@ def start(self) -> None:
def stop(self) -> None:
log.info("stopping server")
- self.should_exit = True
self._thread.join(timeout=5)
if not self._thread.is_alive():
diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py
index 07ff48762..ada6c10dc 100644
--- a/python/cog/server/runner.py
+++ b/python/cog/server/runner.py
@@ -1,10 +1,16 @@
import asyncio
+import contextlib
+import logging
import multiprocessing
+import os
+import signal
+import sys
import threading
import traceback
import typing # TypeAlias, py3.10
from datetime import datetime, timezone
-from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Tuple, Union, cast
+from enum import Enum, auto, unique
+from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
import httpx
import structlog
@@ -12,7 +18,9 @@
from .. import schema, types
from .clients import SKIP_START_EVENT, ClientManager
+from .connection import AsyncConnection
from .eventtypes import (
+ Cancel,
Done,
Heartbeat,
Log,
@@ -20,9 +28,14 @@
PredictionOutput,
PredictionOutputType,
PublicEventType,
+ Shutdown,
+)
+from .exceptions import (
+ FatalWorkerException,
+ InvalidStateException,
)
from .probes import ProbeHelper
-from .worker import Worker
+from .worker import Mux, _ChildWorker
log = structlog.get_logger("cog.server.runner")
_spawn = multiprocessing.get_context("spawn")
@@ -40,6 +53,16 @@ class UnknownPredictionError(Exception):
pass
+@unique
+class WorkerState(Enum):
+ NEW = auto()
+ STARTING = auto()
+ IDLE = auto()
+ PROCESSING = auto()
+ BUSY = auto()
+ DEFUNCT = auto()
+
+
@define
class SetupResult:
started_at: datetime
@@ -47,12 +70,20 @@ class SetupResult:
logs: str
status: schema.Status
+ # TODO: maybe collect events into a result here
+
PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]"
RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask]
+# TODO: we might prefer to move this back to worker
+# runner would still need to do PredictionEventHandler
+# if it's not inline, we would need to make sure {enter,exit}_predict is handled correctly
+# this is a major outstanding piece of work for merging into main
+
+
class PredictionRunner:
def __init__(
self,
@@ -60,21 +91,103 @@ def __init__(
predictor_ref: str,
shutdown_event: Optional[threading.Event],
upload_url: Optional[str] = None,
+ concurrency: int = 1,
+ tee_output: bool = True,
) -> None:
- self._response: Optional[schema.PredictionResponse] = None
- self._result: Optional[RunnerTask] = None
+ self._shutdown_event = shutdown_event # __main__ waits for this event
- self._worker = Worker(predictor_ref=predictor_ref)
-
- self._shutdown_event = shutdown_event
self._upload_url = upload_url
+ self._predictions: "dict[str, tuple[schema.PredictionResponse, PredictionTask]]" = {}
+ self._predictions_in_flight: "set[str]" = set()
+ # it would be lovely to merge these but it's not fully clear how best to handle it
+ # since idempotent requests can kinda come whenever?
+ # p: dict[str, PredictionTask]
+ # p: dict[str, PredictionEventHandler]
+ # p: dict[str, schema.PredictionResponse]
self.client_manager = ClientManager()
+ # TODO: perhaps this could go back into worker, if we could get the interface right
+ # (unclear how to do the tests)
+ #
+ self._state = WorkerState.NEW
+ self._semaphore = asyncio.Semaphore(concurrency)
+ self._concurrency = concurrency
+
+ # A pipe with which to communicate with the child worker.
+ events, child_events = _spawn.Pipe()
+ self._child = _ChildWorker(predictor_ref, child_events, tee_output)
+ self._events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
+ events
+ )
+ # shutdown requested
+ self._shutting_down = False
+ # stop reading events
+ self._terminating = asyncio.Event()
+ self._mux = Mux(self._terminating)
+ #
# bind logger instead of the module-level logger proxy for performance
self.log = log.bind()
- def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
+ def activity_info(self) -> "dict[str, int]":
+ return {"max": self._concurrency, "current": len(self._predictions_in_flight)}
+
+ def setup(self) -> SetupTask:
+ if self._state != WorkerState.NEW:
+ raise RunnerBusyError
+ self._state = WorkerState.STARTING
+
+ # app is allowed to respond to requests and poll the state of this task
+ # while it is running
+ async def inner() -> SetupResult:
+ logs = []
+ status = None
+ started_at = datetime.now(tz=timezone.utc)
+
+ # in 3.10 Event started doing get_running_loop
+ # previously it stored the loop when created, which causes an error in tests
+ if sys.version_info < (3, 10):
+ self._terminating = self._mux.terminating = asyncio.Event()
+
+ self._child.start()
+ await self._events.async_init()
+ self._start_event_reader()
+
+ try:
+ async for event in self._mux.read("SETUP", poll=0.1):
+ if isinstance(event, Log):
+ logs.append(event.message)
+ elif isinstance(event, Done):
+ if event.error:
+ raise FatalWorkerException(
+ "Predictor errored during setup: " + event.error_detail
+ )
+ status = schema.Status.FAILED
+ else:
+ status = schema.Status.SUCCEEDED
+ self._state = WorkerState.IDLE
+ except Exception:
+ logs.append(traceback.format_exc())
+ status = schema.Status.FAILED
+
+ if status is None:
+ logs.append("Error: did not receive 'done' event from setup!")
+ status = schema.Status.FAILED
+
+ completed_at = datetime.now(tz=timezone.utc)
+
+ # Only if setup succeeded, mark the container as "ready".
+ if status == schema.Status.SUCCEEDED:
+ probes = ProbeHelper()
+ probes.ready()
+
+ return SetupResult(
+ started_at=started_at,
+ completed_at=completed_at,
+ logs="".join(logs),
+ status=status,
+ )
+
def handle_error(task: RunnerTask) -> None:
exc = task.exception()
if not exc:
@@ -85,34 +198,64 @@ def handle_error(task: RunnerTask) -> None:
try:
raise exc
except Exception:
- self.log.error(f"caught exception while running {activity}", exc_info=True)
+ self.log.error("caught exception while running setup", exc_info=True)
if self._shutdown_event is not None:
self._shutdown_event.set()
- return handle_error
+ result = asyncio.create_task(inner())
+ result.add_done_callback(handle_error)
+ return result
+
+ def state_from_predictions_in_flight(self) -> WorkerState:
+ valid_states = {WorkerState.IDLE, WorkerState.PROCESSING, WorkerState.BUSY}
+ if self._state not in valid_states:
+ raise InvalidStateException(
+ f"Invalid operation: state is {self._state} (must be IDLE, PROCESSING, or BUSY)"
+ )
+ if len(self._predictions_in_flight) == self._concurrency:
+ return WorkerState.BUSY
+ if len(self._predictions_in_flight) == 0:
+ return WorkerState.IDLE
+ return WorkerState.PROCESSING
- def setup(self) -> SetupTask:
+ def is_busy(self) -> bool:
+ return self._state not in {WorkerState.PROCESSING, WorkerState.IDLE}
+
+ def enter_predict(self, id: str) -> None:
if self.is_busy():
- raise RunnerBusyError()
- self._result = asyncio.create_task(setup(worker=self._worker))
- self._result.add_done_callback(self.make_error_handler("setup"))
- return self._result
+ raise InvalidStateException(
+ f"Invalid operation: state is {self._state} (must be processing or idle)"
+ )
+ if self._shutting_down:
+ raise InvalidStateException(
+ "cannot accept new predictions because shutdown requested"
+ )
+ self.log.info(
+ "accepted prediction %s in flight %s", id, self._predictions_in_flight
+ )
+ self._predictions_in_flight.add(id)
+ self._state = self.state_from_predictions_in_flight()
+
+ def exit_predict(self, id: str) -> None:
+ self._predictions_in_flight.remove(id)
+ self._state = self.state_from_predictions_in_flight()
+
+ @contextlib.contextmanager
+ def prediction_ctx(self, id: str) -> Iterator[None]:
+ self.enter_predict(id)
+ try:
+ yield
+ finally:
+ self.exit_predict(id)
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
# no longer have to support Python 3.8
def predict(
- self, request: schema.PredictionRequest, upload: bool = True
- ) -> Tuple[schema.PredictionResponse, PredictionTask]:
- # It's the caller's responsibility to not call us if we're busy.
+ self, request: schema.PredictionRequest, poll: Optional[float] = None
+ ) -> "tuple[schema.PredictionResponse, PredictionTask]":
if self.is_busy():
- # If self._result is set, but self._response is not, we're still
- # doing setup.
- if self._response is None:
- raise RunnerBusyError()
- assert self._result is not None
- if request.id is not None and request.id == self._response.id: # type: ignore
- result = cast(PredictionTask, self._result)
- return (self._response, result)
+ if request.id in self._predictions:
+ return self._predictions[request.id]
raise RunnerBusyError()
# Set up logger context for main thread. The same thing happens inside
@@ -122,10 +265,18 @@ def predict(
# if upload url was not set, we can respect output_file_prefix
# but maybe we should just throw an error
upload_url = request.output_file_prefix or self._upload_url
- event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log)
- self._response = event_handler.response
+ # this is supposed to send START, but we're trapped in a sync function
+ # this sends START in a task, which calls jsonable_encoder on the input,
+ # which calls iter(io.BytesIO) with data uris that are File
+ # that breaks one of the tests, but happens Rarely in production,
+ # so let's ignore it for now
+ event_handler = PredictionEventHandler(
+ request, self.client_manager, upload_url, self.log
+ )
+ response = event_handler.response
prediction_input = PredictionInput.from_request(request)
+ self.enter_predict(request.id)
async def async_predict_handling_errors() -> schema.PredictionResponse:
try:
@@ -137,9 +288,11 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
if isinstance(v, types.URLTempFile):
real_path = await v.convert(self.client_manager.download_client)
prediction_input.payload[k] = real_path
- event_stream = self._worker.predict(prediction_input.payload, poll=0.1)
- result = await event_handler.handle_event_stream(event_stream)
- return result
+ async with self._semaphore:
+ self._events.send(prediction_input)
+ event_stream = self._mux.read(prediction_input.id, poll=poll)
+ result = await event_handler.handle_event_stream(event_stream)
+ return result
except httpx.HTTPError as e:
tb = traceback.format_exc()
await event_handler.append_logs(tb)
@@ -150,44 +303,110 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
tb = traceback.format_exc()
await event_handler.append_logs(tb)
await event_handler.failed(error=str(e))
- self.log.error("caught exception while running prediction", exc_info=True)
+ self.log.error(
+ "caught exception while running prediction", exc_info=True
+ )
if self._shutdown_event is not None:
self._shutdown_event.set()
raise # we don't actually want to raise anymore but w/e
finally:
+ # mark the prediction as done and update state
+ # ... actually, we might want to mark that part earlier
+ # even if we're still uploading files we can accept new work
+ self.exit_predict(prediction_input.id)
# FIXME: use isinstance(BaseInput)
if hasattr(request.input, "cleanup"):
request.input.cleanup() # type: ignore
+ # this might also, potentially, be too early
+ # since this is just before this coroutine exits
+ self._predictions.pop(request.id)
# this is still a little silly
- self._result = asyncio.create_task(async_predict_handling_errors())
- self._result.add_done_callback(self.make_error_handler("prediction"))
+ result = asyncio.create_task(async_predict_handling_errors())
+ # result.add_done_callback(self.make_error_handler("prediction"))
# even after inlining we might still need a callback to surface remaining exceptions/results
- return (self._response, self._result)
+ self._predictions[request.id] = (response, result)
- def is_busy(self) -> bool:
- if self._result is None:
- return False
-
- if not self._result.done():
- return True
-
- self._response = None
- self._result = None
- return False
+ return (response, result)
def shutdown(self) -> None:
- if self._result:
- self._result.cancel()
- self._worker.terminate()
+ if self._state == WorkerState.DEFUNCT:
+ return
+ # shutdown requested, but keep reading events
+ self._shutting_down = True
- def cancel(self, prediction_id: Optional[str] = None) -> None:
- if not self.is_busy():
+ if self._child.is_alive():
+ self._events.send(Shutdown())
+
+ def terminate(self) -> None:
+ for _, task in self._predictions.values():
+ task.cancel()
+ if self._state == WorkerState.DEFUNCT:
return
- assert self._response is not None
- if prediction_id is not None and prediction_id != self._response.id:
+
+ self._terminating.set()
+ self._state = WorkerState.DEFUNCT
+
+ if self._child.is_alive():
+ self._child.terminate()
+ self._child.join()
+ self._events.close()
+
+ if self._read_events_task:
+ self._read_events_task.cancel()
+
+ def cancel(self, prediction_id: str) -> None:
+ if prediction_id not in self._predictions_in_flight:
+ self.log.warn(
+ "can't cancel %s (%s)", prediction_id, self._predictions_in_flight
+ )
raise UnknownPredictionError()
- self._worker.cancel()
+ if os.getenv("COG_DISABLE_CANCEL"):
+ self.log.warn("cancelling is disabled for this model")
+ return
+ maybe_pid = self._child.pid
+ if self._child.is_alive() and maybe_pid is not None:
+ # since we don't know if the predictor is sync or async, we both send
+ # the signal (honored only if sync) and the event (honored only if async)
+ os.kill(maybe_pid, signal.SIGUSR1)
+ self.log.info("sent cancel")
+ self._events.send(Cancel(prediction_id))
+ # maybe this should probably check self._semaphore._value == self._concurrent
+
+ _read_events_task: "Optional[asyncio.Task[None]]" = None
+
+ def _start_event_reader(self) -> None:
+ def handle_error(task: "asyncio.Task[None]") -> None:
+ if task.cancelled():
+ return
+ exc = task.exception()
+ if exc:
+ logging.error("caught exception", exc_info=exc)
+
+ if not self._read_events_task:
+ self._read_events_task = asyncio.create_task(self._read_events())
+ self._read_events_task.add_done_callback(handle_error)
+
+ async def _read_events(self) -> None:
+ while self._child.is_alive() and not self._terminating.is_set():
+ # in tests this can still be running when the task is destroyed
+ result = await self._events.recv()
+ id, event = result
+ if id == "LOG" and self._state == WorkerState.STARTING:
+ id = "SETUP"
+ if id == "LOG" and len(self._predictions_in_flight) == 1:
+ id = list(self._predictions_in_flight)[0]
+ await self._mux.write(id, event)
+ # If we dropped off the end off the end of the loop, check if it's
+ # because the child process died.
+ if not self._child.is_alive() and not self._terminating.is_set():
+ exitcode = self._child.exitcode
+ self._mux.fatal = FatalWorkerException(
+ f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
+ )
+ # this is the same event as self._terminating
+ # we need to set it so mux.reads wake up and throw an error if needed
+ self._mux.terminating.set()
class PredictionEventHandler:
@@ -219,8 +438,7 @@ def __init__(
# latency (this guarantees that the first output webhook won't be
# throttled.)
if not SKIP_START_EVENT:
- # idk
- # this is pretty wrong
+ # sending it in a coroutine is kind of wrong in some ways
asyncio.create_task(self._send_webhook(schema.WebhookEvent.START))
@property
@@ -309,6 +527,7 @@ def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]:
return self.noop()
if isinstance(event, Log):
return self.append_logs(event.message)
+
if isinstance(event, PredictionOutputType):
if self._output_type is not None:
return self.failed(error="Predictor returned unexpected output")
@@ -328,41 +547,5 @@ def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]:
if event.error:
return self.failed(error=str(event.error_detail))
return self.succeeded()
- log.warn("received unexpected event from worker", data=event)
+ self.logger.warn("received unexpected event from worker", data=event)
return self.noop()
-
-
-async def setup(*, worker: Worker) -> SetupResult:
- logs = []
- status = None
- started_at = datetime.now(tz=timezone.utc)
-
- try:
- async for event in worker.setup():
- if isinstance(event, Log):
- logs.append(event.message)
- elif isinstance(event, Done):
- status = (
- schema.Status.FAILED if event.error else schema.Status.SUCCEEDED
- )
- except Exception:
- logs.append(traceback.format_exc())
- status = schema.Status.FAILED
-
- if status is None:
- logs.append("Error: did not receive 'done' event from setup!")
- status = schema.Status.FAILED
-
- completed_at = datetime.now(tz=timezone.utc)
-
- # Only if setup succeeded, mark the container as "ready".
- if status == schema.Status.SUCCEEDED:
- probes = ProbeHelper()
- probes.ready()
-
- return SetupResult(
- started_at=started_at,
- completed_at=completed_at,
- logs="".join(logs),
- status=status,
- )
diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py
index 941690c1d..1beea2a6e 100644
--- a/python/cog/server/worker.py
+++ b/python/cog/server/worker.py
@@ -1,17 +1,16 @@
import asyncio
import contextlib
import inspect
-import logging
import multiprocessing
-import os
import signal
import sys
import traceback
import types
from collections import defaultdict
+from contextvars import ContextVar
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
-from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, TextIO, Union
+from typing import Any, AsyncIterator, Callable, Iterator, Optional, TextIO
from ..json import make_encodeable
from ..predictor import (
@@ -23,25 +22,24 @@
)
from .connection import AsyncConnection
from .eventtypes import (
+ Cancel,
Done,
Heartbeat,
Log,
PredictionInput,
PredictionOutput,
PredictionOutputType,
+ PublicEventType,
Shutdown,
)
from .exceptions import (
CancelationException,
FatalWorkerException,
- InvalidStateException,
)
from .helpers import StreamRedirector, WrappedStream
_spawn = multiprocessing.get_context("spawn")
-_PublicEventType = Union[Done, Heartbeat, Log, PredictionOutput, PredictionOutputType]
-
@unique
class WorkerState(Enum):
@@ -55,18 +53,18 @@ class WorkerState(Enum):
class Mux:
def __init__(self, terminating: asyncio.Event) -> None:
- self.outs: "defaultdict[str, asyncio.Queue[_PublicEventType]]" = defaultdict(
+ self.outs: "defaultdict[str, asyncio.Queue[PublicEventType]]" = defaultdict(
asyncio.Queue
)
self.terminating = terminating
self.fatal: "Optional[FatalWorkerException]" = None
- async def write(self, id: str, item: _PublicEventType) -> None:
+ async def write(self, id: str, item: PublicEventType) -> None:
await self.outs[id].put(item)
async def read(
self, id: str, poll: Optional[float] = None
- ) -> AsyncIterator[_PublicEventType]:
+ ) -> AsyncIterator[PublicEventType]:
if poll:
send_heartbeats = True
else:
@@ -87,187 +85,6 @@ async def read(
raise self.fatal
-class Worker:
- def __init__(
- self, predictor_ref: str, tee_output: bool = True, concurrent: int = 1
- ) -> None:
- self._state = WorkerState.NEW
- self._allow_cancel = False
- self._semaphore = asyncio.Semaphore(concurrent)
- self._concurrent = concurrent
-
- # A pipe with which to communicate with the child worker.
- events, child_events = _spawn.Pipe()
- self._child = _ChildWorker(predictor_ref, child_events, tee_output)
- self._events: "AsyncConnection[tuple[str, _PublicEventType]]" = AsyncConnection(
- events
- )
- # shutdown requested
- self._shutting_down = False
- # stop reading events
- self._terminating = asyncio.Event()
- self._mux = Mux(self._terminating)
- self._predictions_in_flight = set()
-
- def setup(self) -> AsyncIterator[_PublicEventType]:
- self._assert_state(WorkerState.NEW)
- self._state = WorkerState.STARTING
-
- async def inner() -> AsyncIterator[_PublicEventType]:
- # in 3.10 Event started doing get_running_loop
- # previously it stored the loop when created, which causes an error in tests
- if sys.version_info < (3, 10):
- self._terminating = self._mux.terminating = asyncio.Event()
-
- self._child.start()
- self._ensure_event_reader()
- async for event in self._mux.read("SETUP", poll=0.1):
- yield event
- if isinstance(event, Done):
- if event.error:
- raise FatalWorkerException(
- "Predictor errored during setup: " + event.error_detail
- )
- self._state = WorkerState.IDLE
-
- return inner()
-
- def state_from_predictions_in_flight(self) -> WorkerState:
- valid_states = {WorkerState.IDLE, WorkerState.PROCESSING, WorkerState.BUSY}
- if self._state not in valid_states:
- raise InvalidStateException(
- f"Invalid operation: state is {self._state} (must be IDLE, PROCESSING, or BUSY)"
- )
- # changing _allow_cancel like this is a little bit weird
- # because this is kind of a pure function
- # however, we'll remove all of this cancel logic soon, dwabi
- if len(self._predictions_in_flight) == self._concurrent:
- self._allow_cancel = True
- return WorkerState.BUSY
- if len(self._predictions_in_flight) == 0:
- self._allow_cancel = False
- return WorkerState.IDLE
- self._allow_cancel = True
- return WorkerState.PROCESSING
-
- def is_busy(self) -> bool:
- return self._state not in {WorkerState.PROCESSING, WorkerState.IDLE}
-
- @contextlib.asynccontextmanager
- async def _prediction_ctx(self, input: PredictionInput) -> AsyncIterator[None]:
- async with self._semaphore:
- self._predictions_in_flight.add(input.id) # idempotent ig
- self._state = self.state_from_predictions_in_flight()
- try:
- yield
- finally:
- self._predictions_in_flight.remove(input.id)
- self._state = self.state_from_predictions_in_flight()
-
- def predict(
- self, payload: Dict[str, Any], poll: Optional[float] = None
- ) -> AsyncIterator[_PublicEventType]:
- # this has to be eager for hypothesis
- if self.is_busy():
- raise InvalidStateException(
- f"Invalid operation: state is {self._state} (must be processing or idle)"
- )
- if self._shutting_down:
- raise InvalidStateException(
- "cannot accept new predictions because shutdown requested"
- )
- input = PredictionInput(payload=payload)
- self._predictions_in_flight.add(input.id) # idempotent ig
- self._state = self.state_from_predictions_in_flight()
-
- async def inner() -> AsyncIterator[_PublicEventType]:
- async with self._prediction_ctx(input):
- self._events.send(input)
- async for event in self._mux.read(input.id, poll=poll):
- yield event
-
- return inner()
-
- def shutdown(self) -> None:
- if self._state == WorkerState.DEFUNCT:
- return
- # shutdown requested, but keep reading events
- self._shutting_down = True
-
- if self._child.is_alive():
- self._events.send(Shutdown())
-
- def terminate(self) -> None:
- if self._state == WorkerState.DEFUNCT:
- return
-
- self._terminating.set()
- self._state = WorkerState.DEFUNCT
-
- if self._child.is_alive():
- self._child.terminate()
- self._child.join()
- if self._read_events_task:
- self._read_events_task.cancel()
- self._events.close()
-
- # FIXME: this will need to use a combination
- # of signals and Cancel events on the pipe
- def cancel(self) -> None:
- if (
- self._allow_cancel
- and self._child.is_alive()
- and self._child.pid is not None
- ):
- os.kill(self._child.pid, signal.SIGUSR1)
- # this should probably check self._semaphore._value == self._concurrent
- self._allow_cancel = False
-
- def _assert_state(self, state: WorkerState) -> None:
- if self._state != state:
- raise InvalidStateException(
- f"Invalid operation: state is {self._state} (must be {state})"
- )
-
- _read_events_task: "Optional[asyncio.Task[None]]" = None
-
- def _ensure_event_reader(self) -> None:
- def handle_error(task: "asyncio.Task[None]") -> None:
- if task.cancelled():
- return
- exc = task.exception()
- if exc:
- logging.error("caught exception", exc_info=exc)
-
- if not self._read_events_task:
- self._read_events_task = asyncio.create_task(self._read_events())
- self._read_events_task.add_done_callback(handle_error)
-
- async def _read_events(self) -> None:
- while self._child.is_alive() and not self._terminating.is_set():
- # this can still be running when the task is destroyed
- result = await self._events.recv() # this might be kind of risky
- # event loop closed or child died
- if result is None: # type: ignore
- break
- id, event = result
- if id == "LOG" and self._state == WorkerState.STARTING:
- id = "SETUP"
- if id == "LOG" and len(self._predictions_in_flight) == 1:
- id = list(self._predictions_in_flight)[0]
- await self._mux.write(id, event)
- # If we dropped off the end off the end of the loop, check if it's
- # because the child process died.
- if not self._child.is_alive() and not self._terminating.is_set():
- exitcode = self._child.exitcode
- self._mux.fatal = FatalWorkerException(
- f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
- )
- # this is the same event as _terminating
- # we need to set it so mux.reads wake up and throw an error if needed
- self._mux.terminating.set()
-
-
class _ChildWorker(_spawn.Process): # type: ignore
def __init__(
self,
@@ -292,6 +109,9 @@ def run(self) -> None:
# We use SIGUSR1 to signal an interrupt for cancelation.
signal.signal(signal.SIGUSR1, self._signal_handler)
+ self.prediction_id_context: ContextVar[str] = ContextVar("prediction_context")
+
+ #
ws_stdout = WrappedStream("stdout", sys.stdout)
ws_stderr = WrappedStream("stderr", sys.stderr)
ws_stdout.wrap()
@@ -303,6 +123,8 @@ def run(self) -> None:
[ws_stdout, ws_stderr], self._stream_write_hook
)
self._stream_redirector.start()
+ #
+
self._setup()
self._loop()
self._stream_redirector.shutdown()
@@ -312,6 +134,7 @@ def _setup(self) -> None:
with self._handle_setup_error():
# we need to load the predictor to know if setup is async
self._predictor = load_predictor_from_ref(self._predictor_ref)
+ self._predictor.log = self._log
# if users want to access the same event loop from setup and predict,
# both have to be async. if setup isn't async, it doesn't matter if we
# create the event loop here or after setup
@@ -355,26 +178,38 @@ def _loop_sync(self) -> None:
break
if isinstance(ev, PredictionInput):
self._predict_sync(ev)
+ elif isinstance(ev, Cancel):
+ # in sync mode, Cancel events are ignored
+ # only signals are respected
+ pass
else:
print(f"Got unexpected event: {ev}", file=sys.stderr)
async def _loop_async(self) -> None:
- events: "AsyncConnection[tuple[str, _PublicEventType]]" = AsyncConnection(
+ events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
self._events
)
- while True:
- try:
- ev = await events.recv()
- except asyncio.CancelledError:
- return
- if isinstance(ev, Shutdown):
- return
- if isinstance(ev, PredictionInput):
- # keep track of these so they can be cancelled
- await self._predict_async(ev)
- # handle Cancel
- else:
- print(f"Got unexpected event: {ev}", file=sys.stderr)
+ with events:
+ tasks: "dict[str, asyncio.Task[None]]" = {}
+ while True:
+ try:
+ ev = await events.recv()
+ except asyncio.CancelledError:
+ return
+ if isinstance(ev, Shutdown):
+ return
+ if isinstance(ev, PredictionInput):
+ # keep track of these so they can be cancelled
+ tasks[ev.id] = asyncio.create_task(self._predict_async(ev))
+ elif isinstance(ev, Cancel):
+ # in async mode, cancel signals are ignored
+ # only Cancel events are ignored
+ if ev.id in tasks:
+ tasks[ev.id].cancel()
+ else:
+ print(f"Got unexpected cancellation: {ev}", file=sys.stderr)
+ else:
+ print(f"Got unexpected event: {ev}", file=sys.stderr)
def _loop(self) -> None:
if is_async(get_predict(self._predictor)):
@@ -387,21 +222,25 @@ def _handle_predict_error(self, id: str) -> Iterator[None]:
assert self._predictor
done = Done()
self._cancelable = True
+ token = self.prediction_id_context.set(id)
try:
yield
except CancelationException:
done.canceled = True
+ except asyncio.CancelledError:
+ done.canceled = True
except Exception as e:
traceback.print_exc()
done.error = True
done.error_detail = str(e)
finally:
+ self.prediction_id_context.reset(token)
self._cancelable = False
self._stream_redirector.drain()
self._events.send((id, done))
- def _mk_send(self, id: str) -> Callable[[_PublicEventType], None]:
- def send(event: _PublicEventType) -> None:
+ def _mk_send(self, id: str) -> Callable[[PublicEventType], None]:
+ def send(event: PublicEventType) -> None:
self._events.send((id, event))
return send
@@ -417,9 +256,9 @@ async def _predict_async(self, input: PredictionInput) -> None:
async for r in result:
send(PredictionOutput(payload=make_encodeable(r)))
elif inspect.isawaitable(result):
- result = await result
+ output = await result
send(PredictionOutputType(multi=False))
- send(PredictionOutput(payload=make_encodeable(result)))
+ send(PredictionOutput(payload=make_encodeable(output)))
def _predict_sync(self, input: PredictionInput) -> None:
with self._handle_predict_error(input.id):
@@ -436,16 +275,26 @@ def _predict_sync(self, input: PredictionInput) -> None:
send(PredictionOutput(payload=make_encodeable(result)))
def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None:
+ if self._predictor and is_async(get_predict(self._predictor)):
+ # we could try also canceling the async task around here
+ # but for now in async mode signals are ignored
+ return
+ # this logic might need to be refined
if signum == signal.SIGUSR1 and self._cancelable:
raise CancelationException()
+ def _log(self, *messages: str, source: str = "stderr") -> None:
+ id = self.prediction_id_context.get("LOG")
+ self._events.send((id, Log(" ".join(messages), source=source)))
+
def _stream_write_hook(
self, stream_name: str, original_stream: TextIO, data: str
) -> None:
if self._tee_output:
original_stream.write(data)
original_stream.flush()
- self._events.send(("LOG", Log(data, source=stream_name)))
+ # this won't work, this fn gets called from a thread, not the async task
+ self._log(data, stream_name)
def get_loop() -> asyncio.AbstractEventLoop:
diff --git a/python/tests/server/fixtures/async_sleep.py b/python/tests/server/fixtures/async_sleep.py
new file mode 100644
index 000000000..7c5f73407
--- /dev/null
+++ b/python/tests/server/fixtures/async_sleep.py
@@ -0,0 +1,9 @@
+import asyncio
+
+from cog import BasePredictor
+
+
+class Predictor(BasePredictor):
+ async def predict(self, sleep: float = 0) -> str:
+ await asyncio.sleep(sleep)
+ return f"done in {sleep} seconds"
diff --git a/python/tests/server/test_http_output.py b/python/tests/server/test_http_output.py
index bb19de66c..ad7bba586 100644
--- a/python/tests/server/test_http_output.py
+++ b/python/tests/server/test_http_output.py
@@ -7,11 +7,11 @@
from .conftest import uses_predictor, uses_predictor_with_client_options
-
-@uses_predictor("output_wrong_type")
-def test_return_wrong_type(client):
- resp = client.post("/predictions")
- assert resp.status_code == 500
+# it's not the worst idea to validate outputs but it's slow and not required
+# @uses_predictor("output_wrong_type")
+# def test_return_wrong_type(client):
+# resp = client.post("/predictions")
+# assert resp.status_code == 500
@uses_predictor("output_file")
diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py
index b8abf7b22..146cc8802 100644
--- a/python/tests/server/test_runner.py
+++ b/python/tests/server/test_runner.py
@@ -1,9 +1,9 @@
import asyncio
import os
import threading
+import time
from datetime import datetime
from unittest import mock
-
import pytest
import pytest_asyncio
from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent
@@ -23,6 +23,11 @@
)
+# TODO
+# - setup logs
+# - file inputs being converted
+
+
def _fixture_path(name):
test_dir = os.path.dirname(os.path.realpath(__file__))
return os.path.join(test_dir, f"fixtures/{name}.py") + ":Predictor"
@@ -69,6 +74,31 @@ async def test_prediction_runner(runner):
assert isinstance(response.completed_at, datetime)
+@pytest.mark.asyncio
+async def test_prediction_runner_async():
+ "verify that predictions are not run back to back"
+ runner = PredictionRunner(
+ predictor_ref=_fixture_path("async_sleep"), shutdown_event=None, concurrency=10
+ )
+ await runner.setup()
+ results = []
+ st = time.time()
+ for i in range(10):
+ _, result = runner.predict(PredictionRequest(input={"sleep": 0.1}))
+ results.append(result)
+ with pytest.raises(RunnerBusyError):
+ runner.predict(PredictionRequest(input={"sleep": 0.1}))
+ responses = await asyncio.gather(*results)
+ assert time.time() - st < 0.5
+ for response in responses:
+ assert response.output == "done in 0.1 seconds"
+ assert response.status == "succeeded"
+ assert response.error is None
+ assert response.logs == ""
+ assert isinstance(response.started_at, datetime)
+ assert isinstance(response.completed_at, datetime)
+
+
@pytest.mark.asyncio
async def test_prediction_runner_called_while_busy(runner):
request = PredictionRequest(input={"sleep": 1})
diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py
index f77e6f06a..f0d04f485 100644
--- a/python/tests/server/test_worker.py
+++ b/python/tests/server/test_worker.py
@@ -12,6 +12,7 @@
Done,
Heartbeat,
Log,
+ PredictionInput,
PredictionOutput,
PredictionOutputType,
)
@@ -301,19 +302,24 @@ async def test_cancel_is_safe():
try:
for _ in range(50):
- w.cancel()
+ with pytest.raises(KeyError):
+ w.cancel("1")
await _process(w.setup())
for _ in range(50):
- w.cancel()
+ with pytest.raises(KeyError):
+ w.cancel("1")
- result1 = await _process(w.predict({"sleep": 0.5}))
+ input1 = PredictionInput({"sleep": 0.5})
+ result1 = await _process(w.predict(input1))
for _ in range(50):
- w.cancel()
+ with pytest.raises(KeyError):
+ w.cancel(input1.id)
- result2 = await _process(w.predict({"sleep": 0.1}))
+ input2 = {"sleep": 0.1}
+ result2 = await _process(w.predict(input2))
assert not result1.done.canceled
assert not result2.done.canceled
@@ -335,21 +341,22 @@ async def test_cancel_idempotency():
await _process(w.setup())
p1_done = None
+ input1 = PredictionInput({"sleep": 0.5})
- async for event in w.predict({"sleep": 0.5}, poll=0.01):
+ async for event in w.predict(input1, poll=0.01):
# We call cancel a WHOLE BUNCH to make sure that we don't propagate
# any of those cancelations to subsequent predictions, regardless
# of the internal implementation of exceptions raised inside signal
# handlers.
for _ in range(100):
- w.cancel()
+ w.cancel(input1.id)
if isinstance(event, Done):
p1_done = event
assert p1_done.canceled
- result2 = await _process(w.predict({"sleep": 0.1}))
+ result2 = await _process(w.predict(PredictionInput({"sleep": 0.1})))
assert result2.done and not result2.done.canceled
assert result2.output == "done in 0.1 seconds"
@@ -374,10 +381,11 @@ async def test_cancel_multiple_predictions():
for _ in range(5):
canceled = False
+ input = PredictionInput({"sleep": 0.5})
- async for event in w.predict({"sleep": 0.5}, poll=0.01):
+ async for event in w.predict(input, poll=0.01):
if not canceled:
- w.cancel()
+ w.cancel(input.id)
canceled = True
if isinstance(event, Done):
@@ -423,12 +431,13 @@ async def test_heartbeats_cancel():
start = time.time()
canceled = False
- async for event in w.predict({"sleep": 10}, poll=0.1):
+ input = PredictionInput({"sleep": 10})
+ async for event in w.predict(input, poll=0.1):
if isinstance(event, Heartbeat):
heartbeat_count += 1
if time.time() - start > 0.5:
if not canceled:
- w.cancel()
+ w.cancel(input.id)
canceled = True
elapsed = time.time() - start
@@ -530,8 +539,10 @@ def _check_setup_events(self):
def predict(self, name, steps):
try:
payload = {"name": name, "steps": steps}
- self.predict_generator = self.worker.predict(payload)
- self.predict_payload = payload
+ input = PredictionInput(payload)
+ self.worker.enter_predict(input.id)
+ self.predict_generator = self.worker.predict(input)
+ self.predict_payload = input
self.predict_events = []
except InvalidStateException:
pass
@@ -548,9 +559,10 @@ def read_predict_events(self, n):
self._check_predict_events()
def _check_predict_events(self):
+ self.worker.exit_predict(self.predict_payload.id)
assert isinstance(self.predict_events[-1], Done)
- payload = self.predict_payload
+ payload = self.predict_payload.payload
result = _sync_process(self.predict_events)
expected_stdout = ["START\n"]
@@ -568,7 +580,7 @@ def cancel(self, r):
if isinstance(r, InvalidStateException):
return
- self.worker.cancel()
+ self.worker.cancel(self.predict_payload.id)
result = self.await_(_process(r))
# We'd love to be able to assert result.done.canceled here, but we
diff --git a/test-integration/test_integration/test_run.py b/test-integration/test_integration/test_run.py
index ce35f805c..812b3acca 100644
--- a/test-integration/test_integration/test_run.py
+++ b/test-integration/test_integration/test_run.py
@@ -6,7 +6,7 @@ def test_run(tmpdir_factory):
with open(tmpdir / "cog.yaml", "w") as f:
cog_yaml = """
build:
- python_version: "3.8"
+ python_version: "3.9"
"""
f.write(cog_yaml)
@@ -24,7 +24,7 @@ def test_run_with_secret(tmpdir_factory):
with open(tmpdir / "cog.yaml", "w") as f:
cog_yaml = """
build:
- python_version: "3.8"
+ python_version: "3.9"
run:
- echo hello world
- command: >-
diff --git a/tools/compatgen/internal/torch.go b/tools/compatgen/internal/torch.go
index 6e236a3b0..519ed7c93 100644
--- a/tools/compatgen/internal/torch.go
+++ b/tools/compatgen/internal/torch.go
@@ -208,7 +208,7 @@ func parseTorchInstallString(s string, defaultVersions map[string]string, cuda *
torchaudio := libVersions["torchaudio"]
// TODO: this could be determined from https://download.pytorch.org/whl/torch/
- pythons := []string{"3.7", "3.8", "3.9", "3.10", "3.11"}
+ pythons := []string{"3.8", "3.9", "3.10", "3.11"}
return &config.TorchCompatibility{
Torch: torch,