Skip to content

Commit

Permalink
Add trio.open_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
njsmith committed Aug 22, 2018
1 parent 7f0dd7d commit abca8bd
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 0 deletions.
3 changes: 3 additions & 0 deletions trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from ._sync import *
__all__ += _sync.__all__

from ._channel import *
__all__ += _channel.__all__

from ._threads import *
__all__ += _threads.__all__

Expand Down
209 changes: 209 additions & 0 deletions trio/_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from collections import deque, OrderedDict
from math import inf

import attr
from outcome import Error, Value

from . import _core
from ._util import aiter_compat

__all__ = ["open_channel", "EndOfChannel", "BrokenChannelError"]

# TODO:
# - introspection:
# - statistics
# - capacity, usage
# - repr
# - BrokenResourceError?
# - tests
# - docs
# - should there be a put_back method that inserts an item at the front of the
# queue, while ignoring length limits? (the idea being that you call this
# from a task that is also doing get(), and making get() block on put() is
# a ticket to deadlock city) Example use case: depth-first traversal of a
# directory tree. (Well... does this work? If you start out 10-wide then you
# won't converge on a single DFS quickly, or maybe at all... is that still
# good? do you actually want a priority queue that sorts by depth? maybe
# that is what you want. Huh.)


class EndOfChannel(Exception):
pass


class BrokenChannelError(Exception):
pass


def open_channel(capacity):
if capacity != inf and not isinstance(capacity, int):
raise TypeError("capacity must be an integer or math.inf")
if capacity < 0:
raise ValueError("capacity must be >= 0")
buf = ChannelBuf(capacity)
return PutChannel(buf), GetChannel(buf)


@attr.s(cmp=False, hash=False)
class ChannelBuf:
capacity = attr.ib()
data = attr.ib(default=attr.Factory(deque))
# counts
put_channels = attr.ib(default=0)
get_channels = attr.ib(default=0)
# {task: value}
put_tasks = attr.ib(default=attr.Factory(OrderedDict))
# {task: None}
get_tasks = attr.ib(default=attr.Factory(OrderedDict))


class PutChannel:
def __init__(self, buf):
self._buf = buf
self.closed = False
self._tasks = set()
self._buf.put_channels += 1

@_core.disable_ki_protection
def put_nowait(self, value):
if self.closed:
raise _core.ClosedResourceError
if not self._buf.get_channels:
raise BrokenChannelError
if self._buf.get_tasks:
assert not self._buf.data
task, _ = self._buf.get_tasks.popitem(last=False)
task.custom_sleep_data._tasks.remove(task)
_core.reschedule(task, Value(value))
elif len(self._buf.data) < self._buf.capacity:
self._buf.data.append(value)
else:
raise _core.WouldBlock

@_core.disable_ki_protection
async def put(self, value):
await _core.checkpoint_if_cancelled()
try:
self.put_nowait(value)
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return

task = _core.current_task()
self._tasks.add(task)
self._buf.put_tasks[task] = value
task.custom_sleep_data = self

def abort_fn(_):
self._tasks.remove(task)
del self._buf.put_tasks[task]
return _core.Abort.SUCCEEDED

await _core.wait_task_rescheduled(abort_fn)

@_core.disable_ki_protection
def clone(self):
if self.closed:
raise _core.ClosedResourceError
return PutChannel(self._buf)

@_core.disable_ki_protection
def close(self):
if self.closed:
return
self.closed = True
for task in list(self._tasks):
_core.reschedule(task, Error(ClosedResourceError()))
self._buf.put_channels -= 1
if self._buf.put_channels == 0:
assert not self._buf.put_tasks
for task in list(self._buf.get_tasks):
_core.reschedule(task, Error(EndOfChannel()))

def __enter__(self):
return self

def __exit__(self, *args):
self.close()


class GetChannel:
def __init__(self, buf):
self._buf = buf
self.closed = False
self._tasks = set()
self._buf.get_channels += 1

@_core.disable_ki_protection
def get_nowait(self):
if self.closed:
raise _core.ClosedResourceError
buf = self._buf
if buf.put_tasks:
task, value = buf.put_tasks.popitem(last=False)
task.custom_sleep_data._tasks.remove(task)
_core.reschedule(task)
buf.data.append(value)
# Fall through
if buf.data:
return buf.data.popleft()
if not buf.put_channels:
raise EndOfChannel
raise _core.WouldBlock

@_core.disable_ki_protection
async def get(self):
await _core.checkpoint_if_cancelled()
try:
value = self.get_nowait()
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return value

task = _core.current_task()
self._tasks.add(task)
self._buf.get_tasks[task] = None
task.custom_sleep_data = self

def abort_fn(_):
self._tasks.remove(task)
del self._buf.get_tasks[task]
return _core.Abort.SUCCEEDED

return await _core.wait_task_rescheduled(abort_fn)

@_core.disable_ki_protection
def close(self):
if self.closed:
return
self.closed = True
for task in list(self._tasks):
_core.reschedule(task, Error(ClosedResourceError()))
self._buf.get_channels -= 1
if self._buf.get_channels == 0:
assert not self._buf.get_tasks
for task in list(self._buf.put_tasks):
_core.reschedule(task, Error(BrokenChannelError()))
# XX: or if we're losing data, maybe we should raise a
# BrokenChannelError here?
self._buf.data.clear()

@aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
try:
return await self.get()
except EndOfChannel:
raise StopAsyncIteration

def __enter__(self):
return self

def __exit__(self, *args):
self.close()
76 changes: 76 additions & 0 deletions trio/tests/test_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest

from ..testing import wait_all_tasks_blocked, assert_checkpoints
import trio
from trio import open_channel, EndOfChannel, BrokenChannelError

async def test_channel():
with pytest.raises(TypeError):
open_channel(1.0)
with pytest.raises(ValueError):
open_channel(-1)

p, g = open_channel(2)
repr(p) # smoke test
repr(g) # smoke test

p.put_nowait(1)
with assert_checkpoints():
await p.put(2)
with pytest.raises(trio.WouldBlock):
p.put_nowait(None)

with assert_checkpoints():
assert await g.get() == 1
assert g.get_nowait() == 2
with pytest.raises(trio.WouldBlock):
g.get_nowait()

p.put_nowait("last")
p.close()
with pytest.raises(trio.ClosedResourceError):
await p.put("too late")

assert g.get_nowait() == "last"
with pytest.raises(EndOfChannel):
await g.get()


async def test_553(autojump_clock):
p, g = open_channel(1)
with trio.move_on_after(10) as timeout_scope:
await g.get()
assert timeout_scope.cancelled_caught
await p.put("Test for PR #553")


async def test_channel_fan_in():
async def producer(put_channel, i):
# We close our handle when we're done with it
with put_channel:
for j in range(3 * i, 3 * (i + 1)):
await put_channel.put(j)

put_channel, get_channel = open_channel(0)
async with trio.open_nursery() as nursery:
# We hand out clones to all the new producers, and then close the
# original.
with put_channel:
for i in range(10):
nursery.start_soon(producer, put_channel.clone(), i)

got = []
async for value in get_channel:
got.append(value)

got.sort()
assert got == list(range(30))


# tests to add:
# - put/get close wakes other puts/gets on close
# - for put, only wakes the ones on the same handle
# - get close -> also wakes puts
# - and future puts raise
# - all the queue tests, including e.g. fairness tests
# - statistics

0 comments on commit abca8bd

Please sign in to comment.