Skip to content

Commit

Permalink
fixup! Various alloc reductions and optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Dec 16, 2024
1 parent d52b541 commit 4450769
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 80 deletions.
1 change: 1 addition & 0 deletions src/argument.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ArgPosition() = ArgPosition(true, 0, :NULL)
ArgPosition(pos::ArgPosition) = ArgPosition(pos.positional, pos.idx, pos.kw)
ispositional(pos::ArgPosition) = pos.positional
iskw(pos::ArgPosition) = !pos.positional
raw_position(pos::ArgPosition) = ispositional(pos) ? pos.idx : pos.kw
function pos_idx(pos::ArgPosition)
@assert pos.positional
@assert pos.idx > 0
Expand Down
36 changes: 2 additions & 34 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,6 @@ include("util.jl")
include("fault-handler.jl")
include("dynamic.jl")

mutable struct ProcessorCacheEntry
gproc::OSProc
proc::Processor
next::ProcessorCacheEntry

ProcessorCacheEntry(gproc::OSProc, proc::Processor) = new(gproc, proc)
end
Base.isequal(p1::ProcessorCacheEntry, p2::ProcessorCacheEntry) =
p1.proc === p2.proc
function Base.show(io::IO, entry::ProcessorCacheEntry)
entries = 1
next = entry.next
while next !== entry
entries += 1
next = next.next
end
print(io, "ProcessorCacheEntry(pid $(entry.gproc.pid), $(entry.proc), $entries entries)")
end

struct TaskResult
pid::Int
proc::Processor
Expand Down Expand Up @@ -82,7 +63,6 @@ Fields:
- `worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}` - Maps from worker ID to storage resource capacity
- `worker_loadavg::Dict{Int,NTuple{3,Float64}}` - Worker load average
- `worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}` - Communication channels between the scheduler and each worker
- `procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}` - Cached linked list of processors ready to be used
- `signature_time_cost::Dict{Signature,UInt64}` - Cache of estimated CPU time (in nanoseconds) required to compute calls with the given signature
- `signature_alloc_cost::Dict{Signature,UInt64}` - Cache of estimated CPU RAM (in bytes) required to compute calls with the given signature
- `transfer_rate::Ref{UInt64}` - Estimate of the network transfer rate in bytes per second
Expand All @@ -109,7 +89,6 @@ struct ComputeState
worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}
worker_loadavg::Dict{Int,NTuple{3,Float64}}
worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}
procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}
signature_time_cost::Dict{Signature,UInt64}
signature_alloc_cost::Dict{Signature,UInt64}
transfer_rate::Ref{UInt64}
Expand Down Expand Up @@ -139,7 +118,6 @@ function start_state(deps::Dict, node_order, chan)
Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(),
Dict{Int,NTuple{3,Float64}}(),
Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(),
Ref{Union{ProcessorCacheEntry,Nothing}}(nothing),
Dict{Signature,UInt64}(),
Dict{Signature,UInt64}(),
Ref{UInt64}(1_000_000),
Expand Down Expand Up @@ -553,8 +531,6 @@ function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)
# Remove processors that aren't yet initialized
procs = filter(p -> haskey(state.worker_chans, Dagger.root_worker_id(p)), procs)

populate_processor_cache_list!(state, procs)

# Schedule tasks
to_fire = @reusable_dict :schedule!_to_fire ScheduleTaskLocation Vector{ScheduleTaskSpec} ScheduleTaskLocation(OSProc(), OSProc()) ScheduleTaskSpec[] 1024
failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32
Expand Down Expand Up @@ -633,6 +609,7 @@ function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)
costs = @reusable_dict :schedule!_costs Processor Float64 OSProc() 0.0 32
estimate_task_costs!(sorted_procs, costs, state, input_procs, task)
empty!(costs) # We don't use costs here
empty!(input_procs)
scheduled = false

# Move our corresponding ThreadProc to be the last considered
Expand Down Expand Up @@ -710,22 +687,14 @@ function monitor_procs_changed!(ctx, state, options)
for p in diffps
init_proc(state, p, ctx.log_sink)

# Empty the processor cache list and force reschedule
lock(state.lock) do
state.procs_cache_list[] = nothing
end
# Force reschedule
put!(state.chan, RescheduleSignal())
end

# Cleanup removed procs
diffps = setdiff(old_ps, new_ps)
for p in diffps
cleanup_proc(state, p, ctx.log_sink)

# Empty the processor cache list
lock(state.lock) do
state.procs_cache_list[] = nothing
end
end

@maybelog ctx timespan_finish(ctx, :assign_procs, (;uid=state.uid), nothing)
Expand All @@ -741,7 +710,6 @@ function remove_dead_proc!(ctx, state, proc, options)
delete!(state.worker_storage_capacity, proc.pid)
delete!(state.worker_loadavg, proc.pid)
delete!(state.worker_chans, proc.pid)
state.procs_cache_list[] = nothing
end

