Skip to content

Commit

Permalink
use aioprocessing to unblock it
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed Nov 29, 2023
1 parent 2dcdceb commit b835dbe
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
'typing-compat; python_version < "3.8"',
"typing_extensions>=4.1.0",
"uvicorn[standard]>=0.12,<1",
"aioprocessing"
]

optional-dependencies = { "dev" = [
Expand Down
29 changes: 29 additions & 0 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,32 @@ def run(self) -> None:
if drain_tokens_seen >= drain_tokens_needed:
self.drain_event.set()
drain_tokens_seen = 0


# class AsyncPipe:
# def __init__(self, conn: Connection) -> None:
# self.conn = conn
# self.recv = conn.recv
# self.send = conn.send
# self.poll = conn.poll
# self.data_available = asyncio.Event()
# # self.data_available = asyncio.Semaphore(0)

# started = False

# def start_reader(self) -> None:
# if not self.started:
# print("started reader")
# self.started = True
# loop = asyncio.get_running_loop()
# loop.add_reader(self.conn.fileno(), self.data_available.set)
# # loop.add_reader(conn.fileno(), self.data_available.release)

# async def async_recv(self) -> Any:
# self.start_reader()
# # await self.data_available.acquire()
# if not self.conn.poll():
# await self.data_available.wait()
# data = self.conn.recv()
# self.data_available.clear()
# return data
25 changes: 13 additions & 12 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, TextIO, Union
import aioprocessing

from ..json import make_encodeable
from ..predictor import (
Expand All @@ -34,7 +35,7 @@
FatalWorkerException,
InvalidStateException,
)
from .helpers import StreamRedirector, WrappedStream
from .helpers import StreamRedirector, WrappedStream # , AsyncPipe

_spawn = multiprocessing.get_context("spawn")

Expand All @@ -59,7 +60,9 @@ class WorkerState(Enum):

class Mux:
def __init__(self) -> None:
self.outs: defaultdict[str, asyncio.Queue[_PublicEventType]] = defaultdict(asyncio.Queue)
self.outs: defaultdict[str, asyncio.Queue[_PublicEventType]] = defaultdict(
asyncio.Queue
)

async def write(self, id: str, item: _PublicEventType) -> None:
await self.outs[id].put(item)
Expand All @@ -79,7 +82,9 @@ def __init__(self, predictor_ref: str, tee_output: bool = True) -> None:
self._allow_cancel = False

# A pipe with which to communicate with the child worker.
self._events, child_events = _spawn.Pipe()
# events, child_events = _spawn.Pipe()
# self._events = AsyncPipe(events)
self._events, child_events = aioprocessing.AioPipe()
self._child = _ChildWorker(predictor_ref, child_events, tee_output)
self._terminating = False
self._mux = Mux()
Expand All @@ -100,6 +105,7 @@ async def predict(
self._allow_cancel = True
input = PredictionInput(payload=payload)
self._events.send(input)
self._ensure_event_reader()
async for event in self._mux.read(input.id):
yield event

Expand Down Expand Up @@ -147,10 +153,7 @@ def _ensure_event_reader(self) -> None:

async def _read_events(self, poll: Optional[float] = None) -> None:
while self._child.is_alive():
# naughty! blocks the event loop!
if not self._events.poll(poll):
continue
id, event = self._events.recv()
id, event = await self._events.coro_recv()
await self._mux.write(id, event)

def _wait(
Expand All @@ -169,8 +172,6 @@ def _wait(
if send_heartbeats:
yield Heartbeat()
continue
# this needs aioprocessing.Pipe or similar
# multiprocessing.Pipe is not async
id, ev = self._events.recv()
yield ev

Expand Down Expand Up @@ -201,7 +202,7 @@ def __init__(
) -> None:
self._predictor_ref = predictor_ref
self._predictor: Optional[BasePredictor] = None
self._events = events
self._events = events # AsyncPipe(events)
self._tee_output = tee_output
self._cancelable = False

Expand Down Expand Up @@ -281,7 +282,7 @@ def _loop_sync(self) -> None:

async def _loop_async(self) -> None:
while True:
ev = self._events.recv()
ev = await self._events.coro_recv()
if isinstance(ev, Shutdown):
break
if isinstance(ev, PredictionInput):
Expand Down Expand Up @@ -358,7 +359,7 @@ def _stream_write_hook(
if self._tee_output:
original_stream.write(data)
original_stream.flush()
self._events.send(("", Log(data, source=stream_name)))
self._events.send(("LOG", Log(data, source=stream_name)))


def get_loop() -> asyncio.AbstractEventLoop:
Expand Down

0 comments on commit b835dbe

Please sign in to comment.