Skip to content

Commit

Permalink
Merge pull request #38405 from JuliaLang/vc/distributed_ts
Browse files Browse the repository at this point in the history
Make Distributed.jl `Worker` struct thread-safe.
  • Loading branch information
vchuravy authored Jul 19, 2021
2 parents 02807b2 + 0c073cc commit 5a16805
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 47 deletions.
66 changes: 53 additions & 13 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,14 @@ end
@enum WorkerState W_CREATED W_CONNECTED W_TERMINATING W_TERMINATED
mutable struct Worker
id::Int
msg_lock::Threads.ReentrantLock # Lock for del_msgs, add_msgs, and gcflag
del_msgs::Array{Any,1}
add_msgs::Array{Any,1}
gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -133,7 +134,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -143,12 +144,16 @@ mutable struct Worker
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
lock(w.c_state)
if w.state === W_CREATED
unlock(w.c_state)
if !isclusterlazy()
if PGRP.topology === :all_to_all
# Since higher pids connect with lower pids, the remote worker
Expand All @@ -168,6 +173,8 @@ function check_worker_state(w::Worker)
errormonitor(t)
wait_for_conn(w)
end
else
unlock(w.c_state)
end
end

Expand All @@ -186,13 +193,25 @@ function exec_conn_func(w::Worker)
end

function wait_for_conn(w)
lock(w.c_state)
if w.state === W_CREATED
unlock(w.c_state)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
else
unlock(w.c_state)
end
nothing
end
Expand Down Expand Up @@ -471,6 +490,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
# The `launch` method should add an object of type WorkerConfig for every
# worker launched. It provides information required on how to connect
# to it.

# FIXME: launched should be a Channel, launch_ntfy should be a Threads.Condition
# but both are part of the public interface. This means we currently can't use
# `Threads.@spawn` in the code below.
launched = WorkerConfig[]
launch_ntfy = Condition()

Expand All @@ -483,7 +506,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
@async begin
sleep(1)
notify(launch_ntfy)
end
wait(launch_ntfy)
end

Expand Down Expand Up @@ -636,7 +662,12 @@ function create_worker(manager, wconfig)
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
# wait for wl to join
lock(jw.c_state) do
if jw.state === W_CREATED
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -659,7 +690,12 @@ function create_worker(manager, wconfig)
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
if wl.state === W_CREATED
# wait for wl to join
lock(wl.c_state) do
wait(wl.c_state)
end
end
push!(join_list, wl)
end
end
Expand All @@ -676,7 +712,11 @@ function create_worker(manager, wconfig)
@async manage(w.manager, w.id, w.config, :register)
# wait for rr_ntfy_join with timeout
timedout = false
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
@async begin
sleep($timeout)
timedout = true
put!(rr_ntfy_join, 1)
end
wait(rr_ntfy_join)
if timedout
error("worker did not connect within $timeout seconds")
Expand Down
10 changes: 3 additions & 7 deletions stdlib/Distributed/src/macros.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

let nextidx = 0
let nextidx = Threads.Atomic{Int}(0)
global nextproc
function nextproc()
p = -1
if p == -1
p = workers()[(nextidx % nworkers()) + 1]
nextidx += 1
end
p
idx = Threads.atomic_add!(nextidx, 1)
return workers()[(idx % nworkers()) + 1]
end
end

Expand Down
2 changes: 1 addition & 1 deletion stdlib/Distributed/src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
@async try
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down
28 changes: 13 additions & 15 deletions stdlib/Distributed/src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,20 @@ function flush_gc_msgs(w::Worker)
if !isdefined(w, :w_stream)
return
end
w.gcflag = false
new_array = Any[]
msgs = w.add_msgs
w.add_msgs = new_array
if !isempty(msgs)
remote_do(add_clients, w, msgs)
end
lock(w.msg_lock) do
w.gcflag || return # early exit if someone else got to this
w.gcflag = false
msgs = w.add_msgs
w.add_msgs = Any[]
if !isempty(msgs)
remote_do(add_clients, w, msgs)
end