function finish_task!(ctx, state, node, thunk_failed)
Expand Down
35 changes: 7 additions & 28 deletions src/sch/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ function can_use_proc(state, task, gproc, proc, opts, scope)
scope = constrain(scope, Dagger.ExactScope(proc))
elseif opts.proclist isa Vector
if !(typeof(proc) in opts.proclist)
@dagdebug task :scope "Rejected $proc: !(typeof(proc) in proclist)"
@dagdebug task :scope "Rejected $proc: !($(typeof(proc)) in proclist)"
return false, scope
end
scope = constrain(scope,
Expand Down Expand Up @@ -437,18 +437,18 @@ function can_use_proc(state, task, gproc, proc, opts, scope)
return false, scope
end

# Check against f/args
# Check against function and arguments
Tf = chunktype(task.f)
if !Dagger.iscompatible_func(proc, opts, Tf)
@dagdebug task :scope "Rejected $proc: Not compatible with function type ($Tf)"
return false, scope
end
for (_, arg) in task.inputs
arg = unwrap_weak_checked(arg)
if arg isa Thunk
arg = state.cache[arg]
for arg in task.inputs[2:end]
value = unwrap_weak_checked(Dagger.value(arg))
if value isa Thunk
value = load_result(state, value)
end
Targ = chunktype(arg)
Targ = chunktype(value)
if !Dagger.iscompatible_arg(proc, opts, Targ)
@dagdebug task :scope "Rejected $proc: Not compatible with argument type ($Targ)"
return false, scope
Expand Down Expand Up @@ -498,27 +498,6 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig)
return true, est_time_util, est_alloc_util, est_occupancy
end

function populate_processor_cache_list!(state, procs)
# Populate the cache if empty
if state.procs_cache_list[] === nothing
current = nothing
for p in map(x->x.pid, procs)
for proc in get_processors(OSProc(p))
next = ProcessorCacheEntry(OSProc(p), proc)
if current === nothing
current = next
current.next = current
state.procs_cache_list[] = current
else
current.next = next
current = next
current.next = state.procs_cache_list[]
end
end
end
end
end

"Like `sum`, but replaces `nothing` entries with the average of non-`nothing` entries."
function impute_sum(xs)
total = 0
Expand Down
6 changes: 3 additions & 3 deletions src/utils/logging-events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function (ta::TaskArguments)(ev::Event{:finish})
if ev.category == :move
args = Pair{Union{Symbol,Int},Dagger.LoggedMutableObject}[]
thunk_id = ev.id.thunk_id::Int
pos = ev.id.position::Union{Symbol,Int}
pos = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int}
arg = ev.timeline.data
if ismutable(arg)
push!(args, pos => Dagger.objectid_or_chunkid(arg))
Expand All @@ -174,7 +174,7 @@ function (ta::TaskArgumentMoves)(ev::Event{:start})
data = ev.timeline.data
if ismutable(data)
thunk_id = ev.id.thunk_id::Int
position = ev.id.position::Union{Symbol,Int}
position = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int}
d = get!(Dict{Union{Int,Symbol},Dagger.LoggedMutableObject}, ta.pre_move_args, thunk_id)
d[position] = Dagger.objectid_or_chunkid(data)
end
Expand All @@ -186,7 +186,7 @@ function (ta::TaskArgumentMoves)(ev::Event{:finish})
post_data = ev.timeline.data
if ismutable(post_data)
thunk_id = ev.id.thunk_id::Int
position = ev.id.position::Union{Symbol,Int}
position = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int}
if haskey(ta.pre_move_args, thunk_id)
d = ta.pre_move_args[thunk_id]
if haskey(d, position)
Expand Down
4 changes: 2 additions & 2 deletions src/utils/reuse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ function maybetake!(cache::ReusableCache{T}, len=nothing) where T
for idx in 1:length(cache.used)
cache.used[idx] && continue
if cache.sized && isassigned(cache.cache, idx) && length(cache.cache[idx]) != len
@debug "Skipping length $(length(cache.cache[idx])) (want length $len) @ $idx"
@dagdebug nothing :reuse "Skipping length $(length(cache.cache[idx])) (want length $len) @ $idx"
continue
end
cache.used[idx] = true
if !isassigned(cache.cache, idx)
if cache.sized
@debug "Allocating length $len @ $idx"
@dagdebug nothing :reuse "Allocating length $len @ $idx"
cache.cache[idx] = alloc!(T, len)
else
cache.cache[idx] = alloc!(T)
Expand Down
10 changes: 0 additions & 10 deletions test/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,6 @@ import Colors, GraphViz, DataFrames, Plots, JSON3
@test any(e->haskey(e, :fire), esat)
@test any(e->haskey(e, :take), esat)
@test any(e->haskey(e, :finish), esat)
if Threads.nthreads() == 1
if nprocs() > 1
# Note: May one day be true as scheduler evolves
@test !any(e->haskey(e, :compute), esat)
@test !any(e->haskey(e, :move), esat)
psat = l1[:psat]
# Note: May become false
@test all(e->length(e) == 0, psat)
end
end

had_psat_proc = 0
for wo in filter(w->w != 1, keys(logs))
Expand Down
2 changes: 1 addition & 1 deletion test/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ end
t1 = Dagger.@spawn 1+"fail"
Dagger.@spawn t1+1
end
@test_throws_unwrap (Dagger.ThunkFailedException, MethodError) fetch(t2)
@test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(t2)
end
end
if nprocs() > 1
Expand Down
4 changes: 2 additions & 2 deletions test/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ function _test_throws_unwrap(terr, ex; to_match=[])
match_expr = Expr(:block)
for m in to_match
if m.head == :(=)
lhs, rhs = replace_obj!(m.args[1], oerr), m.args[2]
lhs, rhs = replace_obj!(m.args[1], rerr), m.args[2]
push!(match_expr.args, :(@test $lhs == $rhs))
elseif m.head == :call
fn = m.args[1]
lhs, rhs = replace_obj!(m.args[2], oerr), m.args[3]
lhs, rhs = replace_obj!(m.args[2], rerr), m.args[3]
if fn == :(<)
push!(match_expr.args, :(@test startswith($lhs, $rhs)))
elseif fn == :(>)
Expand Down

0 comments on commit 4450769

Please sign in to comment.