Skip to content

Commit

Permalink
make tests async and fix them
Browse files Browse the repository at this point in the history
Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Nov 22, 2023
1 parent dc5ef44 commit 6729d53
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 31 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [
'numpy; python_version >= "3.8"',
"pillow",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-rerunfailures",
"pytest-xdist",
Expand Down
3 changes: 1 addition & 2 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ async def _predict(
return JSONResponse(jsonable_encoder(initial_response), status_code=202)

try:
res = await async_result
response = PredictionResponse(res.dict())
response = PredictionResponse(**(await async_result).dict())
except ValidationError as e:
_log_invalid_output(e)
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down
5 changes: 3 additions & 2 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def handle_error(task: Task) -> None:

return handle_error

def setup(self) -> Task["dict[str, Any]"]:
def setup(self) -> "Task[dict[str, Any]]":
if self.is_busy():
raise RunnerBusyError()
self._result = asyncio.create_task(setup(worker=self._worker))
Expand All @@ -84,7 +84,7 @@ def setup(self) -> Task["dict[str, Any]"]:
# no longer have to support Python 3.8
def predict(
self, prediction: schema.PredictionRequest, upload: bool = True
) -> Tuple[schema.PredictionResponse, Task[schema.PredictionResponse]]:
) -> Tuple[schema.PredictionResponse, "Task[schema.PredictionResponse]"]:
# It's the caller's responsibility to not call us if we're busy.
if self.is_busy():
# If self._result is set, but self._response is not, we're still
Expand Down Expand Up @@ -364,6 +364,7 @@ async def _predict(
return event_handler.response
# will be async
for event in worker.predict(input_dict, poll=0.1):
await asyncio.sleep(0)
if should_cancel.is_set():
worker.cancel()
should_cancel.clear()
Expand Down
7 changes: 4 additions & 3 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def test_prediction_idempotent_endpoint_conflict(client, match):
json={"input": {"sleep": 1}},
headers={"Prefer": "respond-async"},
)
time.sleep(0.001)
resp2 = client.put(
"/predictions/5678efgh",
json={"input": {"sleep": 1}},
Expand Down Expand Up @@ -492,12 +493,12 @@ def test_prediction_cancel(client):
)
assert resp.status_code == 202

resp = client.post("/predictions/456/cancel")
assert resp.status_code == 404

resp = client.post("/predictions/123/cancel")
assert resp.status_code == 200

resp = client.post("/predictions/456/cancel")
assert resp.status_code == 404


@uses_predictor_with_client_options(
"setup_weights",
Expand Down
57 changes: 33 additions & 24 deletions python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import mock

import pytest
import pytest_asyncio
from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent
from cog.server.eventtypes import (
Done,
Expand All @@ -26,24 +27,25 @@ def _fixture_path(name):
return os.path.join(test_dir, f"fixtures/{name}.py") + ":Predictor"


@pytest.fixture
def runner():
@pytest_asyncio.fixture
async def runner():
runner = PredictionRunner(
predictor_ref=_fixture_path("sleep"), shutdown_event=threading.Event()
)
try:
runner.setup().get(5)
await runner.setup()
yield runner
finally:
runner.shutdown()


def test_prediction_runner_setup():
@pytest.mark.asyncio
async def test_prediction_runner_setup():
runner = PredictionRunner(
predictor_ref=_fixture_path("sleep"), shutdown_event=threading.Event()
)
try:
result = runner.setup().get(5)
result = await runner.setup()

assert result["status"] == Status.SUCCEEDED
assert result["logs"] == ""
Expand All @@ -53,10 +55,11 @@ def test_prediction_runner_setup():
runner.shutdown()


def test_prediction_runner(runner):
@pytest.mark.asyncio
async def test_prediction_runner(runner):
request = PredictionRequest(input={"sleep": 0.1})
_, async_result = runner.predict(request)
response = async_result.get(timeout=1)
response = await async_result
assert response.output == "done in 0.1 seconds"
assert response.status == "succeeded"
assert response.error is None
Expand All @@ -65,53 +68,57 @@ def test_prediction_runner(runner):
assert isinstance(response.completed_at, datetime)


def test_prediction_runner_called_while_busy(runner):
@pytest.mark.asyncio
async def test_prediction_runner_called_while_busy(runner):
request = PredictionRequest(input={"sleep": 0.1})
_, async_result = runner.predict(request)

assert runner.is_busy()
with pytest.raises(RunnerBusyError):
runner.predict(request)
await runner.predict(request)[1]

# Call .get() to ensure that the first prediction is scheduled before we
# Await to ensure that the first prediction is scheduled before we
# attempt to shut down the runner.
async_result.get()
await async_result


def test_prediction_runner_called_while_busy_idempotent(runner):
@pytest.mark.asyncio
async def test_prediction_runner_called_while_busy_idempotent(runner):
request = PredictionRequest(id="abcd1234", input={"sleep": 0.1})

runner.predict(request)
runner.predict(request)
_, async_result = runner.predict(request)

response = async_result.get(timeout=1)
response = await async_result
assert response.id == "abcd1234"
assert response.output == "done in 0.1 seconds"
assert response.status == "succeeded"


def test_prediction_runner_called_while_busy_idempotent_wrong_id(runner):
@pytest.mark.asyncio
async def test_prediction_runner_called_while_busy_idempotent_wrong_id(runner):
request1 = PredictionRequest(id="abcd1234", input={"sleep": 0.1})
request2 = PredictionRequest(id="5678efgh", input={"sleep": 0.1})

_, async_result = runner.predict(request1)
with pytest.raises(RunnerBusyError):
runner.predict(request2)

response = async_result.get(timeout=1)
response = await async_result
assert response.id == "abcd1234"
assert response.output == "done in 0.1 seconds"
assert response.status == "succeeded"


def test_prediction_runner_cancel(runner):
@pytest.mark.asyncio
async def test_prediction_runner_cancel(runner):
request = PredictionRequest(input={"sleep": 0.5})
_, async_result = runner.predict(request)

runner.cancel()

response = async_result.get(timeout=1)
response = await async_result
assert response.output is None
assert response.status == "canceled"
assert response.error is None
Expand All @@ -120,25 +127,27 @@ def test_prediction_runner_cancel(runner):
assert isinstance(response.completed_at, datetime)


def test_prediction_runner_cancel_matching_id(runner):
@pytest.mark.asyncio
async def test_prediction_runner_cancel_matching_id(runner):
request = PredictionRequest(id="abcd1234", input={"sleep": 0.5})
_, async_result = runner.predict(request)

runner.cancel(prediction_id="abcd1234")

response = async_result.get(timeout=1)
response = await async_result
assert response.output is None
assert response.status == "canceled"


def test_prediction_runner_cancel_by_mismatched_id(runner):
@pytest.mark.asyncio
async def test_prediction_runner_cancel_by_mismatched_id(runner):
request = PredictionRequest(id="abcd1234", input={"sleep": 0.5})
_, async_result = runner.predict(request)

with pytest.raises(UnknownPredictionError):
runner.cancel(prediction_id="5678efgh")

response = async_result.get(timeout=1)
response = await async_result
assert response.output == "done in 0.5 seconds"
assert response.status == "succeeded"

Expand Down Expand Up @@ -188,15 +197,15 @@ def predict(self, input_, poll=None):

return FakeWorker()


@pytest.mark.asyncio
@pytest.mark.parametrize("events,calls", PREDICT_TESTS)
def test_predict(events, calls):
async def test_predict(events, calls):
worker = fake_worker(events)
request = PredictionRequest(input={"text": "hello"}, foo="bar")
event_handler = mock.Mock()
should_cancel = threading.Event()

predict(
await predict(
worker=worker,
request=request,
event_handler=event_handler,
Expand Down

0 comments on commit 6729d53

Please sign in to comment.