# del_msgs gets populated by finalizers, so be very careful here about ordering of allocations
# XXX: threading requires this to be atomic
new_array = Any[]
msgs = w.del_msgs
w.del_msgs = new_array
if !isempty(msgs)
#print("sending delete of $msgs\n")
remote_do(del_clients, w, msgs)
msgs = w.del_msgs
w.del_msgs = Any[]
if !isempty(msgs)
remote_do(del_clients, w, msgs)
end
end
end

Expand Down
46 changes: 35 additions & 11 deletions stdlib/Distributed/src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,42 @@ function del_clients(pairs::Vector)
end
end

const any_gc_flag = Condition()
# The task below is coalescing the `flush_gc_msgs` call
# across multiple producers, see `send_del_client`,
# and `send_add_client`.
# XXX: Is this worth the additional complexity?
# `flush_gc_msgs` has to iterate over all connected workers.
const any_gc_flag = Threads.Condition()
function start_gc_msgs_task()
errormonitor(@async while true
wait(any_gc_flag)
flush_gc_msgs()
end)
errormonitor(
Threads.@spawn begin
while true
lock(any_gc_flag) do
wait(any_gc_flag)
flush_gc_msgs() # handles throws internally
end
end
end
)
end

# Function can be called within a finalizer
function send_del_client(rr)
if rr.where == myid()
del_client(rr)
elseif id_in_procs(rr.where) # process only if a valid worker
w = worker_from_id(rr.where)::Worker
push!(w.del_msgs, (remoteref_id(rr), myid()))
w.gcflag = true
notify(any_gc_flag)
msg = (remoteref_id(rr), myid())
# We cannot acquire locks from finalizers
Threads.@spawn begin
lock(w.msg_lock) do
push!(w.del_msgs, msg)
w.gcflag = true
end
lock(any_gc_flag) do
notify(any_gc_flag)
end
end
end
end

Expand All @@ -288,9 +308,13 @@ function send_add_client(rr::AbstractRemoteRef, i)
# to the processor that owns the remote ref. it will add_client
# itself inside deserialize().
w = worker_from_id(rr.where)
push!(w.add_msgs, (remoteref_id(rr), i))
w.gcflag = true
notify(any_gc_flag)
lock(w.msg_lock) do
push!(w.add_msgs, (remoteref_id(rr), i))
w.gcflag = true
end
lock(any_gc_flag) do
notify(any_gc_flag)
end
end
end

Expand Down
1 change: 1 addition & 0 deletions stdlib/Distributed/test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1696,4 +1696,5 @@ include("splitrange.jl")
# Run topology tests last after removing all workers, since a given
# cluster at any time only supports a single topology.
rmprocs(workers())
include("threads.jl")
include("topology.jl")
63 changes: 63 additions & 0 deletions stdlib/Distributed/test/threads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Test
using Distributed, Base.Threads
using Base.Iterators: product

exeflags = ("--startup-file=no",
"--check-bounds=yes",
"--depwarn=error",
"--threads=2")

function call_on(f, wid, tid)
remotecall(wid) do
t = Task(f)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
schedule(t)
@assert threadid(t) == tid
t
end
end

# Run function on process holding the data to only serialize the result of f.
# This becomes useful for things that cannot be serialized (e.g. running tasks)
# or that would be unnecessarily big if serialized.
fetch_from_owner(f, rr) = remotecall_fetch(f fetch, rr.where, rr)

isdone(rr) = fetch_from_owner(istaskdone, rr)
isfailed(rr) = fetch_from_owner(istaskfailed, rr)

@testset "RemoteChannel allows put!/take! from thread other than 1" begin
ws = ts = product(1:2, 1:2)
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
procs_added = addprocs(2; exeflags, lazy=true)
@everywhere procs_added using Base.Threads

p1 = procs_added[w1]
p2 = procs_added[w2]
chan_id = first(procs_added)
chan = RemoteChannel(chan_id)
send = call_on(p1, t1) do
put!(chan, nothing)
end
recv = call_on(p2, t2) do
take!(chan)
end

# Wait on the spawned tasks on the owner
@sync begin
Threads.@spawn fetch_from_owner(wait, recv)
Threads.@spawn fetch_from_owner(wait, send)
end

# Check the tasks
@test isdone(send)
@test isdone(recv)

@test !isfailed(send)
@test !isfailed(recv)

rmprocs(procs_added)
end
end
end

0 comments on commit 5a16805

Please sign in to comment.