From b002d542ac68e010c20247416c3c64cd8779a12f Mon Sep 17 00:00:00 2001 From: technillogue Date: Mon, 6 Nov 2023 16:48:47 -0500 Subject: [PATCH] review changes to tests and server Signed-off-by: technillogue --- python/cog/server/http.py | 6 ++++-- python/cog/server/runner.py | 4 +++- python/tests/server/test_http.py | 7 +++---- python/tests/server/test_runner.py | 6 ++++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 80ddb59b6..2fe8e2f29 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -203,7 +203,8 @@ async def _predict( return JSONResponse(jsonable_encoder(initial_response), status_code=202) try: - response = PredictionResponse(**(await async_result).dict()) + prediction = await async_result + response = PredictionResponse(**prediction.dict()) except ValidationError as e: _log_invalid_output(e) raise HTTPException(status_code=500, detail=str(e)) from e @@ -246,7 +247,8 @@ async def _check_setup_result() -> Any: if not app.state.setup_result.done(): return - result = await app.state.setup_result + # this can raise CancelledError + result = app.state.setup_result.result() if result["status"] == schema.Status.SUCCEEDED: app.state.health = Health.READY diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index ccc42838e..35c27cfb8 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -129,8 +129,9 @@ def is_busy(self) -> bool: return False def shutdown(self) -> None: + if self._result: + self._result.cancel() self._worker.terminate() - # TODO: cancel setup or predict task def cancel(self, prediction_id: Optional[str] = None) -> None: if not self.is_busy(): @@ -280,6 +281,7 @@ async def setup(*, worker: Worker) -> Dict[str, Any]: try: # will be async for event in worker.setup(): + await asyncio.sleep(0) if isinstance(event, Log): logs.append(event.message) elif isinstance(event, Done): diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index ec3308b53..0de397a6d 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -425,7 +425,6 @@ def test_prediction_idempotent_endpoint_conflict(client, match): json={"input": {"sleep": 1}}, headers={"Prefer": "respond-async"}, ) - time.sleep(0.001) resp2 = client.put( "/predictions/5678efgh", json={"input": {"sleep": 1}}, @@ -493,12 +492,12 @@ def test_prediction_cancel(client): ) assert resp.status_code == 202 - resp = client.post("/predictions/123/cancel") - assert resp.status_code == 200 - resp = client.post("/predictions/456/cancel") assert resp.status_code == 404 + resp = client.post("/predictions/123/cancel") + assert resp.status_code == 200 + @uses_predictor_with_client_options( "setup_weights", diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index ed59eab32..2a2ae9a04 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -1,3 +1,4 @@ +import asyncio import os import threading from datetime import datetime @@ -75,7 +76,8 @@ async def test_prediction_runner_called_while_busy(runner): assert runner.is_busy() with pytest.raises(RunnerBusyError): - await runner.predict(request)[1] + _, task = runner.predict(request) + await task # Await to ensure that the first prediction is scheduled before we # attempt to shut down the runner. @@ -90,7 +92,7 @@ async def test_prediction_runner_called_while_busy_idempotent(runner): runner.predict(request) _, async_result = runner.predict(request) - response = await async_result + response = await asyncio.wait_for(async_result, timeout=1) assert response.id == "abcd1234" assert response.output == "done in 0.1 seconds" assert response.status == "succeeded"