Skip to content

Commit

Permalink
route prediction logs to prediction if only one is running and implem…
Browse files Browse the repository at this point in the history
…ent anext for old pythons

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Dec 5, 2023
1 parent d9ac267 commit 688b152
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
7 changes: 6 additions & 1 deletion python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,14 @@ async def _read_events(self) -> None:
if id == "LOG" and "SETUP" in self._mux.outs:
id = "SETUP"
# also if one prediction
maybe_ids = list(set(self._mux.outs) - {"LOG", "SETUP"})
# this can cause logs to be lost incorrectly
# once we track predictions in flight properly that should be fine
if len(maybe_ids) == 1:
id = maybe_ids[0]
await self._mux.write(id, event)
if isinstance(event, Done):
# prediction_ctx should do this for us
# prediction_ctx does this for us
# self._state = WorkerState.IDLE
self._allow_cancel = False
# If we dropped off the end off the end of the loop, check if it's
Expand Down
31 changes: 19 additions & 12 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import inspect
import os
import sys
import time
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, TypeVar
from typing import Any, AsyncIterator, Awaitable, Coroutine, Iterator, Optional, TypeVar

import pytest
from attrs import define
Expand Down Expand Up @@ -90,6 +91,14 @@
)
]

T = TypeVar("T")

# anext was added in 3.10
if sys.version_info.minor <= 9:

def anext(gen: "AsyncIterator[T] | Coroutine[None, None, T]") -> Awaitable[T]:
return gen.__anext__()


@define
class Result:
Expand All @@ -102,9 +111,6 @@ class Result:
exception: Optional[Exception] = None


T = TypeVar("T")


async def make_generator_async(it: Iterator[T]) -> AsyncIterator[T]:
for item in iter(it):
yield item
Expand Down Expand Up @@ -351,7 +357,7 @@ async def test_cancel_idempotency():

result2 = await _process(w.predict({"sleep": 0.1}))

assert not result2.done.canceled
assert result2.done and not result2.done.canceled
assert result2.output == "done in 0.1 seconds"
finally:
w.terminate()
Expand Down Expand Up @@ -463,7 +469,7 @@ async def test_graceful_shutdown():
assert result.output == "done in 1 seconds"
finally:
w.terminate()
await asyncio.sleep(0) # let terminate clean up
await asyncio.sleep(0) # let terminate clean up


class WorkerState(RuleBasedStateMachine):
Expand Down Expand Up @@ -521,7 +527,7 @@ def setup(self):
# @rule(n=st.integers(min_value=1, max_value=10))
# def read_setup_events(self, n):
try:
while 1: # we want all the events, not just "n" events
while 1: # we want all the events, not just "n" events
event = self.await_(anext(self.setup_generator))
self.setup_events.append(event)
# if we wanted to keep the InvalidStateException to ignore setup assert_state(NEW)
Expand Down Expand Up @@ -556,9 +562,11 @@ def predict(self, name, steps):
# @rule(n=st.integers(min_value=1, max_value=10))
# def read_predict_events(self, n):
try:
while 1: # we want all the events, not just "n" events
# trace("slep")
# self.await_(asyncio.sleep(0.01))
while 1: # we want all the events, not just "n" events
# this sleep *shouldn't* be necessary
# but without it sometimes output is missed?
trace("slep")
self.await_(asyncio.sleep(0.01))
trace("getting predict event")
event = self.await_(anext(self.predict_generator))
trace("got predict event")
Expand All @@ -580,8 +588,7 @@ async def _check_predict_events(self):
expected_stdout.append(f"STEP {i+1}\n")
expected_stdout.append("END\n")

# logs get swallowed for now
# assert result.stdout == "".join(expected_stdout)
assert result.stdout == "".join(expected_stdout)
assert result.stderr == ""
assert result.output == f"NAME={payload['name']}"
assert result.done == Done()
Expand Down

0 comments on commit 688b152

Please sign in to comment.