From 294d6032368155f61880c188a204ad99516891be Mon Sep 17 00:00:00 2001 From: technillogue Date: Tue, 5 Dec 2023 16:53:44 -0700 Subject: [PATCH] keep track of predictions in flight explicitly and use that to route logs Signed-off-by: technillogue --- python/cog/server/worker.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 72584cc0f..d97630c59 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -104,6 +104,7 @@ def __init__( # stop reading events self._terminating = asyncio.Event() self._mux = Mux(self._terminating) + self._predictions_in_flight = set() # @contextlib.contextmanager # def ctx(self): @@ -132,11 +133,15 @@ async def setup(self) -> AsyncIterator[_PublicEventType]: self._allow_cancel = False @contextlib.asynccontextmanager - async def prediction_ctx(self) -> AsyncIterator[None]: + async def prediction_ctx(self, input: PredictionInput) -> AsyncIterator[None]: async with self._semaphore: self._state = WorkerState.PROCESSING self._allow_cancel = True - yield + self._predictions_in_flight.add(input.id) + try: + yield + finally: + self._predictions_in_flight.remove(input.id) if self._semaphore._value == self._concurrent: self._state = WorkerState.IDLE self._allow_cancel = False @@ -148,8 +153,8 @@ async def predict( raise InvalidStateException( "cannot accept new predictions because shutdown requested" ) - async with self.prediction_ctx(): - input = PredictionInput(payload=payload) + input = PredictionInput(payload=payload) + async with self.prediction_ctx(input): self._events.send(input) self._ensure_event_reader() async for event in self._mux.read(input.id, poll=poll): @@ -224,6 +229,8 @@ async def _read_events(self) -> None: id, event = result if id == "LOG" and self._state == WorkerState.STARTING: id = "SETUP" + if id == "LOG" and len(self._predictions_in_flight) == 1: + id = list(self._predictions_in_flight)[0] await self._mux.write(id, event) # If we dropped off the end off the end of the loop, check if it's # because the child process died.