Skip to content

Commit

Permalink
Add 'future' type
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewagner committed Oct 7, 2024
1 parent 02d3636 commit 9e92684
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 55 deletions.
162 changes: 112 additions & 50 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class BorrowType(ValType):
class StreamType(ValType):
t: ValType

@dataclass
class FutureType(ValType):
t: ValType

### CallContext

class CallContext:
Expand Down Expand Up @@ -202,7 +206,7 @@ class CanonicalOptions:

class ComponentInstance:
resources: ResourceTables
waitables: Table[Subtask|StreamHandle]
waitables: Table[Subtask|StreamHandle|FutureHandle]
num_tasks: int
may_leave: bool
backpressure: bool
Expand All @@ -211,7 +215,7 @@ class ComponentInstance:

def __init__(self):
self.resources = ResourceTables()
self.waitables = Table[Subtask|StreamHandle]()
self.waitables = Table[Subtask|StreamHandle|FutureHandle]()
self.num_tasks = 0
self.may_leave = True
self.backpressure = False
Expand Down Expand Up @@ -308,6 +312,8 @@ class EventCode(IntEnum):
YIELDED = 4
STREAM_READ = 5
STREAM_WRITE = 6
FUTURE_READ = 7
FUTURE_WRITE = 8

EventTuple = tuple[EventCode, int, int]
EventCallback = Callable[[], Optional[EventTuple]]
Expand Down Expand Up @@ -565,7 +571,7 @@ def lower(self, vs):
store_list_into_valid_range(self._cx, vs, ptr, self._t)
self._progress += len(vs)

class Stream:
class AsyncValue:
closed: Callable[[]]
read: Callable[[WritableBuffer, OnBlockCallback], Awaitable]
cancel_read: Callable[[WritableBuffer, OnBlockCallback], Awaitable]
Expand All @@ -579,38 +585,38 @@ def __init__(self, impl):
self.maybe_writer_handle_index = impl.maybe_writer_handle_index
self.close = impl.close

class StreamHandle:
stream: Stream
class AsyncValueHandle:
async_value: AsyncValue
t: ValType
cx: Optional[CallContext]
copying_buffer: Optional[Buffer]

def __init__(self, stream, t, cx):
self.stream = stream
def __init__(self, async_value, t, cx):
self.async_value = async_value
self.t = t
self.cx = cx
self.copying_buffer = None

def drop(self):
trap_if(self.copying_buffer)
if not self.stream.closed():
self.stream.close()
if not self.async_value.closed():
self.async_value.close()
if self.cx:
self.cx.task.need_to_drop -= 1

class ReadableStreamHandle(StreamHandle):
class ReadableAsyncValueHandle(AsyncValueHandle):
async def copy(self, dst, on_block):
await self.stream.read(dst, on_block)
await self.async_value.read(dst, on_block)
async def cancel_copy(self, dst, on_block):
await self.stream.cancel_read(dst, on_block)
await self.async_value.cancel_read(dst, on_block)

class WritableStreamHandle(StreamHandle):
class WritableAsyncValueHandle(AsyncValueHandle):
closed: bool
rendezvous_buffer: Optional[Buffer]
rendezvous_future: Optional[asyncio.Future]

def __init__(self, t):
super().__init__(Stream(self), t, cx = None)
super().__init__(AsyncValue(self), t, cx = None)
self.closed = False
self.rendezvous_buffer = None
self.rendezvous_future = None
Expand All @@ -623,7 +629,7 @@ async def copy(self, src, on_block):
async def read(self, dst, on_block):
await self.rendezvous('read', dst, on_block)
async def rendezvous(self, direction, buffer, on_block):
assert(not self.stream.closed())
assert(not self.async_value.closed())
if self.rendezvous_buffer:
ncopy = min(buffer.remain(), self.rendezvous_buffer.remain())
assert(ncopy > 0)
Expand All @@ -648,15 +654,15 @@ async def cancel_copy(self, src, on_block):
async def cancel_read(self, dst, on_block):
await self.cancel_rendezvous('read', dst, on_block)
async def cancel_rendezvous(self, direction, buffer, on_block):
assert(not self.stream.closed())
assert(not self.async_value.closed())
if self.rendezvous_buffer is buffer:
self.rendezvous_buffer = None
if self.rendezvous_future:
self.rendezvous_future.set_result(None)
self.rendezvous_future = None

