Skip to content

Commit

Permalink
fix: ignore "train" part of schema errors during predict runs in prod…
Browse files Browse the repository at this point in the history
…uction for backward compatibility

Signed-off-by: Dmitri Khokhlov <dkhokhlov@gmail.com>
  • Loading branch information
dkhokhlov committed Jan 24, 2024
1 parent 4f2d690 commit d60370e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 38 deletions.
2 changes: 1 addition & 1 deletion python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
try:
with suppress_output():
config = load_config()
app = create_app(config, shutdown_event=None)
app = create_app(config, shutdown_event=None, is_build=True)
if app.state.setup_result and app.state.setup_result.status == Status.FAILED:
raise CogError(app.state.setup_result.logs)
schema = app.openapi()
Expand Down
83 changes: 46 additions & 37 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pydantic.error_wrappers import ErrorWrapper

from .. import schema
from ..errors import PredictorNotSet
from ..files import upload_file
from ..json import upload_files
from ..logging import setup_logging
Expand Down Expand Up @@ -104,6 +105,7 @@ def create_app(
threads: int = 1,
upload_url: Optional[str] = None,
mode: str = "predict",
is_build: bool = False
) -> MyFastAPI:
app = MyFastAPI(
title="Cog", # TODO: mention model name?
Expand Down Expand Up @@ -163,48 +165,54 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":

if "train" in config:
try:
trainer_ref = get_predictor_ref(config, "train")
# TODO: avoid loading trainer code in this process
trainer = load_predictor_from_ref(config["train"])
trainer = load_predictor_from_ref(trainer_ref)
TrainingInputType = get_training_input_type(trainer)
TrainingOutputType = get_training_output_type(trainer)
except Exception:
app.state.health = Health.SETUP_FAILED
msg = "Error while loading trainer:\n\n" + traceback.format_exc()
add_setup_failed_routes(app, started_at, msg)
return app

class TrainingRequest(
schema.TrainingRequest.with_types(input_type=TrainingInputType)
):
pass

TrainingResponse = schema.TrainingResponse.with_types(
input_type=TrainingInputType, output_type=TrainingOutputType
)

@app.post(
"/trainings",
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train(request: TrainingRequest = Body(default=None), prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore
return predict(request, prefer)
class TrainingRequest(
schema.TrainingRequest.with_types(input_type=TrainingInputType)
):
pass

@app.put(
"/trainings/{training_id}",
response_model=PredictionResponse,
response_model_exclude_unset=True,
)
def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Union[str, None] = Header(default=None),
) -> Any:
return predict_idempotent(training_id, request, prefer)
TrainingResponse = schema.TrainingResponse.with_types(
input_type=TrainingInputType, output_type=TrainingOutputType
)

@app.post("/trainings/{training_id}/cancel")
def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any:
return cancel(training_id)
@app.post(
"/trainings",
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train(request: TrainingRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore
return predict(request, prefer)

@app.put(
"/trainings/{training_id}",
response_model=PredictionResponse,
response_model_exclude_unset=True,
)
def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Union[str, None] = Header(default=None),
) -> Any:
return predict_idempotent(training_id, request, prefer)

@app.post("/trainings/{training_id}/cancel")
def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any:
return cancel(training_id)

except Exception as e:
if isinstance(e, (PredictorNotSet, FileNotFoundError)) and not is_build:
pass # ignore missing train.py for backward compatibility with existing "bad" models in use
else:
app.state.health = Health.SETUP_FAILED
msg = "Error while loading trainer:\n\n" + traceback.format_exc()
add_setup_failed_routes(app, started_at, msg)
return app

@app.on_event("startup")
def startup() -> None:
Expand Down Expand Up @@ -247,7 +255,8 @@ async def healthcheck() -> Any:
response_model=PredictionResponse,
response_model_exclude_unset=True,
)
async def predict(request: PredictionRequest = Body(default=None), prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore
async def predict(request: PredictionRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore
"""
Run a single prediction on the model
"""
Expand Down

0 comments on commit d60370e

Please sign in to comment.