From 6729d53d00e555accf023c42bee9b06eac21eaa0 Mon Sep 17 00:00:00 2001 From: technillogue Date: Fri, 27 Oct 2023 20:15:26 -0400 Subject: [PATCH] make tests async and fix them Signed-off-by: technillogue --- pyproject.toml | 1 + python/cog/server/http.py | 3 +- python/cog/server/runner.py | 5 +-- python/tests/server/test_http.py | 7 ++-- python/tests/server/test_runner.py | 57 +++++++++++++++++------------- 5 files changed, 42 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 41c9c6e89..edad0944d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [ 'numpy; python_version >= "3.8"', "pillow", "pytest", + "pytest-asyncio", "pytest-httpserver", "pytest-rerunfailures", "pytest-xdist", diff --git a/python/cog/server/http.py b/python/cog/server/http.py index ad72edcdd..80ddb59b6 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -203,8 +203,7 @@ async def _predict( return JSONResponse(jsonable_encoder(initial_response), status_code=202) try: - res = await async_result - response = PredictionResponse(res.dict()) + response = PredictionResponse(**(await async_result).dict()) except ValidationError as e: _log_invalid_output(e) raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index f2b08d662..491e1449c 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -73,7 +73,7 @@ def handle_error(task: Task) -> None: return handle_error - def setup(self) -> Task["dict[str, Any]"]: + def setup(self) -> "Task[dict[str, Any]]": if self.is_busy(): raise RunnerBusyError() self._result = asyncio.create_task(setup(worker=self._worker)) @@ -84,7 +84,7 @@ def setup(self) -> Task["dict[str, Any]"]: # no longer have to support Python 3.8 def predict( self, prediction: schema.PredictionRequest, upload: bool = True - ) -> Tuple[schema.PredictionResponse, Task[schema.PredictionResponse]]: + ) -> Tuple[schema.PredictionResponse, "Task[schema.PredictionResponse]"]: # It's the caller's responsibility to not call us if we're busy. if self.is_busy(): # If self._result is set, but self._response is not, we're still @@ -364,6 +364,7 @@ async def _predict( return event_handler.response # will be async for event in worker.predict(input_dict, poll=0.1): + await asyncio.sleep(0) if should_cancel.is_set(): worker.cancel() should_cancel.clear() diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 0de397a6d..ec3308b53 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -425,6 +425,7 @@ def test_prediction_idempotent_endpoint_conflict(client, match): json={"input": {"sleep": 1}}, headers={"Prefer": "respond-async"}, ) + time.sleep(0.001) resp2 = client.put( "/predictions/5678efgh", json={"input": {"sleep": 1}}, @@ -492,12 +493,12 @@ def test_prediction_cancel(client): ) assert resp.status_code == 202 - resp = client.post("/predictions/456/cancel") - assert resp.status_code == 404 - resp = client.post("/predictions/123/cancel") assert resp.status_code == 200 + resp = client.post("/predictions/456/cancel") + assert resp.status_code == 404 + @uses_predictor_with_client_options( "setup_weights", diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index bca5199e8..ed59eab32 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -4,6 +4,7 @@ from unittest import mock import pytest +import pytest_asyncio from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent from cog.server.eventtypes import ( Done, @@ -26,24 +27,25 @@ def _fixture_path(name): return os.path.join(test_dir, f"fixtures/{name}.py") + ":Predictor" -@pytest.fixture -def runner(): +@pytest_asyncio.fixture +async def runner(): runner = PredictionRunner( predictor_ref=_fixture_path("sleep"), shutdown_event=threading.Event() ) try: - runner.setup().get(5) + await runner.setup() yield runner finally: runner.shutdown() -def test_prediction_runner_setup(): +@pytest.mark.asyncio +async def test_prediction_runner_setup(): runner = PredictionRunner( predictor_ref=_fixture_path("sleep"), shutdown_event=threading.Event() ) try: - result = runner.setup().get(5) + result = await runner.setup() assert result["status"] == Status.SUCCEEDED assert result["logs"] == "" @@ -53,10 +55,11 @@ def test_prediction_runner_setup(): runner.shutdown() -def test_prediction_runner(runner): +@pytest.mark.asyncio +async def test_prediction_runner(runner): request = PredictionRequest(input={"sleep": 0.1}) _, async_result = runner.predict(request) - response = async_result.get(timeout=1) + response = await async_result assert response.output == "done in 0.1 seconds" assert response.status == "succeeded" assert response.error is None @@ -65,33 +68,36 @@ def test_prediction_runner(runner): assert isinstance(response.completed_at, datetime) -def test_prediction_runner_called_while_busy(runner): +@pytest.mark.asyncio +async def test_prediction_runner_called_while_busy(runner): request = PredictionRequest(input={"sleep": 0.1}) _, async_result = runner.predict(request) assert runner.is_busy() with pytest.raises(RunnerBusyError): - runner.predict(request) + await runner.predict(request)[1] - # Call .get() to ensure that the first prediction is scheduled before we + # Await to ensure that the first prediction is scheduled before we # attempt to shut down the runner. - async_result.get() + await async_result -def test_prediction_runner_called_while_busy_idempotent(runner): +@pytest.mark.asyncio +async def test_prediction_runner_called_while_busy_idempotent(runner): request = PredictionRequest(id="abcd1234", input={"sleep": 0.1}) runner.predict(request) runner.predict(request) _, async_result = runner.predict(request) - response = async_result.get(timeout=1) + response = await async_result assert response.id == "abcd1234" assert response.output == "done in 0.1 seconds" assert response.status == "succeeded" -def test_prediction_runner_called_while_busy_idempotent_wrong_id(runner): +@pytest.mark.asyncio +async def test_prediction_runner_called_while_busy_idempotent_wrong_id(runner): request1 = PredictionRequest(id="abcd1234", input={"sleep": 0.1}) request2 = PredictionRequest(id="5678efgh", input={"sleep": 0.1}) @@ -99,19 +105,20 @@ def test_prediction_runner_called_while_busy_idempotent_wrong_id(runner): with pytest.raises(RunnerBusyError): runner.predict(request2) - response = async_result.get(timeout=1) + response = await async_result assert response.id == "abcd1234" assert response.output == "done in 0.1 seconds" assert response.status == "succeeded" -def test_prediction_runner_cancel(runner): +@pytest.mark.asyncio +async def test_prediction_runner_cancel(runner): request = PredictionRequest(input={"sleep": 0.5}) _, async_result = runner.predict(request) runner.cancel() - response = async_result.get(timeout=1) + response = await async_result assert response.output is None assert response.status == "canceled" assert response.error is None @@ -120,25 +127,27 @@ def test_prediction_runner_cancel(runner): assert isinstance(response.completed_at, datetime) -def test_prediction_runner_cancel_matching_id(runner): +@pytest.mark.asyncio +async def test_prediction_runner_cancel_matching_id(runner): request = PredictionRequest(id="abcd1234", input={"sleep": 0.5}) _, async_result = runner.predict(request) runner.cancel(prediction_id="abcd1234") - response = async_result.get(timeout=1) + response = await async_result assert response.output is None assert response.status == "canceled" -def test_prediction_runner_cancel_by_mismatched_id(runner): +@pytest.mark.asyncio +async def test_prediction_runner_cancel_by_mismatched_id(runner): request = PredictionRequest(id="abcd1234", input={"sleep": 0.5}) _, async_result = runner.predict(request) with pytest.raises(UnknownPredictionError): runner.cancel(prediction_id="5678efgh") - response = async_result.get(timeout=1) + response = await async_result assert response.output == "done in 0.5 seconds" assert response.status == "succeeded" @@ -188,15 +197,15 @@ def predict(self, input_, poll=None): return FakeWorker() - +@pytest.mark.asyncio @pytest.mark.parametrize("events,calls", PREDICT_TESTS) -def test_predict(events, calls): +async def test_predict(events, calls): worker = fake_worker(events) request = PredictionRequest(input={"text": "hello"}, foo="bar") event_handler = mock.Mock() should_cancel = threading.Event() - predict( + await predict( worker=worker, request=request, event_handler=event_handler,