Skip to content

Commit

Permalink
fixup! Add streaming API
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Dec 3, 2024
1 parent e772af0 commit 478b2d0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 148 deletions.
34 changes: 34 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,37 @@ Dagger.@spawn copyto!(C, X)

In contrast to the previous example, here, the tasks are executed without argument annotations. As a result, there is a possibility of the `copyto!` task being executed before the `sort!` task, leading to unexpected results in the output array `C`.

## Quickstart: Streaming

Dagger.jl provides a streaming API that allows you to process data in a streaming fashion, where data is processed as it becomes available, rather than waiting for the entire dataset to be loaded into memory.

For more details: [Streaming](@ref)

### Syntax

The `Dagger.spawn_streaming()` function is used to create a streaming region,
where tasks are executed continuously, processing data as it becomes available:

```julia
# Open a file to write to on this worker
f = Dagger.@mutable open("output.txt", "w")
t = Dagger.spawn_streaming() do
# Generate random numbers continuously
val = Dagger.@spawn rand()
# Write each random number to a file
Dagger.@spawn (f, val) -> begin
if val < 0.01
# Finish streaming when the random number is less than 0.01
Dagger.finish_stream()
end
println(f, val)
end
end
# Wait for all values to be generated and written
wait(t)
```

The above example demonstrates a streaming region that generates random numbers
continuously and writes each random number to a file. The streaming region is
terminated when a random number less than 0.01 is generated, which is done by
calling `Dagger.finish_stream()` (this exits the current streaming task).
7 changes: 4 additions & 3 deletions docs/src/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ end
```

If you want to stop the streaming DAG and tear it all down, you can call
`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the
kill propagates throughout the DAG).
`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to
terminate each streaming task. In the future, a more convenient way to tear
down a full DAG will be added; for now, each task must be cancelled individually.

Alternatively, tasks can stop themselves from the inside with
`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's
Expand All @@ -102,4 +103,4 @@ fetch(vals)

In this example, the call to `fetch` will hang (while random numbers continue
to be drawn), until a drawn number is less than 0.001; at that point, `fetch`
will return with "Finished!", and the task `vals` will have terminated.
will return with `"Finished!"`, and the task `vals` will have terminated.
59 changes: 0 additions & 59 deletions src/stream-transfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,62 +67,3 @@ function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::Stream
@lock our_store.lock notify(our_store.lock)
@dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid"
end

#= TODO: Remove me
# This is a bad implementation because it wants to sleep on the remote side to
# wait for values, but this isn't semantically valid when done with MemPool.access_ref
struct RemoteFetcher end
function stream_push_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer)
sleep(1)
end
function stream_pull_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer)
id = our_store.uid
thunk_id = STREAM_THUNK_ID[]
@dagdebug thunk_id :stream "fetching values"
free_space = capacity(buffer) - length(buffer)
if free_space == 0
@dagdebug thunk_id :stream "waiting for drain of full input buffer"
yield()
task_may_cancel!()
wait_for_nonfull_input(our_store, their_stream.uid)
return
end
values = T[]
while isempty(values)
values, closed = MemPool.access_ref(their_stream.store_ref.handle, id, T, thunk_id, free_space) do their_store, id, T, thunk_id, free_space
@dagdebug thunk_id :stream "trying to fetch values at worker $(myid())"
STREAM_THUNK_ID[] = thunk_id
values = T[]
@dagdebug thunk_id :stream "trying to fetch with free_space: $free_space"
wait_for_nonempty_output(their_store, id)
if isempty(their_store, id) && !isopen(their_store, id)
@dagdebug thunk_id :stream "remote stream is closed, returning"
return values, true
end
while !isempty(their_store, id) && length(values) < free_space
value = take!(their_store, id)::T
@dagdebug thunk_id :stream "fetched $value"
push!(values, value)
end
return values, false
end::Tuple{Vector{T},Bool}
if closed
throw(InterruptException())
end
# We explicitly yield in the loop to allow other tasks to run. This
# matters on single-threaded instances because MemPool.access_ref()
# might not yield when accessing data locally, which can cause this loop
# to spin forever.
yield()
task_may_cancel!()
end
@dagdebug thunk_id :stream "fetched $(length(values)) values"
for value in values
put!(buffer, value)
end
end
=#
137 changes: 51 additions & 86 deletions src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,47 +91,11 @@ function Base.take!(store::StreamStore, id::UInt)
return value
end
end
function wait_for_nonfull_input(store::StreamStore, id::UInt)
@lock store.lock begin
@assert haskey(store.input_streams, id)
@assert haskey(store.input_buffers, id)
buffer = store.input_buffers[id]
while isfull(buffer) && isopen(store)
@dagdebug STREAM_THUNK_ID[] :stream "waiting for space in input buffer"
wait(store.lock)
end
end
end
function wait_for_nonempty_output(store::StreamStore, id::UInt)
@lock store.lock begin
@assert haskey(store.output_streams, id)

# Wait for the output buffer to be initialized
while !haskey(store.output_buffers, id) && isopen(store, id)
@dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be initialized"
wait(store.lock)
end
isopen(store, id) || return

# Wait for the output buffer to be nonempty
buffer = store.output_buffers[id]
while isempty(buffer) && isopen(store, id)
@dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be nonempty"
wait(store.lock)
end
end
end

function Base.isempty(store::StreamStore, id::UInt)
if !haskey(store.output_buffers, id)
@assert haskey(store.output_streams, id)
return true
end
return isempty(store.output_buffers[id])
end
isfull(store::StreamStore, id::UInt) = isfull(store.output_buffers[id])

"Returns whether the store is actively open. Only check this when deciding if new values can be pushed."
"""
Returns whether the store is actively open. Only check this when deciding if
new values can be pushed.
"""
Base.isopen(store::StreamStore) = store.open