def maybe_writer_handle_index(self, inst):
assert(not self.stream.closed())
assert(not self.async_value.closed())
if inst is self.cx.inst:
return self.cx.task.inst.waitables.array.index(self)
return None
Expand All @@ -669,11 +675,30 @@ def close(self):
self.rendezvous_future.set_result(None)
self.rendezvous_future = None

class StreamHandle: pass
class ReadableStreamHandle(StreamHandle, ReadableAsyncValueHandle): pass
class WritableStreamHandle(StreamHandle, WritableAsyncValueHandle): pass

class FutureHandle:
async def copy(self, buffer, on_block):
assert(buffer.remain() == 1)
await super().copy(buffer, on_block)
assert(buffer.remain() == 0)
if not self.async_value.closed():
self.async_value.close()

def drop(self):
trap_if(not self.async_value.closed())
super().drop()

class ReadableFutureHandle(FutureHandle, ReadableAsyncValueHandle): pass
class WritableFutureHandle(FutureHandle, WritableAsyncValueHandle): pass

### Type utilities

def contains_async(t):
match t:
case StreamType():
case StreamType() | FutureType():
return True
case PrimValType() | OwnType() | BorrowType():
return False
Expand Down Expand Up @@ -721,7 +746,7 @@ def alignment(t):
case VariantType(cases) : return alignment_variant(cases)
case FlagsType(labels) : return alignment_flags(labels)
case OwnType() | BorrowType() : return 4
case StreamType() : return 4
case StreamType() | FutureType() : return 4

def alignment_list(elem_type, maybe_length):
if maybe_length is not None:
Expand Down Expand Up @@ -778,7 +803,7 @@ def elem_size(t):
case VariantType(cases) : return elem_size_variant(cases)
case FlagsType(labels) : return elem_size_flags(labels)
case OwnType() | BorrowType() : return 4
case StreamType() : return 4
case StreamType() | FutureType() : return 4

def elem_size_list(elem_type, maybe_length):
if maybe_length is not None:
Expand Down Expand Up @@ -839,6 +864,7 @@ def load(cx, ptr, t):
case OwnType() : return lift_own(cx, load_int(cx, ptr, 4), t)
case BorrowType() : return lift_borrow(cx, load_int(cx, ptr, 4), t)
case StreamType(t) : return lift_stream(cx, load_int(cx, ptr, 4), t)
case FutureType(t) : return lift_future(cx, load_int(cx, ptr, 4), t)

def load_int(cx, ptr, nbytes, signed = False):
return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed)
Expand Down Expand Up @@ -992,21 +1018,27 @@ def lift_borrow(cx, i, t):
cx.add_lender(h)
return h.rep

def lift_stream(cx, i, elem_type):
def lift_stream(cx, i, t):
return lift_async_value(ReadableStreamHandle, WritableStreamHandle, cx, i, t)

def lift_future(cx, i, t):
return lift_async_value(ReadableFutureHandle, WritableFutureHandle, cx, i, t)

def lift_async_value(ReadableHandleT, WritableHandleT, cx, i, t):
h = cx.inst.waitables.get(i)
trap_if(not isinstance(h, StreamHandle))
trap_if(h.t != elem_type)
trap_if(not isinstance(h, ReadableHandleT|WritableHandleT))
trap_if(h.t != t)
match h:
case ReadableStreamHandle():
case ReadableHandleT():
trap_if(h.copying_buffer)
h.cx.task.need_to_drop -= 1
cx.inst.waitables.remove(i)
case WritableStreamHandle():
case WritableHandleT():
trap_if(h.cx is not None)
assert(not h.copying_buffer)
h.cx = cx
h.cx.task.need_to_drop += 1
return h.stream
return h.async_value

