From d60370e50c14eacb6638a300d45265913a65c7e3 Mon Sep 17 00:00:00 2001 From: Dmitri Khokhlov Date: Wed, 24 Jan 2024 12:40:42 -0800 Subject: [PATCH] fix: ignore "train" part of schema errors during predict runs in production for backward compatibility Signed-off-by: Dmitri Khokhlov --- python/cog/command/openapi_schema.py | 2 +- python/cog/server/http.py | 83 +++++++++++++++------------- 2 files changed, 47 insertions(+), 38 deletions(-) diff --git a/python/cog/command/openapi_schema.py b/python/cog/command/openapi_schema.py index d69dd8703..f44c3c663 100644 --- a/python/cog/command/openapi_schema.py +++ b/python/cog/command/openapi_schema.py @@ -16,7 +16,7 @@ try: with suppress_output(): config = load_config() - app = create_app(config, shutdown_event=None) + app = create_app(config, shutdown_event=None, is_build=True) if app.state.setup_result and app.state.setup_result.status == Status.FAILED: raise CogError(app.state.setup_result.logs) schema = app.openapi() diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 6f1b5ca73..106f08467 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -36,6 +36,7 @@ from pydantic.error_wrappers import ErrorWrapper from .. import schema +from ..errors import PredictorNotSet from ..files import upload_file from ..json import upload_files from ..logging import setup_logging @@ -104,6 +105,7 @@ def create_app( threads: int = 1, upload_url: Optional[str] = None, mode: str = "predict", + is_build: bool = False ) -> MyFastAPI: app = MyFastAPI( title="Cog", # TODO: mention model name? @@ -163,48 +165,54 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": if "train" in config: try: + trainer_ref = get_predictor_ref(config, "train") # TODO: avoid loading trainer code in this process - trainer = load_predictor_from_ref(config["train"]) + trainer = load_predictor_from_ref(trainer_ref) TrainingInputType = get_training_input_type(trainer) TrainingOutputType = get_training_output_type(trainer) - except Exception: - app.state.health = Health.SETUP_FAILED - msg = "Error while loading trainer:\n\n" + traceback.format_exc() - add_setup_failed_routes(app, started_at, msg) - return app - - class TrainingRequest( - schema.TrainingRequest.with_types(input_type=TrainingInputType) - ): - pass - - TrainingResponse = schema.TrainingResponse.with_types( - input_type=TrainingInputType, output_type=TrainingOutputType - ) - @app.post( - "/trainings", - response_model=TrainingResponse, - response_model_exclude_unset=True, - ) - def train(request: TrainingRequest = Body(default=None), prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore - return predict(request, prefer) + class TrainingRequest( + schema.TrainingRequest.with_types(input_type=TrainingInputType) + ): + pass - @app.put( - "/trainings/{training_id}", - response_model=PredictionResponse, - response_model_exclude_unset=True, - ) - def train_idempotent( - training_id: str = Path(..., title="Training ID"), - request: TrainingRequest = Body(..., title="Training Request"), - prefer: Union[str, None] = Header(default=None), - ) -> Any: - return predict_idempotent(training_id, request, prefer) + TrainingResponse = schema.TrainingResponse.with_types( + input_type=TrainingInputType, output_type=TrainingOutputType + ) - @app.post("/trainings/{training_id}/cancel") - def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any: - return cancel(training_id) + @app.post( + "/trainings", + response_model=TrainingResponse, + response_model_exclude_unset=True, + ) + def train(request: TrainingRequest = Body(default=None), + prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore + return predict(request, prefer) + + @app.put( + "/trainings/{training_id}", + response_model=PredictionResponse, + response_model_exclude_unset=True, + ) + def train_idempotent( + training_id: str = Path(..., title="Training ID"), + request: TrainingRequest = Body(..., title="Training Request"), + prefer: Union[str, None] = Header(default=None), + ) -> Any: + return predict_idempotent(training_id, request, prefer) + + @app.post("/trainings/{training_id}/cancel") + def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any: + return cancel(training_id) + + except Exception as e: + if isinstance(e, (PredictorNotSet, FileNotFoundError)) and not is_build: + pass # ignore missing train.py for backward compatibility with existing "bad" models in use + else: + app.state.health = Health.SETUP_FAILED + msg = "Error while loading trainer:\n\n" + traceback.format_exc() + add_setup_failed_routes(app, started_at, msg) + return app @app.on_event("startup") def startup() -> None: @@ -247,7 +255,8 @@ async def healthcheck() -> Any: response_model=PredictionResponse, response_model_exclude_unset=True, ) - async def predict(request: PredictionRequest = Body(default=None), prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore + async def predict(request: PredictionRequest = Body(default=None), + prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore """ Run a single prediction on the model """