diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 99a5393e4..1cacb9f10 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -196,7 +196,7 @@ def train( @app.put( "/trainings/{training_id}", - response_model=PredictionResponse, + response_model=TrainingResponse, response_model_exclude_unset=True, ) def train_idempotent(