Skip to content

Commit

Permalink
Add streaming API
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo authored and JamesWrigley committed Nov 15, 2024
1 parent cbac605 commit e441bd0
Show file tree
Hide file tree
Showing 8 changed files with 762 additions and 1 deletion.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ makedocs(;
"Task Spawning" => "task-spawning.md",
"Data Management" => "data-management.md",
"Distributed Arrays" => "darray.md",
"Streaming Tasks" => "streaming.md",
"Scopes" => "scopes.md",
"Processors" => "processors.md",
"Task Queues" => "task-queues.md",
Expand Down
105 changes: 105 additions & 0 deletions docs/src/streaming.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Streaming Tasks

Dagger tasks have a limited lifetime - they are created, execute, finish, and
are eventually destroyed when they're no longer needed. Thus, if one wants
to run the same kind of computations over and over, one might re-create a
similar set of tasks for each unit of data that needs processing.

This might be fine for computations which take a long time to run (thus
dwarfing the cost of task creation, which is quite small), or when working with
a limited set of data, but this approach is not great for doing lots of small
computations on a large (or endless) amount of data. For example, processing
image frames from a webcam, reacting to messages from a message bus, reading
samples from a software radio, etc. All of these tasks are better suited to a
"streaming" model of data processing, where data is simply piped into a
continuously-running task (or DAG of tasks) forever, or until the data runs
out.

Thankfully, if you have a problem which is best modeled as a streaming system
of tasks, Dagger has you covered! Building on its support for
["Task Queues"](@ref), Dagger provides a means to convert an entire DAG of
tasks into a streaming DAG, where data flows into and out of each task
asynchronously, using the `spawn_streaming` function:

```julia
Dagger.spawn_streaming() do # enters a streaming region
vals = Dagger.@spawn rand()
print_vals = Dagger.@spawn println(vals)
end # exits the streaming region, and starts the DAG running
```

In the above example, `vals` is a Dagger task which has been transformed to run
in a streaming manner - instead of just calling `rand()` once and returning its
result, it will re-run `rand()` endlessly, continuously producing new random
values. In typical Dagger style, `print_vals` is a Dagger task which depends on
`vals`, but in streaming form - it will continuously `println` the random
values produced from `vals`. Both tasks will run forever, and will run
efficiently, only doing the work necessary to generate, transfer, and consume
values.

As the comments point out, `spawn_streaming` creates a streaming region, during
which `vals` and `print_vals` are created and configured. Both tasks are halted
until `spawn_streaming` returns, allowing large DAGs to be built all at once,
without any task losing a single value. If desired, streaming regions can be
connected, although some values might be lost while tasks are being connected:

```julia
vals = Dagger.spawn_streaming() do
Dagger.@spawn rand()
end

# Some values might be generated by `vals` but thrown away
# before `print_vals` is fully setup and connected to it

print_vals = Dagger.spawn_streaming() do
Dagger.@spawn println(vals)
end
```

More complicated streaming DAGs can be easily constructed, without doing
anything different. For example, we can generate multiple streams of random
numbers, write them all to their own files, and print the combined results:

```julia
Dagger.spawn_streaming() do
all_vals = [Dagger.spawn(rand) for i in 1:4]
all_vals_written = map(1:4) do i
Dagger.spawn(all_vals[i]) do val
open("results_$i.txt"; write=true, create=true, append=true) do io
println(io, repr(val))
end
return val
end
end
Dagger.spawn(all_vals_written...) do all_vals_written...
vals_sum = sum(all_vals_written)
println(vals_sum)
end
end
```

If you want to stop the streaming DAG and tear it all down, you can call
`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the
kill propagates throughout the DAG).

Alternatively, tasks can stop themselves from the inside with
`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's
do this when our randomly-drawn number falls within some arbitrary range:

```julia
vals = Dagger.spawn_streaming() do
Dagger.spawn() do
x = rand()
if x < 0.001
# That's good enough, let's be done
return Dagger.finish_streaming("Finished!")
end
return x
end
end
fetch(vals)
```

In this example, the call to `fetch` will hang (while random numbers continue
to be drawn), until a drawn number is less than 0.001; at that point, `fetch`
will return with "Finished!", and the task `vals` will have terminated.
5 changes: 5 additions & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ include("sch/Sch.jl"); using .Sch
# Data dependency task queue
include("datadeps.jl")

# Streaming
include("stream-buffers.jl")
include("stream-fetchers.jl")
include("stream.jl")

# Array computations
include("array/darray.jl")
include("array/alloc.jl")
Expand Down
7 changes: 7 additions & 0 deletions src/sch/eager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ function eager_cleanup(state, uid)
# N.B. cache and errored expire automatically
delete!(state.thunk_dict, tid)
end
remotecall_wait(1, uid) do uid
lock(EAGER_THUNK_STREAMS) do global_streams
if haskey(global_streams, uid)
delete!(global_streams, uid)
end
end
end
end

function _find_thunk(e::Dagger.DTask)
Expand Down
202 changes: 202 additions & 0 deletions src/stream-buffers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""
A buffer that drops all elements put into it. Only to be used as the output
buffer for a task - will throw if attached as an input.
"""
struct DropBuffer{T} end
DropBuffer{T}(_) where T = DropBuffer{T}()
Base.isempty(::DropBuffer) = true
isfull(::DropBuffer) = false
Base.put!(::DropBuffer, _) = nothing
Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer")

"A process-local buffer backed by a `Channel{T}`."
struct ChannelBuffer{T}
channel::Channel{T}
len::Int
count::Threads.Atomic{Int}
ChannelBuffer{T}(len::Int=1024) where T =
new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0))
end
Base.isempty(cb::ChannelBuffer) = isempty(cb.channel)
isfull(cb::ChannelBuffer) = cb.count[] == cb.len
function Base.put!(cb::ChannelBuffer{T}, x) where T
put!(cb.channel, convert(T, x))
Threads.atomic_add!(cb.count, 1)
end
function Base.take!(cb::ChannelBuffer)
take!(cb.channel)
Threads.atomic_sub!(cb.count, 1)
end

"A cross-worker buffer backed by a `RemoteChannel{T}`."
struct RemoteChannelBuffer{T}
channel::RemoteChannel{Channel{T}}
len::Int
count::Threads.Atomic{Int}
RemoteChannelBuffer{T}(len::Int=1024) where T =
new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0))
end
Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel)
isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len
function Base.put!(cb::RemoteChannelBuffer{T}, x) where T
put!(cb.channel, convert(T, x))
Threads.atomic_add!(cb.count, 1)
end
function Base.take!(cb::RemoteChannelBuffer)
take!(cb.channel)
Threads.atomic_sub!(cb.count, 1)
end

"A process-local ring buffer."
mutable struct ProcessRingBuffer{T}
read_idx::Int
write_idx::Int
@atomic count::Int
buffer::Vector{T}
function ProcessRingBuffer{T}(len::Int=1024) where T
buffer = Vector{T}(undef, len)
return new{T}(1, 1, 0, buffer)
end
end
Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0
isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer)
function Base.put!(rb::ProcessRingBuffer{T}, x) where T
len = length(rb.buffer)
while (@atomic rb.count) == len
yield()
end
to_write_idx = mod1(rb.write_idx, len)
rb.buffer[to_write_idx] = convert(T, x)
rb.write_idx += 1
@atomic rb.count += 1
end
function Base.take!(rb::ProcessRingBuffer)
while (@atomic rb.count) == 0
yield()
end
to_read_idx = rb.read_idx
rb.read_idx += 1
@atomic rb.count -= 1
to_read_idx = mod1(to_read_idx, length(rb.buffer))
return rb.buffer[to_read_idx]
end

#= TODO
"A server-local ring buffer backed by shared-memory."
mutable struct ServerRingBuffer{T}
read_idx::Int
write_idx::Int
@atomic count::Int
buffer::Vector{T}
function ServerRingBuffer{T}(len::Int=1024) where T
buffer = Vector{T}(undef, len)
return new{T}(1, 1, 0, buffer)
end
end
Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0
function Base.put!(rb::ServerRingBuffer{T}, x) where T
len = length(rb.buffer)
while (@atomic rb.count) == len
yield()
end
to_write_idx = mod1(rb.write_idx, len)
rb.buffer[to_write_idx] = convert(T, x)
rb.write_idx += 1
@atomic rb.count += 1
end
function Base.take!(rb::ServerRingBuffer)
while (@atomic rb.count) == 0
yield()
end
to_read_idx = rb.read_idx
rb.read_idx += 1
@atomic rb.count -= 1
to_read_idx = mod1(to_read_idx, length(rb.buffer))
return rb.buffer[to_read_idx]
end
=#

#=
"A TCP-based ring buffer."
mutable struct TCPRingBuffer{T}
read_idx::Int
write_idx::Int
@atomic count::Int
buffer::Vector{T}
function TCPRingBuffer{T}(len::Int=1024) where T
buffer = Vector{T}(undef, len)
return new{T}(1, 1, 0, buffer)
end
end
Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0
function Base.put!(rb::TCPRingBuffer{T}, x) where T
len = length(rb.buffer)
while (@atomic rb.count) == len
yield()
end
to_write_idx = mod1(rb.write_idx, len)
rb.buffer[to_write_idx] = convert(T, x)
rb.write_idx += 1
@atomic rb.count += 1
end
function Base.take!(rb::TCPRingBuffer)
while (@atomic rb.count) == 0
yield()
end
to_read_idx = rb.read_idx
rb.read_idx += 1
@atomic rb.count -= 1
to_read_idx = mod1(to_read_idx, length(rb.buffer))
return rb.buffer[to_read_idx]
end
=#

#=
"""
A flexible puller which switches to the most efficient buffer type based
on the sender and receiver locations.
"""
mutable struct UniBuffer{T}
buffer::Union{ProcessRingBuffer{T}, Nothing}
end
function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T
if buffer_amount == 0
error("Return NullBuffer")
end
send_osproc = get_parent(send_proc)
recv_osproc = get_parent(recv_proc)
if send_osproc.pid == recv_osproc.pid
inner = RingBuffer{T}(buffer_amount)
elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid)
inner = ProcessBuffer{T}(buffer_amount)
else
inner = RemoteBuffer{T}(buffer_amount)
end
return UniBuffer{T}(buffer_amount)
end
struct LocalPuller{T,B}
buffer::B{T}
id::UInt
function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B}
buffer = initialize_stream_buffer!(B, T, buffer_amount)
return new{T,B}(buffer, id)
end
end
function Base.take!(pull::LocalPuller{T,B}) where {T,B}
if pull.buffer === nothing
pull.buffer =
error("Return NullBuffer")
end
value = take!(pull.buffer)
end
function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B}
local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id
local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount)
ref.buffers[id] = remote_buffer
return local_buffer
end
stream.buffer = local_buffer
return stream
end
=#
24 changes: 24 additions & 0 deletions src/stream-fetchers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
struct RemoteFetcher end
function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal}
if store_ref.handle.owner == myid()
store = fetch(store_ref)::Store_remote
while !isfull(buffer)
value = take!(store, id)::T
put!(buffer, value)
end
else
tls = Dagger.get_tls()
values = remotecall_fetch(store_ref.handle.owner, store_ref.handle, id, T, Store_remote) do store_ref, id, T, Store_remote
store = MemPool.poolget(store_ref)::Store_remote
values = T[]
while !isempty(store, id)
value = take!(store, id)::T
push!(values, value)
end
return values
end::Vector{T}
for value in values
put!(buffer, value)
end
end
end
Loading

0 comments on commit e441bd0

Please sign in to comment.