From 0c073ccb19ab926e25a1ddb457472f43ef946aeb Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 16 Jul 2021 15:53:46 +0200 Subject: [PATCH] [Distributed] Allow workers to be started with threading enabled Makes the worker struct threadsafe as well as flushing the GC messages --- stdlib/Distributed/src/cluster.jl | 66 +++++++++++++++++---- stdlib/Distributed/src/macros.jl | 10 +--- stdlib/Distributed/src/managers.jl | 2 +- stdlib/Distributed/src/messages.jl | 28 ++++----- stdlib/Distributed/src/remotecall.jl | 46 ++++++++++---- stdlib/Distributed/test/distributed_exec.jl | 1 + stdlib/Distributed/test/threads.jl | 63 ++++++++++++++++++++ 7 files changed, 169 insertions(+), 47 deletions(-) create mode 100644 stdlib/Distributed/test/threads.jl diff --git a/stdlib/Distributed/src/cluster.jl b/stdlib/Distributed/src/cluster.jl index ebe4cac0f3bbe..591ce3f850551 100644 --- a/stdlib/Distributed/src/cluster.jl +++ b/stdlib/Distributed/src/cluster.jl @@ -95,13 +95,14 @@ end @enum WorkerState W_CREATED W_CONNECTED W_TERMINATING W_TERMINATED mutable struct Worker id::Int + msg_lock::Threads.ReentrantLock # Lock for del_msgs, add_msgs, and gcflag del_msgs::Array{Any,1} add_msgs::Array{Any,1} gcflag::Bool state::WorkerState - c_state::Condition # wait for state changes - ct_time::Float64 # creation time - conn_func::Any # used to setup connections lazily + c_state::Threads.Condition # wait for state changes, lock for state + ct_time::Float64 # creation time + conn_func::Any # used to setup connections lazily r_stream::IO w_stream::IO @@ -133,7 +134,7 @@ mutable struct Worker if haskey(map_pid_wrkr, id) return map_pid_wrkr[id] end - w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func) + w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func) w.initialized = Event() register_worker(w) w @@ -143,12 +144,16 @@ mutable struct Worker end function set_worker_state(w, state) - w.state = state - notify(w.c_state; all=true) + lock(w.c_state) do + w.state = state + notify(w.c_state; all=true) + end end function check_worker_state(w::Worker) + lock(w.c_state) if w.state === W_CREATED + unlock(w.c_state) if !isclusterlazy() if PGRP.topology === :all_to_all # Since higher pids connect with lower pids, the remote worker @@ -168,6 +173,8 @@ function check_worker_state(w::Worker) errormonitor(t) wait_for_conn(w) end + else + unlock(w.c_state) end end @@ -186,13 +193,25 @@ function exec_conn_func(w::Worker) end function wait_for_conn(w) + lock(w.c_state) if w.state === W_CREATED + unlock(w.c_state) timeout = worker_timeout() - (time() - w.ct_time) timeout <= 0 && error("peer $(w.id) has not connected to $(myid())") - @async (sleep(timeout); notify(w.c_state; all=true)) - wait(w.c_state) - w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds") + T = Threads.@spawn begin + sleep($timeout) + lock(w.c_state) do + notify(w.c_state; all=true) + end + end + errormonitor(T) + lock(w.c_state) do + wait(w.c_state) + w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds") + end + else + unlock(w.c_state) end nothing end @@ -471,6 +490,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...) # The `launch` method should add an object of type WorkerConfig for every # worker launched. It provides information required on how to connect # to it. + + # FIXME: launched should be a Channel, launch_ntfy should be a Threads.Condition + # but both are part of the public interface. This means we currently can't use + # `Threads.@spawn` in the code below. launched = WorkerConfig[] launch_ntfy = Condition() @@ -483,7 +506,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...) while true if isempty(launched) istaskdone(t_launch) && break - @async (sleep(1); notify(launch_ntfy)) + @async begin + sleep(1) + notify(launch_ntfy) + end wait(launch_ntfy) end @@ -636,7 +662,12 @@ function create_worker(manager, wconfig) # require the value of config.connect_at which is set only upon connection completion for jw in PGRP.workers if (jw.id != 1) && (jw.id < w.id) - (jw.state === W_CREATED) && wait(jw.c_state) + # wait for wl to join + lock(jw.c_state) do + if jw.state === W_CREATED + wait(jw.c_state) + end + end push!(join_list, jw) end end @@ -659,7 +690,12 @@ function create_worker(manager, wconfig) end for wl in wlist - (wl.state === W_CREATED) && wait(wl.c_state) + if wl.state === W_CREATED + # wait for wl to join + lock(wl.c_state) do + wait(wl.c_state) + end + end push!(join_list, wl) end end @@ -676,7 +712,11 @@ function create_worker(manager, wconfig) @async manage(w.manager, w.id, w.config, :register) # wait for rr_ntfy_join with timeout timedout = false - @async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1)) + @async begin + sleep($timeout) + timedout = true + put!(rr_ntfy_join, 1) + end wait(rr_ntfy_join) if timedout error("worker did not connect within $timeout seconds") diff --git a/stdlib/Distributed/src/macros.jl b/stdlib/Distributed/src/macros.jl index 6603d627c3409..24a24f4c08ed4 100644 --- a/stdlib/Distributed/src/macros.jl +++ b/stdlib/Distributed/src/macros.jl @@ -1,14 +1,10 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -let nextidx = 0 +let nextidx = Threads.Atomic{Int}(0) global nextproc function nextproc() - p = -1 - if p == -1 - p = workers()[(nextidx % nworkers()) + 1] - nextidx += 1 - end - p + idx = Threads.atomic_add!(nextidx, 1) + return workers()[(idx % nworkers()) + 1] end end diff --git a/stdlib/Distributed/src/managers.jl b/stdlib/Distributed/src/managers.jl index 08686fc2a0b87..5b4f016c63a78 100644 --- a/stdlib/Distributed/src/managers.jl +++ b/stdlib/Distributed/src/managers.jl @@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy: # Wait for all launches to complete. @sync for (i, (machine, cnt)) in enumerate(manager.machines) let machine=machine, cnt=cnt - @async try + @async try launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy) catch e print(stderr, "exception launching on machine $(machine) : $(e)\n") diff --git a/stdlib/Distributed/src/messages.jl b/stdlib/Distributed/src/messages.jl index 47f70e044a2c0..fcba709b4db4a 100644 --- a/stdlib/Distributed/src/messages.jl +++ b/stdlib/Distributed/src/messages.jl @@ -126,22 +126,20 @@ function flush_gc_msgs(w::Worker) if !isdefined(w, :w_stream) return end - w.gcflag = false - new_array = Any[] - msgs = w.add_msgs - w.add_msgs = new_array - if !isempty(msgs) - remote_do(add_clients, w, msgs) - end + lock(w.msg_lock) do + w.gcflag || return # early exit if someone else got to this + w.gcflag = false + msgs = w.add_msgs + w.add_msgs = Any[] + if !isempty(msgs) + remote_do(add_clients, w, msgs) + end - # del_msgs gets populated by finalizers, so be very careful here about ordering of allocations - # XXX: threading requires this to be atomic - new_array = Any[] - msgs = w.del_msgs - w.del_msgs = new_array - if !isempty(msgs) - #print("sending delete of $msgs\n") - remote_do(del_clients, w, msgs) + msgs = w.del_msgs + w.del_msgs = Any[] + if !isempty(msgs) + remote_do(del_clients, w, msgs) + end end end diff --git a/stdlib/Distributed/src/remotecall.jl b/stdlib/Distributed/src/remotecall.jl index 088b7416f4488..5ac397656ce44 100644 --- a/stdlib/Distributed/src/remotecall.jl +++ b/stdlib/Distributed/src/remotecall.jl @@ -247,22 +247,42 @@ function del_clients(pairs::Vector) end end -const any_gc_flag = Condition() +# The task below is coalescing the `flush_gc_msgs` call +# across multiple producers, see `send_del_client`, +# and `send_add_client`. +# XXX: Is this worth the additional complexity? +# `flush_gc_msgs` has to iterate over all connected workers. +const any_gc_flag = Threads.Condition() function start_gc_msgs_task() - errormonitor(@async while true - wait(any_gc_flag) - flush_gc_msgs() - end) + errormonitor( + Threads.@spawn begin + while true + lock(any_gc_flag) do + wait(any_gc_flag) + flush_gc_msgs() # handles throws internally + end + end + end + ) end +# Function can be called within a finalizer function send_del_client(rr) if rr.where == myid() del_client(rr) elseif id_in_procs(rr.where) # process only if a valid worker w = worker_from_id(rr.where)::Worker - push!(w.del_msgs, (remoteref_id(rr), myid())) - w.gcflag = true - notify(any_gc_flag) + msg = (remoteref_id(rr), myid()) + # We cannot acquire locks from finalizers + Threads.@spawn begin + lock(w.msg_lock) do + push!(w.del_msgs, msg) + w.gcflag = true + end + lock(any_gc_flag) do + notify(any_gc_flag) + end + end end end @@ -288,9 +308,13 @@ function send_add_client(rr::AbstractRemoteRef, i) # to the processor that owns the remote ref. it will add_client # itself inside deserialize(). w = worker_from_id(rr.where) - push!(w.add_msgs, (remoteref_id(rr), i)) - w.gcflag = true - notify(any_gc_flag) + lock(w.msg_lock) do + push!(w.add_msgs, (remoteref_id(rr), i)) + w.gcflag = true + end + lock(any_gc_flag) do + notify(any_gc_flag) + end end end diff --git a/stdlib/Distributed/test/distributed_exec.jl b/stdlib/Distributed/test/distributed_exec.jl index 749c18f6b61f0..3b99afac8cc15 100644 --- a/stdlib/Distributed/test/distributed_exec.jl +++ b/stdlib/Distributed/test/distributed_exec.jl @@ -1696,4 +1696,5 @@ include("splitrange.jl") # Run topology tests last after removing all workers, since a given # cluster at any time only supports a single topology. rmprocs(workers()) +include("threads.jl") include("topology.jl") diff --git a/stdlib/Distributed/test/threads.jl b/stdlib/Distributed/test/threads.jl new file mode 100644 index 0000000000000..57d99b7ea056c --- /dev/null +++ b/stdlib/Distributed/test/threads.jl @@ -0,0 +1,63 @@ +using Test +using Distributed, Base.Threads +using Base.Iterators: product + +exeflags = ("--startup-file=no", + "--check-bounds=yes", + "--depwarn=error", + "--threads=2") + +function call_on(f, wid, tid) + remotecall(wid) do + t = Task(f) + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1) + schedule(t) + @assert threadid(t) == tid + t + end +end + +# Run function on process holding the data to only serialize the result of f. +# This becomes useful for things that cannot be serialized (e.g. running tasks) +# or that would be unnecessarily big if serialized. +fetch_from_owner(f, rr) = remotecall_fetch(f ∘ fetch, rr.where, rr) + +isdone(rr) = fetch_from_owner(istaskdone, rr) +isfailed(rr) = fetch_from_owner(istaskfailed, rr) + +@testset "RemoteChannel allows put!/take! from thread other than 1" begin + ws = ts = product(1:2, 1:2) + @testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws + @testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts + # We want (the default) lazyness, so that we wait for `Worker.c_state`! + procs_added = addprocs(2; exeflags, lazy=true) + @everywhere procs_added using Base.Threads + + p1 = procs_added[w1] + p2 = procs_added[w2] + chan_id = first(procs_added) + chan = RemoteChannel(chan_id) + send = call_on(p1, t1) do + put!(chan, nothing) + end + recv = call_on(p2, t2) do + take!(chan) + end + + # Wait on the spawned tasks on the owner + @sync begin + Threads.@spawn fetch_from_owner(wait, recv) + Threads.@spawn fetch_from_owner(wait, send) + end + + # Check the tasks + @test isdone(send) + @test isdone(recv) + + @test !isfailed(send) + @test !isfailed(recv) + + rmprocs(procs_added) + end + end +end