"""
Expand All @@ -153,9 +117,9 @@ function Base.isopen(store::StreamStore, id::UInt)
end

function Base.close(store::StreamStore)
store.open || return
store.open = false
@lock store.lock begin
store.open || return
store.open = false
for buffer in values(store.input_buffers)
close(buffer)
end
Expand Down Expand Up @@ -207,7 +171,6 @@ mutable struct Stream{T,B}
end
end

struct StreamCancelledException <: Exception end
struct StreamingValue{B}
buffer::B
end
Expand Down Expand Up @@ -307,8 +270,6 @@ function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}})
return
end

add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream, UInt[waiter])

function remove_waiters!(stream::Stream, waiters::Vector{UInt})
MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters
remove_waiters!(store::StreamStore, waiters)
Expand All @@ -317,8 +278,6 @@ function remove_waiters!(stream::Stream, waiters::Vector{UInt})
return
end

remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter])

struct StreamingFunction{F, S}
f::F
stream::S
Expand All @@ -331,12 +290,14 @@ end
function migrate_stream!(stream::Stream, w::Integer=myid())
# Perform migration of the StreamStore
# MemPool will block access to the new ref until the migration completes
# FIXME: Do this with MemPool.access_ref, in case stream was already migrated
# FIXME: Do this ownership check with MemPool.access_ref,
# in case stream was already migrated
if stream.store_ref.handle.owner != w
thunk_id = STREAM_THUNK_ID[]
@dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))"

# TODO: Wire up listener to ferry cancel_token notifications to remote worker
# TODO: Wire up listener to ferry cancel_token notifications to remote
# worker once migrations occur during runtime
tls = get_tls()
@assert w == myid() "Only pull-based migration is currently supported"
#remote_cancel_token = clone_cancel_token_remote(get_tls().cancel_token, worker_id)
Expand Down Expand Up @@ -417,45 +378,48 @@ function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}
end

function initialize_streaming!(self_streams, spec, task)
if !isa(spec.f, StreamingFunction)
# Calculate the return type of the called function
T_old = Base.uniontypes(task.metadata.return_type)
T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old)
# N.B. We treat non-dominating error paths as unreachable
T_old = filter(t->t !== Union{}, T_old)
T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any

# Get input buffer configuration
input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1)
if input_buffer_amount <= 0
throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0"))
end
@assert !isa(spec.f, StreamingFunction) "Task is already in streaming form"

# Get output buffer configuration
output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1)
if output_buffer_amount <= 0
throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0"))
end
# Calculate the return type of the called function
T_old = Base.uniontypes(task.metadata.return_type)
T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old)
# N.B. We treat non-dominating error paths as unreachable
T_old = filter(t->t !== Union{}, T_old)
T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any

# Create the Stream
buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer)
stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount)
self_streams[task.uid] = stream
# Get input buffer configuration
input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1)
if input_buffer_amount <= 0
throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0"))
end

# Get max evaluation count
max_evals = get(spec.options, :stream_max_evals, -1)
if max_evals == 0
throw(ArgumentError("stream_max_evals cannot be 0"))
end
# Get output buffer configuration
output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1)
if output_buffer_amount <= 0
throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0"))
end

spec.f = StreamingFunction(spec.f, stream, max_evals)
spec.options = merge(spec.options, (;occupancy=Dict(Any=>0)))
# Create the Stream
buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer)
stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount)
self_streams[task.uid] = stream

# Register Stream globally
remotecall_wait(1, task.uid, stream) do uid, stream
lock(EAGER_THUNK_STREAMS) do global_streams
global_streams[uid] = stream
end
# Get max evaluation count
max_evals = get(spec.options, :stream_max_evals, -1)
if max_evals == 0
throw(ArgumentError("stream_max_evals cannot be 0"))
end

# Wrap the function in a StreamingFunction
spec.f = StreamingFunction(spec.f, stream, max_evals)

# Mark the task as non-blocking
spec.options = merge(spec.options, (;occupancy=Dict(Any=>0)))

# Register Stream globally
remotecall_wait(1, task.uid, stream) do uid, stream
lock(EAGER_THUNK_STREAMS) do global_streams
global_streams[uid] = stream
end
end
end
Expand Down Expand Up @@ -496,7 +460,7 @@ function (sf::StreamingFunction)(args...; kwargs...)

@label start
@dagdebug thunk_id :stream "Starting StreamingFunction"
worker_id = sf.stream.store_ref.handle.owner
worker_id = sf.stream.store_ref.handle.owner # FIXME: Not valid to access the owner directly
result = if worker_id == myid()
_run_streamingfunction(nothing, nothing, sf, args...; kwargs...)
else
Expand Down Expand Up @@ -620,7 +584,7 @@ function stream!(sf::StreamingFunction, uid,

# Exit streaming on eval limit
if sf.max_evals > 0 && counter >= sf.max_evals
@dagdebug STREAM_THUNK_ID[] :stream "max evals reached ($counter)"
@dagdebug STREAM_THUNK_ID[] :stream "max evals reached (eval $counter)"
return
end
end
Expand All @@ -644,6 +608,7 @@ end
return :($NT(stream_kwarg_values))
end

# Default for buffers, can be customized
initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount)

const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}())
Expand Down Expand Up @@ -707,7 +672,7 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams)
end
end

# Adjust waiter count of Streams with dependencies
# Notify Streams of any new waiters
for (uid, waiters) in stream_waiter_changes
stream = task_to_stream(uid)
add_waiters!(stream, waiters)
Expand Down

0 comments on commit 478b2d0

Please sign in to comment.