Skip to content

Commit

Permalink
review changes to tests and server
Browse files Browse the repository at this point in the history
Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Nov 22, 2023
1 parent 087f482 commit b002d54
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
6 changes: 4 additions & 2 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ async def _predict(
return JSONResponse(jsonable_encoder(initial_response), status_code=202)

try:
response = PredictionResponse(**(await async_result).dict())
prediction = await async_result
response = PredictionResponse(**prediction.dict())
except ValidationError as e:
_log_invalid_output(e)
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down Expand Up @@ -246,7 +247,8 @@ async def _check_setup_result() -> Any:
if not app.state.setup_result.done():
return

result = await app.state.setup_result
# this can raise CancelledError
result = app.state.setup_result.result()

if result["status"] == schema.Status.SUCCEEDED:
app.state.health = Health.READY
Expand Down
4 changes: 3 additions & 1 deletion python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def is_busy(self) -> bool:
return False

def shutdown(self) -> None:
if self._result:
self._result.cancel()
self._worker.terminate()
# TODO: cancel setup or predict task

def cancel(self, prediction_id: Optional[str] = None) -> None:
if not self.is_busy():
Expand Down Expand Up @@ -280,6 +281,7 @@ async def setup(*, worker: Worker) -> Dict[str, Any]:
try:
# will be async
for event in worker.setup():
await asyncio.sleep(0)
if isinstance(event, Log):
logs.append(event.message)
elif isinstance(event, Done):
Expand Down
7 changes: 3 additions & 4 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,6 @@ def test_prediction_idempotent_endpoint_conflict(client, match):
json={"input": {"sleep": 1}},
headers={"Prefer": "respond-async"},
)
time.sleep(0.001)
resp2 = client.put(
"/predictions/5678efgh",
json={"input": {"sleep": 1}},
Expand Down Expand Up @@ -493,12 +492,12 @@ def test_prediction_cancel(client):
)
assert resp.status_code == 202

resp = client.post("/predictions/123/cancel")
assert resp.status_code == 200

resp = client.post("/predictions/456/cancel")
assert resp.status_code == 404

resp = client.post("/predictions/123/cancel")
assert resp.status_code == 200


@uses_predictor_with_client_options(
"setup_weights",
Expand Down
6 changes: 4 additions & 2 deletions python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import threading
from datetime import datetime
Expand Down Expand Up @@ -75,7 +76,8 @@ async def test_prediction_runner_called_while_busy(runner):

assert runner.is_busy()
with pytest.raises(RunnerBusyError):
await runner.predict(request)[1]
_, task = runner.predict(request)
await task

# Await to ensure that the first prediction is scheduled before we
# attempt to shut down the runner.
Expand All @@ -90,7 +92,7 @@ async def test_prediction_runner_called_while_busy_idempotent(runner):
runner.predict(request)
_, async_result = runner.predict(request)

response = await async_result
response = await asyncio.wait_for(async_result, timeout=1)
assert response.id == "abcd1234"
assert response.output == "done in 0.1 seconds"
assert response.status == "succeeded"
Expand Down

0 comments on commit b002d54

Please sign in to comment.