Skip to content

Commit

Permalink
add multithreading of learning network
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Feb 10, 2022
1 parent 48d4d02 commit 36788c6
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/composition/learning_networks/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,9 @@ order. These machines are those returned by
"""
fit!(y::Node; acceleration=CPU1(), kwargs...) =
fit!(y::Node; acceleration=default_resource(), kwargs...) =
fit!(y::Node, acceleration; kwargs...)

fit!(y::Node, ::AbstractResource; kwargs...) =
error("Only `acceleration=CPU1()` currently supported")

function fit!(y::Node, ::CPU1; kwargs...)

_machines = machines(y)
Expand All @@ -204,6 +201,27 @@ function fit!(y::Node, ::CPU1; kwargs...)

return y
end

function fit!(y::Node, ::CPUThreads; kwargs...)
_machines = machines(y)

# flush the fit_okay channels:
@sync for mach in _machines
Threads.@spawn flush!(mach.fit_okay)
end

# fit the machines in Multithreading mode
@sync for mach in _machines
Threads.@spawn fit_only!(mach, true; kwargs...)
end

return y

end

fit!(y::Node, ::AbstractResource; kwargs...) =
error("Only `acceleration=CPU1()` currently supported")

fit!(S::Source; args...) = S

# allow arguments of `Nodes` and `Machine`s to appear
Expand Down

0 comments on commit 36788c6

Please sign in to comment.