Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization #733

Merged
merged 18 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
1 change: 1 addition & 0 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import MLJModelInterface: fit, update, update_data, transform,
using Parameters

# Containers & data manipulation
using Serialization
using Tables
import PrettyTables
using DelimitedFiles
Expand Down
228 changes: 173 additions & 55 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,127 @@ end
network_model_names(model::Nothing, mach::Machine{<:Surrogate}) =
nothing

## DUPLICATING/REPLACING PARTS OF A LEARNING NETWORK MACHINE

## DUPLICATING AND REPLACING PARTS OF A LEARNING NETWORK MACHINE
function copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old)
train_args = [newnode_given_old[arg] for arg in N.machine.args]
return Machine(newmodel_given_old[N.machine.model],
train_args...)
end

"""
copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old)

For now, two top functions will lead to a call of this function: `Base.replace` and
olivierlabayle marked this conversation as resolved.
Show resolved Hide resolved
`save`. If `save` is the calling function, the argument `newmodel_given_old` will be nothing
olivierlabayle marked this conversation as resolved.
Show resolved Hide resolved
and the goal is to make the machine in the current learning network serializable.
olivierlabayle marked this conversation as resolved.
Show resolved Hide resolved
This method will be called. If `Base.replace` is the calling function, then `newmodel_given_old`
ablaom marked this conversation as resolved.
Show resolved Hide resolved
will be defined and the other method called, a new Machine will be built with training data.
"""
function copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old)
m = serializable(N.machine)
m.args = Tuple(newnode_given_old[s] for s in N.machine.args)
return m
end

"""
update_mappings_with_node!(
newnode_given_old,
newmach_given_old,
newmodel_given_old,
N::AbstractNode)

For Nodes that are not sources, update the appropriate mappings
between elements of the learning to be copied and the copy itself.
olivierlabayle marked this conversation as resolved.
Show resolved Hide resolved
"""
function update_mappings_with_node!(
newnode_given_old,
newmach_given_old,
newmodel_given_old,
N::AbstractNode)
args = [newnode_given_old[arg] for arg in N.args]
if N.machine === nothing
newnode_given_old[N] = node(N.operation, args...)
else
if N.machine in keys(newmach_given_old)
m = newmach_given_old[N.machine]
else
m = copy_or_replace_machine(N, newmodel_given_old, newnode_given_old)
newmach_given_old[N.machine] = m
end
newnode_given_old[N] = N.operation(m, args...)
end
end

update_mappings_with_node!(
newnode_given_old,
newmach_given_old,
newmodel_given_old,
N::Source) = nothing

"""
copysignature(signature, newnode_given_old; newmodel_given_old=nothing)

olivierlabayle marked this conversation as resolved.
Show resolved Hide resolved
Copies the given signature of a learning network.

# Arguments:
- `signature`: signature of the learning network to be copied
- `newnode_given_old`: initialized mapping between nodes of the
learning network to be copied and the new one. At this stage it should
contain only source nodes.
- `newmodel_given_old`: initialized mapping between models of the
learning network to be copied and the new one. This is `nothing` if `save` was
the calling function which will result in a different behaviour of
`update_mappings_with_node!`
"""
function copysignature(signature, newnode_given_old; newmodel_given_old=nothing)
operation_nodes = values(_operation_part(signature))
report_nodes = values(_report_part(signature))
W = glb(operation_nodes..., report_nodes...)
# Note: We construct nodes of the new network as values of a
# dictionary keyed on the nodes of the old network. Additionally,
# there are dictionaries of models keyed on old models and
# machines keyed on old machines. The node and machine
# dictionaries must be built simultaneously.

# instantiate node and machine dictionaries:
newoperation_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newreport_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newmach_given_old = IdDict{Machine,Machine}()

# build the new network:
for N in nodes(W)
update_mappings_with_node!(
newnode_given_old,
newmach_given_old,
newmodel_given_old,
N
)
if N in operation_nodes # could be `Source`
newoperation_node_given_old[N] = newnode_given_old[N]
elseif N in report_nodes
newreport_node_given_old[N] = newnode_given_old[N]
end
end
newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in
operation_nodes)
newreport_nodes = Tuple(newreport_node_given_old[N] for N in
report_nodes)
report_tuple =
NamedTuple{keys(_report_part(signature))}(newreport_nodes)
operation_tuple =
NamedTuple{keys(_operation_part(signature))}(newoperation_nodes)

newsignature = if isempty(report_tuple)
operation_tuple
else
merge(operation_tuple, (report=report_tuple,))
end

return newsignature
end

