Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal to add support to allow non-generalizing models to contribute to a machine's report #806

Merged
merged 6 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ ComputationalResources = "0.3"
Distributions = "0.25.3"
InvertedIndices = "1"
LossFunctions = "0.5, 0.6, 0.7, 0.8"
MLJModelInterface = "1.5"
MLJModelInterface = "1.6"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
PrettyTables = "1"
ProgressMeter = "1.7.1"
ScientificTypes = "3"
StatisticalTraits = "3"
StatisticalTraits = "3.2"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1.0"
julia = "1.6"
35 changes: 25 additions & 10 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,22 @@ $DOC_SIGNATURES
"""
glb(mach::Machine{<:Union{Composite,Surrogate}}) = glb(mach.fitresult)

"""
report(fitresult::CompositeFitresult)

Return a tuple combining the report from `fitresult.glb` (a `Node` report) with the
additions coming from nodes declared as report nodes in `fitresult.signature`, but without
merging the two.

$DOC_SIGNATURES

**Private method**
"""
function report(fitresult::CompositeFitresult)
basic = report(glb(fitresult))
additions = _call(_report_part(signature(fitresult)))
return (; basic, additions)
end

"""
fit!(mach::Machine{<:Surrogate};
Expand All @@ -245,11 +261,10 @@ See also [`machine`](@ref)

"""
function fit!(mach::Machine{<:Surrogate}; kwargs...)
glb_node = glb(mach)
fit!(glb_node; kwargs...)
glb = MLJBase.glb(mach)
fit!(glb; kwargs...)
mach.state += 1
report_additions_ = _call(_report_part(signature(mach.fitresult)))
mach.report = merge(report(glb_node), report_additions_)
mach.report = MLJBase.report(mach.fitresult)
return mach
end

Expand Down Expand Up @@ -347,7 +362,7 @@ the following:

- Calls `fit!(mach, verbosity=verbosity, acceleration=acceleration)`.

- Records a copy of `model` in a variable called `cache`.
- Records (among other things) a copy of `model` in a variable called `cache`

- Returns `cache` and outcomes of training in an appropriate form
(specifically, `(mach.fitresult, cache, mach.report)`; see [Adding
Expand Down Expand Up @@ -396,6 +411,7 @@ function return!(mach::Machine{<:Surrogate},
# record the current hyper-parameter values:
old_model = deepcopy(model)

glb = MLJBase.glb(mach)
cache = (; old_model)

setfield!(mach.fitresult,
Expand Down Expand Up @@ -647,9 +663,8 @@ function restore!(mach::Machine{<:Composite})
return mach
end


function setreport!(mach::Machine{<:Composite}, report)
basereport = MLJBase.report(glb(mach))
report_additions = Base.structdiff(report, basereport)
mach.report = merge(basereport, report_additions)
function setreport!(copymach::Machine{<:Composite}, mach)
basic = report(glb(copymach.fitresult))
additions = mach.report.additions
copymach.report = (; basic, additions)
end
26 changes: 16 additions & 10 deletions src/composition/models/inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ try_scalarize(v) = length(v) == 1 ? v[1] : v
function machines_given_model_name(mach::Machine{M}) where M<:Composite
network_model_names = getfield(mach.fitresult, :network_model_names)
names = unique(filter(name->!(name === nothing), network_model_names))
network_models = MLJBase.models(glb(mach))
network_machines = MLJBase.machines(glb(mach))
glb = MLJBase.glb(mach)
network_models = MLJBase.models(glb)
network_machines = MLJBase.machines(glb)
ret = LittleDict{Symbol,Any}()
for name in names
mask = map(==(name), network_model_names)
Expand All @@ -17,22 +18,27 @@ function machines_given_model_name(mach::Machine{M}) where M<:Composite
return ret
end

function tuple_keyed_on_model_names(item_given_machine, mach)
function tuple_keyed_on_model_names(machines, mach, f)
dict = MLJBase.machines_given_model_name(mach)
names = tuple(keys(dict)...)
named_tuple_values = map(names) do name
[item_given_machine[m] for m in dict[name]] |> try_scalarize
[f(m) for m in dict[name]] |> try_scalarize
end
return NamedTuple{names}(named_tuple_values)
end

function report(mach::Machine{<:Composite})
dict = mach.report.report_given_machine
return merge(tuple_keyed_on_model_names(dict, mach), mach.report)
function report(mach::Machine{<:Union{Composite,Surrogate}})
report_additions = mach.report.additions
report_basic = mach.report.basic
report_components = mach isa Machine{<:Surrogate} ? NamedTuple() :
MLJBase.tuple_keyed_on_model_names(report_basic.machines, mach, MLJBase.report)
return merge(report_components, report_basic, report_additions)
end

function fitted_params(mach::Machine{<:Composite})
fp = fitted_params(mach.model, mach.fitresult)
dict = fp.fitted_params_given_machine
return merge(MLJBase.tuple_keyed_on_model_names(dict, mach), fp)
fp_basic = fitted_params(mach.model, mach.fitresult)
machines = fp_basic.machines
fp_components =
MLJBase.tuple_keyed_on_model_names(machines, mach, MLJBase.fitted_params)
return merge(fp_components, fp_basic)
end
13 changes: 7 additions & 6 deletions src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,25 @@ function update(model::M,
# underlying learning network machine.

network_model_names = getfield(fitresult, :network_model_names)
old_model = cache.old_model

glb_node = glb(fitresult) # greatest lower bound
old_model = cache.old_model
glb = MLJBase.glb(fitresult) # greatest lower bound of network, a node

if fallback(model, old_model, network_model_names, glb_node)
if fallback(model, old_model, network_model_names, glb)
return fit(model, verbosity, args...)
end

fit!(glb_node; verbosity=verbosity)
fit!(glb; verbosity=verbosity)

# Retrieve additional report values
report_additions_ = _call(_report_part(signature(fitresult)))
report = MLJBase.report(fitresult)

# record current model state:
cache = (; old_model = deepcopy(model))

return (fitresult,
cache,
merge(report(glb_node), report_additions_))
report)

end

Expand Down
2 changes: 1 addition & 1 deletion src/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(
# ## Training losses

function MMI.training_losses(pipe::SupervisedPipeline, pipe_report)
mach = supervised(pipe_report.machines)
mach = supervised(pipe_report.basic.machines)
_report = report(mach)
return training_losses(mach.model, _report)
end
2 changes: 1 addition & 1 deletion src/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ end
# # TRAINING LOSSES

function training_losses(model::SomeTT, tt_report)
mach = first(tt_report.machines)
mach = first(tt_report.basic.machines)
return training_losses(mach)
end

Expand Down
18 changes: 11 additions & 7 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,11 @@ end
# Not one, but *two*, fit methods are defined for machines here,
# `fit!` and `fit_only!`.

# - `fit_only!`: trains a machine without touching the learned
# parameters (`fitresult`) of any other machine. It may error if
# another machine on which it depends (through its node training
# arguments `N1, N2, ...`) has not been trained.
# - `fit_only!`: trains a machine without touching the learned parameters (`fitresult`) of
# any other machine. It may error if another machine on which it depends (through its node
# training arguments `N1, N2, ...`) has not been trained. It's possible that a dependent
# machine `mach` may have it's report mutated if `reporting_operations(mach.model)` is
# non-empty.

# - `fit!`: trains a machine after first progressively training all
# machines on which the machine depends. Implicitly this involves
Expand Down Expand Up @@ -909,13 +910,14 @@ function serializable(mach::Machine{<:Any, C}) where C
setfield!(copymach, fieldname, ())
# Make fitresult ready for serialization
elseif fieldname == :fitresult
# this `save` does the actual emptying of fields
copymach.fitresult = save(mach.model, getfield(mach, fieldname))
else
setfield!(copymach, fieldname, getfield(mach, fieldname))
end
end

setreport!(copymach, mach.report)
setreport!(copymach, mach)

return copymach
end
Expand Down Expand Up @@ -997,6 +999,8 @@ function save(file::Union{String,IO},
serialize(file, smach)
end

setreport!(copymach, mach) =
setfield!(copymach, :report, mach.report)

setreport!(mach::Machine, report) =
setfield!(mach, :report, report)
# NOTE. there is also a specialization for `setreport!` for `Composite` models, defined in
# /src/composition/learning_networks/machines/
78 changes: 55 additions & 23 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,37 +37,59 @@ warn_serializable_mach(operation) = "The operation $operation has been called on
"deserialised machine mach whose learned parameters "*
"may be unusable. To be sure, first run restore!(mach)."

# Given return value `ret` of an operation with symbol `operation` (eg, `:predict`) return
# `ret` in the ordinary case that the operation does not include an "report" component ;
# otherwise update `mach.report` with that component and return the non-report part of
# `ret`:
function get!(ret, operation, mach)
if operation in reporting_operations(mach.model)
report = last(ret)
if isnothing(mach.report) || isempty(mach.report)
mach.report = report
else
mach.report = merge(mach.report, report)
end
return first(ret)
end
return ret
end

# 0. operations on machine, given rows=...:

for operation in OPERATIONS

if operation != :inverse_transform
quoted_operation = QuoteNode(operation) # eg, :(:predict)

ex = quote
function $(operation)(mach::Machine{<:Model,false}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && _err_serialized($operation)
return ($operation)(mach, mach.args[1](rows=rows))
end
function $(operation)(mach::Machine{<:Model,true}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && _err_serialized($operation)
model = mach.model
return ($operation)(model,
mach.fitresult,
selectrows(model, rows, mach.data[1])...)
end
end
eval(ex)
operation == :inverse_transform && continue

ex = quote
function $(operation)(mach::Machine{<:Model,false}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && _err_serialized($operation)
ret = ($operation)(mach, mach.args[1](rows=rows))
return get!(ret, $quoted_operation, mach)
end
function $(operation)(mach::Machine{<:Model,true}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && _err_serialized($operation)
model = mach.model
ret = ($operation)(
model,
mach.fitresult,
selectrows(model, rows, mach.data[1])...,
)
return get!(ret, $quoted_operation, mach)
end
end
eval(ex)

end

# special case of Static models (no training arguments):
transform(mach::Machine{<:Static}; rows=:) = _err_rows_not_allowed()

inverse_transform(mach::Machine; rows=:) =
throw(ArgumentError("`inverse_transform()(mach)` and "*
throw(ArgumentError("`inverse_transform(mach)` and "*
"`inverse_transform(mach, rows=...)` are "*
"not supported. Data or nodes "*
"must be explictly specified, "*
Expand All @@ -77,22 +99,32 @@ _symbol(f) = Base.Core.Typeof(f).name.mt.name

for operation in OPERATIONS

quoted_operation = QuoteNode(operation) # eg, :(:predict)

ex = quote
# 1. operations on machines, given *concrete* data:
function $operation(mach::Machine, Xraw)
if mach.state != 0
mach.state == -1 && @warn warn_serializable_mach($operation)
return $(operation)(mach.model,
mach.fitresult,
reformat(mach.model, Xraw)...)
ret = $(operation)(
mach.model,
mach.fitresult,
reformat(mach.model, Xraw)...,
)
get!(ret, $quoted_operation, mach)
else
error("$mach has not been trained.")
end
end

function $operation(mach::Machine{<:Static}, Xraw, Xraw_more...)
return $(operation)(mach.model, mach.fitresult,
Xraw, Xraw_more...)
ret = $(operation)(
mach.model,
mach.fitresult,
Xraw,
Xraw_more...,
)
get!(ret, $quoted_operation, mach)
end

# 2. operations on machines, given *dynamic* data (nodes):
Expand Down
6 changes: 3 additions & 3 deletions test/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ end
@test Θ.transform == Wout
Θ.report.some_stuff == rnode
@test report(mach).some_stuff == :stuff

@test report(mach).machines == fitted_params(mach).machines

# supervised
Expand Down Expand Up @@ -281,7 +280,7 @@ end
end

# Testing extra report field : it is a deepcopy
@test smach.report.cv_report === mach.report.cv_report
@test report(smach).cv_report === report(mach).cv_report

@test smach.fitresult isa MLJBase.CompositeFitresult

Expand Down Expand Up @@ -356,7 +355,8 @@ end
metalearner = FooBarRegressor(lambda=1.),
resampling = dcv,
model_1 = DeterministicConstantRegressor(),
model_2=ConstantRegressor())
model_2=ConstantRegressor()
)

filesizes = []
for n in [100, 500, 1000]
Expand Down
Loading