### Storing

Expand Down Expand Up @@ -1034,6 +1066,7 @@ def store(cx, v, t, ptr):
case OwnType() : store_int(cx, lower_own(cx, v, t), ptr, 4)
case BorrowType() : store_int(cx, lower_borrow(cx, v, t), ptr, 4)
case StreamType(t) : store_int(cx, lower_stream(cx, v, t), ptr, 4)
case FutureType(t) : store_int(cx, lower_future(cx, v, t), ptr, 4)

def store_int(cx, v, ptr, nbytes, signed = False):
cx.opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed)
Expand Down Expand Up @@ -1296,17 +1329,23 @@ def lower_borrow(cx, rep, t):
cx.need_to_drop += 1
return cx.inst.resources.add(t.rt, h)

def lower_stream(cx, stream, elem_type):
assert(isinstance(stream, Stream))
if (i := stream.maybe_writer_handle_index(cx.inst)):
def lower_stream(cx, v, t):
return lower_async_value(ReadableStreamHandle, WritableStreamHandle, cx, v, t)

def lower_future(cx, v, t):
return lower_async_value(ReadableFutureHandle, WritableFutureHandle, cx, v, t)

def lower_async_value(ReadableHandleT, WritableHandleT, cx, v, t):
assert(isinstance(v, AsyncValue))
if (i := v.maybe_writer_handle_index(cx.inst)):
h = cx.inst.waitables.array[i]
assert(isinstance(h, WritableStreamHandle))
assert(isinstance(h, WritableHandleT))
h.cx.task.need_to_drop -= 1
h.cx = None
assert(2**31 > Table.MAX_LENGTH >= i)
return i | (2**31)
else:
h = ReadableStreamHandle(stream, elem_type, cx)
h = ReadableHandleT(v, t, cx)
cx.task.need_to_drop += 1
return cx.inst.waitables.add(h)

Expand Down Expand Up @@ -1360,7 +1399,7 @@ def flatten_type(t):
case VariantType(cases) : return flatten_variant(cases)
case FlagsType(labels) : return ['i32']
case OwnType() | BorrowType() : return ['i32']
case StreamType() : return ['i32']
case StreamType() | FutureType() : return ['i32']

def flatten_list(elem_type, maybe_length):
if maybe_length is not None:
Expand Down Expand Up @@ -1428,6 +1467,7 @@ def lift_flat(cx, vi, t):
case OwnType() : return lift_own(cx, vi.next('i32'), t)
case BorrowType() : return lift_borrow(cx, vi.next('i32'), t)
case StreamType(t) : return lift_stream(cx, vi.next('i32'), t)
case FutureType(t) : return lift_future(cx, vi.next('i32'), t)

def lift_flat_unsigned(vi, core_width, t_width):
i = vi.next('i' + str(core_width))
Expand Down Expand Up @@ -1520,6 +1560,7 @@ def lower_flat(cx, v, t):
case OwnType() : return [lower_own(cx, v, t)]
case BorrowType() : return [lower_borrow(cx, v, t)]
case StreamType(t) : return [lower_stream(cx, v, t)]
case FutureType(t) : return [lower_future(cx, v, t)]

def lower_flat_signed(i, core_bits):
if i < 0:
Expand Down Expand Up @@ -1757,39 +1798,53 @@ async def canon_task_yield(task):
await task.yield_()
return []

### 🔀 `canon stream.new`
### 🔀 `canon {stream,future}.new`

async def canon_stream_new(elem_type, task):
trap_if(not task.inst.may_leave)
h = WritableStreamHandle(elem_type)
return [ task.inst.waitables.add(h) ]

### 🔀 `canon stream.read` and `canon stream.write`
async def canon_future_new(t, task):
trap_if(not task.inst.may_leave)
h = WritableFutureHandle(t)
return [ task.inst.waitables.add(h) ]

### 🔀 `canon {stream,future}.{read,write}`

