Skip to content

Commit

Permalink
keep track of predictions in flight explicitly and use that to route …
Browse files Browse the repository at this point in the history
…logs

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Dec 7, 2023
1 parent 0f09a9e commit 294d603
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
# stop reading events
self._terminating = asyncio.Event()
self._mux = Mux(self._terminating)
self._predictions_in_flight = set()

# @contextlib.contextmanager
# def ctx(self):
Expand Down Expand Up @@ -132,11 +133,15 @@ async def setup(self) -> AsyncIterator[_PublicEventType]:
self._allow_cancel = False

@contextlib.asynccontextmanager
async def prediction_ctx(self) -> AsyncIterator[None]:
async def prediction_ctx(self, input: PredictionInput) -> AsyncIterator[None]:
async with self._semaphore:
self._state = WorkerState.PROCESSING
self._allow_cancel = True
yield
self._predictions_in_flight.add(input.id)
try:
yield
finally:
self._predictions_in_flight.remove(input.id)
if self._semaphore._value == self._concurrent:
self._state = WorkerState.IDLE
self._allow_cancel = False
Expand All @@ -148,8 +153,8 @@ async def predict(
raise InvalidStateException(
"cannot accept new predictions because shutdown requested"
)
async with self.prediction_ctx():
input = PredictionInput(payload=payload)
input = PredictionInput(payload=payload)
async with self.prediction_ctx(input):
self._events.send(input)
self._ensure_event_reader()
async for event in self._mux.read(input.id, poll=poll):
Expand Down Expand Up @@ -224,6 +229,8 @@ async def _read_events(self) -> None:
id, event = result
if id == "LOG" and self._state == WorkerState.STARTING:
id = "SETUP"
if id == "LOG" and len(self._predictions_in_flight) == 1:
id = list(self._predictions_in_flight)[0]
await self._mux.write(id, event)
# If we dropped off the end off the end of the loop, check if it's
# because the child process died.
Expand Down

0 comments on commit 294d603

Please sign in to comment.