"""
replace(mach, a1=>b1, a2=>b2, ...; empty_unspecified_sources=false)
Expand All @@ -439,21 +557,14 @@ function Base.replace(mach::Machine{<:Surrogate},

W = glb(operation_nodes..., report_nodes...)

# Note: We construct nodes of the new network as values of a
# dictionary keyed on the nodes of the old network. Additionally,
# there are dictionaries of models keyed on old models and
# machines keyed on old machines. The node and machine
# dictionaries must be built simultaneously.

# build model dict:
# Instantiate model dictionary:
model_pairs = filter(collect(pairs)) do pair
first(pair) isa Model
end
models_ = models(W)
models_to_copy = setdiff(models_, first.(model_pairs))
model_copy_pairs = [model=>deepcopy(model) for model in models_to_copy]
newmodel_given_old = IdDict(vcat(model_pairs, model_copy_pairs))

# build complete source replacement pairs:
sources_ = sources(W)
specified_source_pairs = filter(collect(pairs)) do pair
Expand All @@ -476,58 +587,65 @@ function Base.replace(mach::Machine{<:Surrogate},
end

all_source_pairs = vcat(specified_source_pairs, unspecified_source_pairs)
newnode_given_old =
IdDict{AbstractNode,AbstractNode}(all_source_pairs)
newsources = [newnode_given_old[s] for s in sources(W)]

nodes_ = nodes(W)
newsignature = copysignature(signature, newnode_given_old, newmodel_given_old=newmodel_given_old)


# instantiate node and machine dictionaries:
return machine(mach.model, newsources...; newsignature...)

end


###############################################################################
##### SAVE AND RESTORE FOR COMPOSITES #####
###############################################################################


"""
save(model::Composite, fitresult)

Returns a new `CompositeFitresult` that is a shallow copy of the original one.
To do so, we build a copy of the learning network where each machine contained
in it needs to be called `serializable` upon.

Ideally this method should "reuse" as much as possible `Base.replace`.
olivierlabayle marked this conversation as resolved.
Show resolved Hide resolved
"""
function save(model::Composite, fitresult)
signature = MLJBase.signature(fitresult)
operation_nodes = values(MLJBase._operation_part(signature))
report_nodes = values(MLJBase._report_part(signature))
W = glb(operation_nodes..., report_nodes...)
newnode_given_old =
IdDict{AbstractNode,AbstractNode}(all_source_pairs)
newsources = [newnode_given_old[s] for s in sources_]
newoperation_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newreport_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newmach_given_old = IdDict{Machine,Machine}()
IdDict{AbstractNode,AbstractNode}([old => source() for old in sources(W)])

# build the new network:
for N in nodes_
if N isa Node # ie, not a `Source`
args = [newnode_given_old[arg] for arg in N.args]
if N.machine === nothing
newnode_given_old[N] = node(N.operation, args...)
else
if N.machine in keys(newmach_given_old)
m = newmach_given_old[N.machine]
else
train_args = [newnode_given_old[arg] for arg in N.machine.args]
m = Machine(newmodel_given_old[N.machine.model],
train_args...)
newmach_given_old[N.machine] = m
end
newnode_given_old[N] = N.operation(m, args...)
end
end
if N in operation_nodes # could be `Source`
newoperation_node_given_old[N] = newnode_given_old[N]
elseif N in report_nodes
newreport_node_given_old[N] = newnode_given_old[N]
end
end
newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in
operation_nodes)
newreport_nodes = Tuple(newreport_node_given_old[N] for N in
report_nodes)
report_tuple =
NamedTuple{keys(_report_part(signature))}(newreport_nodes)
operation_tuple =
NamedTuple{keys(_operation_part(signature))}(newoperation_nodes)
newsignature = copysignature(signature, newnode_given_old; newmodel_given_old=nothing)

newsignature = if isempty(report_tuple)
operation_tuple
else
merge(operation_tuple, (report=report_tuple,))
newfitresult = MLJBase.CompositeFitresult(newsignature)
setfield!(newfitresult, :report_additions, getfield(fitresult, :report_additions))

return newfitresult
end

"""
restore!(mach::Machine{<:Composite})

Restores a machine of a composite model by restoring all
submachines contained in it.
"""
function restore!(mach::Machine{<:Composite})
glb_node = glb(mach)
for submach in machines(glb_node)
restore!(submach)
end
return mach
end

return machine(mach.model, newsources...; newsignature...)

function setreport!(mach::Machine{<:Composite}, report)
basereport = MLJBase.report(glb(mach))
report_additions = Base.structdiff(report, basereport)
mach.report = merge(basereport, report_additions)
end
Loading