async def canon_stream_read(task, i, ptr, n):
return await stream_copy(ReadableStreamHandle, WritableBuffer,
task, i, ptr, n, EventCode.STREAM_READ)
return await canon_copy(ReadableStreamHandle, WritableBuffer,
task, i, ptr, n, EventCode.STREAM_READ)

async def canon_stream_write(task, i, ptr, n):
return await stream_copy(WritableStreamHandle, ReadableBuffer,
task, i, ptr, n, EventCode.STREAM_WRITE)
return await canon_copy(WritableStreamHandle, ReadableBuffer,
task, i, ptr, n, EventCode.STREAM_WRITE)

async def stream_copy(StreamHandleT, BufferT, task, i, ptr, n, event_code):
async def canon_future_read(task, i, ptr):
return await canon_copy(ReadableFutureHandle, WritableBuffer,
task, i, ptr, 1, EventCode.FUTURE_READ)

async def canon_future_write(task, i, ptr):
return await canon_copy(WritableFutureHandle, ReadableBuffer,
task, i, ptr, 1, EventCode.FUTURE_WRITE)

async def canon_copy(HandleT, BufferT, task, i, ptr, n, event_code):
trap_if(not task.inst.may_leave)
h = task.inst.waitables.get(i)
trap_if(not isinstance(h, StreamHandleT))
trap_if(not isinstance(h, HandleT))
trap_if(not h.cx)
trap_if(h.copying_buffer)
buffer = BufferT(h.cx, h.t, ptr, n)
if h.stream.closed():
if h.async_value.closed():
trap_if(issubclass(HandleT, FutureHandle))
flat_results = [CLOSED]
else:
async def do_copy(on_block):
await h.copy(buffer, on_block)
def stream_event():
if h.copying_buffer is buffer:
h.copying_buffer = None
return (event_code, i, pack_stream_result(buffer, h))
return (event_code, i, copy_result(HandleT, buffer, h))
else:
return None
h.cx.task.notify(stream_event)
Expand All @@ -1798,31 +1853,38 @@ def stream_event():
h.copying_buffer = buffer
flat_results = [BLOCKED]
case Returned():
flat_results = [pack_stream_result(buffer, h)]
flat_results = [copy_result(HandleT, buffer, h)]
return flat_results

def pack_stream_result(buffer, h):
def copy_result(HandleT, buffer, h):
if buffer.progress():
return buffer.progress()
assert(h.stream.closed())
assert(h.async_value.closed())
assert(not issubclass(HandleT, FutureHandle))
return CLOSED

BLOCKED = 0xffff_ffff
CLOSED = 0x8000_0000
assert(Buffer.MAX_LENGTH < CLOSED < BLOCKED)

### 🔀 `canon stream.cancel-read` and `canon stream.cancel-writing`
### 🔀 `canon {stream,future}.cancel-{read,write}`

async def canon_stream_cancel_read(sync, task, i):
return await stream_cancel_copy(ReadableStreamHandle, sync, task, i)
return await canon_cancel_copy(ReadableStreamHandle, sync, task, i)

async def canon_stream_cancel_write(sync, task, i):
return await stream_cancel_copy(WritableStreamHandle, sync, task, i)
return await canon_cancel_copy(WritableStreamHandle, sync, task, i)

async def canon_future_cancel_read(sync, task, i):
return await canon_cancel_copy(ReadableFutureHandle, sync, task, i)

async def canon_future_cancel_write(sync, task, i):
return await canon_cancel_copy(WritableFutureHandle, sync, task, i)

async def stream_cancel_copy(StreamHandleT, sync, task, i):
async def canon_cancel_copy(HandleT, sync, task, i):
trap_if(not task.inst.may_leave)
h = task.inst.waitables.get(i)
trap_if(not isinstance(h, StreamHandleT))
trap_if(not isinstance(h, HandleT))
trap_if(not h.copying_buffer)
if sync:
await task.call_sync(h.cancel_copy, h.copying_buffer)
Expand Down
Loading

0 comments on commit 9e92684

Please sign in to comment.