From 688b152a584a0f1b28fbfe03829e120138eaa614 Mon Sep 17 00:00:00 2001 From: technillogue Date: Mon, 4 Dec 2023 11:29:22 -0700 Subject: [PATCH] route prediction logs to prediction if only one is running and implement anext for old pythons Signed-off-by: technillogue --- python/cog/server/worker.py | 7 ++++++- python/tests/server/test_worker.py | 31 ++++++++++++++++++------------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index c583ad7bb..4728d3b29 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -256,9 +256,14 @@ async def _read_events(self) -> None: if id == "LOG" and "SETUP" in self._mux.outs: id = "SETUP" # also if one prediction + maybe_ids = list(set(self._mux.outs) - {"LOG", "SETUP"}) + # this can cause logs to be lost incorrectly + # once we track predictions in flight properly that should be fine + if len(maybe_ids) == 1: + id = maybe_ids[0] await self._mux.write(id, event) if isinstance(event, Done): - # prediction_ctx should do this for us + # prediction_ctx does this for us # self._state = WorkerState.IDLE self._allow_cancel = False # If we dropped off the end off the end of the loop, check if it's diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 1f02904fe..f993f3085 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -1,8 +1,9 @@ import asyncio import inspect import os +import sys import time -from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, TypeVar +from typing import Any, AsyncIterator, Awaitable, Coroutine, Iterator, Optional, TypeVar import pytest from attrs import define @@ -90,6 +91,14 @@ ) ] +T = TypeVar("T") + +# anext was added in 3.10 +if sys.version_info.minor <= 9: + + def anext(gen: "AsyncIterator[T] | Coroutine[None, None, T]") -> Awaitable[T]: + return gen.__anext__() + @define class Result: @@ -102,9 +111,6 @@ class Result: exception: Optional[Exception] = None -T = TypeVar("T") - - async def make_generator_async(it: Iterator[T]) -> AsyncIterator[T]: for item in iter(it): yield item @@ -351,7 +357,7 @@ async def test_cancel_idempotency(): result2 = await _process(w.predict({"sleep": 0.1})) - assert not result2.done.canceled + assert result2.done and not result2.done.canceled assert result2.output == "done in 0.1 seconds" finally: w.terminate() @@ -463,7 +469,7 @@ async def test_graceful_shutdown(): assert result.output == "done in 1 seconds" finally: w.terminate() - await asyncio.sleep(0) # let terminate clean up + await asyncio.sleep(0) # let terminate clean up class WorkerState(RuleBasedStateMachine): @@ -521,7 +527,7 @@ def setup(self): # @rule(n=st.integers(min_value=1, max_value=10)) # def read_setup_events(self, n): try: - while 1: # we want all the events, not just "n" events + while 1: # we want all the events, not just "n" events event = self.await_(anext(self.setup_generator)) self.setup_events.append(event) # if we wanted to keep the InvalidStateException to ignore setup assert_state(NEW) @@ -556,9 +562,11 @@ def predict(self, name, steps): # @rule(n=st.integers(min_value=1, max_value=10)) # def read_predict_events(self, n): try: - while 1: # we want all the events, not just "n" events - # trace("slep") - # self.await_(asyncio.sleep(0.01)) + while 1: # we want all the events, not just "n" events + # this sleep *shouldn't* be necessary + # but without it sometimes output is missed? + trace("slep") + self.await_(asyncio.sleep(0.01)) trace("getting predict event") event = self.await_(anext(self.predict_generator)) trace("got predict event") @@ -580,8 +588,7 @@ async def _check_predict_events(self): expected_stdout.append(f"STEP {i+1}\n") expected_stdout.append("END\n") - # logs get swallowed for now - # assert result.stdout == "".join(expected_stdout) + assert result.stdout == "".join(expected_stdout) assert result.stderr == "" assert result.output == f"NAME={payload['name']}" assert result.done == Done()