Skip to content

Commit

Permalink
fix: add handling of early setup exceptions on "train" path in cog pr…
Browse files Browse the repository at this point in the history
…edict/build (#1490)

* fix: add handling of early setup exceptions on "train" path in cog predict


Signed-off-by: Dmitri Khokhlov <dkhokhlov@gmail.com>

* factor out add_setup_failed_routes

Signed-off-by: technillogue <technillogue@gmail.com>

---------

Signed-off-by: Dmitri Khokhlov <dkhokhlov@gmail.com>
Signed-off-by: technillogue <technillogue@gmail.com>
Co-authored-by: technillogue <technillogue@gmail.com>
  • Loading branch information
dkhokhlov and technillogue authored Jan 19, 2024
1 parent 2e57549 commit 4f2d690
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 41 deletions.
5 changes: 4 additions & 1 deletion python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
"""
import json

from ..errors import ConfigDoesNotExist, PredictorNotSet
from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet
from ..predictor import load_config
from ..schema import Status
from ..server.http import create_app
from ..suppress_output import suppress_output

Expand All @@ -16,6 +17,8 @@
with suppress_output():
config = load_config()
app = create_app(config, shutdown_event=None)
if app.state.setup_result and app.state.setup_result.status == Status.FAILED:
raise CogError(app.state.setup_result.logs)
schema = app.openapi()
except (ConfigDoesNotExist, PredictorNotSet):
# If there is no cog.yaml or 'predict' has not been set, then there is no type signature.
Expand Down
86 changes: 50 additions & 36 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,31 @@ class MyState:
setup_task: Optional[SetupTask]
setup_result: Optional[SetupResult]


class MyFastAPI(FastAPI):
# TODO: not, strictly speaking, legal
# https://github.com/microsoft/pyright/issues/5933
# but it'd need a FastAPI patch to fix
state: MyState # type: ignore


def add_setup_failed_routes(app: MyFastAPI, started_at: datetime, msg: str) -> None:
print(msg)
result = SetupResult(
started_at=started_at,
completed_at=datetime.now(tz=timezone.utc),
logs=msg,
status=schema.Status.FAILED,
)
app.state.setup_result = result
app.state.health = Health.SETUP_FAILED

@app.get("/health-check")
async def healthcheck_startup_failed() -> Any:
setup = attrs.asdict(app.state.setup_result)
return jsonable_encoder({"status": app.state.health.name, "setup": setup})


def create_app(
config: Dict[str, Any],
shutdown_event: Optional[threading.Event],
Expand All @@ -96,37 +115,23 @@ def create_app(
app.state.setup_result = None
started_at = datetime.now(tz=timezone.utc)

predictor_ref = get_predictor_ref(config, mode)
# shutdown is needed no matter what happens
@app.post("/shutdown")
async def start_shutdown() -> Any:
log.info("shutdown requested via http")
if shutdown_event is not None:
shutdown_event.set()
return JSONResponse({}, status_code=200)

try:
predictor_ref = get_predictor_ref(config, mode)
# TODO: avoid loading predictor code in this process
predictor = load_predictor_from_ref(predictor_ref)
InputType = get_input_type(predictor)
OutputType = get_output_type(predictor)
except Exception:
app.state.health = Health.SETUP_FAILED
msg = "Error while loading predictor:\n\n" + traceback.format_exc()
print(msg)
result = SetupResult(
started_at=started_at,
completed_at=datetime.now(tz=timezone.utc),
logs=msg,
status=schema.Status.FAILED,
)
app.state.setup_result = result

@app.get("/health-check")
async def healthcheck_startup_failed() -> Any:
setup = attrs.asdict(app.state.setup_result)
return jsonable_encoder({"status": app.state.health.name, "setup": setup})

@app.post("/shutdown")
async def start_shutdown_startup_failed() -> Any:
log.info("shutdown requested via http")
if shutdown_event is not None:
shutdown_event.set()
return JSONResponse({}, status_code=200)

add_setup_failed_routes(app, started_at, msg)
return app

runner = PredictionRunner(
Expand All @@ -137,6 +142,7 @@ async def start_shutdown_startup_failed() -> Any:

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
pass

PredictionResponse = schema.PredictionResponse.with_types(
input_type=InputType, output_type=OutputType
)
Expand All @@ -146,6 +152,7 @@ class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType
if TYPE_CHECKING:
P = ParamSpec("P")
T = TypeVar("T")

def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]":
@functools.wraps(f)
async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":
Expand All @@ -155,11 +162,16 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":
return wrapped

if "train" in config:
# TODO: avoid loading trainer code in this process
trainer = load_predictor_from_ref(config["train"])

TrainingInputType = get_training_input_type(trainer)
TrainingOutputType = get_training_output_type(trainer)
try:
# TODO: avoid loading trainer code in this process
trainer = load_predictor_from_ref(config["train"])
TrainingInputType = get_training_input_type(trainer)
TrainingOutputType = get_training_output_type(trainer)
except Exception:
app.state.health = Health.SETUP_FAILED
msg = "Error while loading trainer:\n\n" + traceback.format_exc()
add_setup_failed_routes(app, started_at, msg)
return app

class TrainingRequest(
schema.TrainingRequest.with_types(input_type=TrainingInputType)
Expand Down Expand Up @@ -196,7 +208,16 @@ def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any:

@app.on_event("startup")
def startup() -> None:
app.state.setup_task = runner.setup()
# check for early setup failures
if (
app.state.setup_result
and app.state.setup_result.status == schema.Status.FAILED
):
if not args.await_explicit_shutdown: # signal shutdown if interactive run
if shutdown_event is not None:
shutdown_event.set()
else:
app.state.setup_task = runner.setup()

@app.on_event("shutdown")
def shutdown() -> None:
Expand Down Expand Up @@ -333,13 +354,6 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any:
else:
return JSONResponse({}, status_code=200)

@app.post("/shutdown")
async def start_shutdown() -> Any:
log.info("shutdown requested via http")
if shutdown_event is not None:
shutdown_event.set()
return JSONResponse({}, status_code=200)

def _check_setup_result() -> Any:
if app.state.setup_task is None:
return
Expand Down
4 changes: 2 additions & 2 deletions test-integration/test_integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

import pytest

from .util import random_string
from .util import remove_docker_image
from .util import random_string, remove_docker_image


def pytest_sessionstart(session):
Expand Down
2 changes: 1 addition & 1 deletion test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_build_without_predictor(docker_image):
capture_output=True,
)
assert build_process.returncode > 0
assert "Model schema is invalid" in build_process.stderr.decode()
assert "Can't run predictions: 'predict' option not found" in build_process.stderr.decode()


def test_build_names_uses_image_option_in_cog_yaml(tmpdir, docker_image):
Expand Down
3 changes: 2 additions & 1 deletion test-integration/test_integration/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import random
import string
import time
import subprocess
import time


def random_string(length):
return "".join(random.choice(string.ascii_lowercase) for i in range(length))
Expand Down

0 comments on commit 4f2d690

Please sign in to comment.