From 9f6bcdb4cbee541c02effdd4302fe508b02df5f9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 10:27:19 +0000 Subject: [PATCH 01/18] migrate current state from MLJSerialization --- Project.toml | 1 + src/MLJBase.jl | 1 + src/composition/learning_networks/machines.jl | 106 +++++++++++++ src/machines.jl | 146 ++++++++++++++++++ .../composition/learning_networks/machines.jl | 107 +++++++++++++ test/machines.jl | 34 ++++ test/test_utilities.jl | 35 +++++ 7 files changed, 430 insertions(+) diff --git a/Project.toml b/Project.toml index df3f19da..6e7d5f5c 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 3101cd94..da31010e 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -59,6 +59,7 @@ import MLJModelInterface: fit, update, update_data, transform, using Parameters # Containers & data manipulation +using Serialization using Tables import PrettyTables using DelimitedFiles diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 882e12c8..4a123705 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -531,3 +531,109 @@ function Base.replace(mach::Machine{<:Surrogate}, return machine(mach.model, newsources...; newsignature...) end + + +############################################################################### +##### SAVE AND RESTORE FOR COMPOSITES ##### +############################################################################### + + +""" + save(model::Composite, fitresult) + +Returns a new `CompositeFitresult` that is a shallow copy of the original one. +To do so, we build a copy of the learning network where each machine contained +in it needs to be called `serializable` upon. + +Ideally this method should "reuse" as much as possible `Base.replace`. +""" +function save(model::Composite, fitresult) + # THIS IS WIP: NOT WORKING + signature = MLJBase.signature(fitresult) + + operation_nodes = values(MLJBase._operation_part(signature)) + report_nodes = values(MLJBase._report_part(signature)) + + W = glb(operation_nodes..., report_nodes...) + + nodes_ = filter(x -> !(x isa Source), nodes(W)) + + # instantiate node dictionary with source nodes and exception nodes + # This supposes that exception nodes only occur in the signature otherwise we need + # to to this differently + newnode_given_old = + IdDict{AbstractNode,AbstractNode}([old => source() for old in sources(W)]) + # Other useful mappings + newoperation_node_given_old = + IdDict{AbstractNode,AbstractNode}() + newreport_node_given_old = + IdDict{AbstractNode,AbstractNode}() + newmach_given_old = IdDict{Machine,Machine}() + + # build the new network, nodes are nicely ordered + for N in nodes_ + # Retrieve the future node's ancestors + args = [newnode_given_old[arg] for arg in N.args] + if N.machine === nothing + newnode_given_old[N] = node(N.operation, args...) + else + # The same machine can be associated with multiple nodes + if N.machine in keys(newmach_given_old) + m = newmach_given_old[N.machine] + else + m = serializable(N.machine) + m.args = Tuple(newnode_given_old[s] for s in N.machine.args) + newmach_given_old[N.machine] = m + end + newnode_given_old[N] = N.operation(m, args...) + end + # Sort nodes according to: operation_node, report_node + if N in operation_nodes + newoperation_node_given_old[N] = newnode_given_old[N] + elseif N in report_nodes + newreport_node_given_old[N] = newnode_given_old[N] + end + end + + newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in + operation_nodes) + newreport_nodes = Tuple(newreport_node_given_old[N] for N in + report_nodes) + report_tuple = + NamedTuple{keys(MLJBase._report_part(signature))}(newreport_nodes) + operation_tuple = + NamedTuple{keys(MLJBase._operation_part(signature))}(newoperation_nodes) + + newsignature = if isempty(report_tuple) + operation_tuple + else + merge(operation_tuple, (report=report_tuple,)) + end + + + newfitresult = MLJBase.CompositeFitresult(newsignature) + setfield!(newfitresult, :report_additions, report_tuple) + + return newfitresult +end + +""" + restore!(mach::Machine{<:Composite}) + +Restores a machine of a composite model by restoring all +submachines contained in it. +""" +function restore!(mach::Machine{<:Composite}) + glb_node = glb(mach) + for submach in machines(glb_node) + restore!(submach) + end + return mach +end + + +function setreport!(mach::Machine{<:Composite}, report) + basereport = MLJBase.report(glb(mach)) + report_additions = Base.structdiff(report, basereport) + mach.report = merge(basereport, report_additions) +end diff --git a/src/machines.jl b/src/machines.jl index c251b22b..2816cad1 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -378,6 +378,22 @@ function machine(model::Model, arg1::AbstractNode, args::AbstractNode...; return Machine(model, arg1, args...; kwargs...) end +""" + machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) + +Rebuild from a file a machine that has been serialized using the default +Serialization module. +""" +function machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) + smach = deserialize(file) + restore!(smach) + if raw_arg1 !== nothing + args = source.((raw_arg1, raw_args...)) + MLJBase.check(smach.model, args...; full=true) + smach.args = args + end + return smach +end ## INSPECTION AND MINOR MANIPULATION OF FIELDS @@ -796,3 +812,133 @@ function training_losses(mach::Machine) throw(NotTrainedError(mach, :training_losses)) end end + + +############################################################################### +##### SERIALIZABLE, RESTORE!, SAVE AND A FEW UTILITY FUNCTIONS ##### +############################################################################### + + +""" + serializable(mach::Machine) + +Returns a shallow copy of the machine to make it serializable, in particular: + - Removes all data from caches, args and data fields + - Makes all `fitresults` serializable + - Annotates the state as -1 +""" +function serializable(mach::Machine{<:Any, C}) where C + copymach = machine(mach.model, mach.args..., cache=C) + + for fieldname in fieldnames(Machine) + if fieldname ∈ (:model, :report) + continue + elseif fieldname == :state + setfield!(copymach, :state, -1) + # Wipe data from cache + elseif fieldname == :cache + setfield!(copymach, :cache, serializable_cache(mach.cache)) + elseif fieldname == :args + setfield!(copymach, fieldname, ()) + # Let those fields undefined + elseif fieldname ∈ (:data, :resampled_data, :old_rows) + continue + # Make fitresult ready for serialization + elseif fieldname == :fitresult + copymach.fitresult = save(mach.model, getfield(mach, fieldname)) + else + setfield!(copymach, fieldname, getfield(mach, fieldname)) + end + end + + setreport!(copymach, mach.report) + + return copymach +end + +""" + restore!(mach::Machine) + +Default method to restores the state of a machine that is currently serializable. +Such a machine is annotated with `state=-1` +""" +function restore!(mach::Machine) + mach.fitresult = restore(mach.model, mach.fitresult) + return mach +end + + +""" + MLJ.save(filename, mach::Machine) + MLJ.save(io, mach::Machine) + + MLJBase.save(filename, mach::Machine) + MLJBase.save(io, mach::Machine) + +Serialize the machine `mach` to a file with path `filename`, or to an +input/output stream `io` (at least `IOBuffer` instances are +supported) using the Serialization module. + +Machines are de-serialized using the `machine` constructor as shown in +the example below. Data (or nodes) may be optionally passed to the +constructor for retraining on new data using the saved model. + + +### Example + + using MLJ + tree = @load DecisionTreeClassifier + X, y = @load_iris + mach = fit!(machine(tree, X, y)) + + MLJ.save("tree.jlso", mach) + mach_predict_only = machine("tree.jlso") + predict(mach_predict_only, X) + + mach2 = machine("tree.jlso", selectrows(X, 1:100), y[1:100]) + predict(mach2, X) # same as above + + fit!(mach2) # saved learned parameters are over-written + predict(mach2, X) # not same as above + + # using a buffer: + io = IOBuffer() + MLJ.save(io, mach) + seekstart(io) + predict_only_mach = machine(io) + predict(predict_only_mach, X) + +!!! warning "Only load files from trusted sources" + Maliciously constructed JLSO files, like pickles, and most other + general purpose serialization formats, can allow for arbitrary code + execution during loading. This means it is possible for someone + to use a JLSO file that looks like a serialized MLJ machine as a + [Trojan + horse](https://en.wikipedia.org/wiki/Trojan_horse_(computing)). + +""" +function save(file::Union{String,IO}, + mach::Machine) + isdefined(mach, :fitresult) || + error("Cannot save an untrained machine. ") + + smach = serializable(mach) + + serialize(file, smach) +end + + +setreport!(mach::Machine, report) = + setfield!(mach, :report, report) + + +maybe_serializable(val) = val +maybe_serializable(val::Machine) = serializable(val) + + +serializable_cache(cache) = cache +serializable_cache(cache::Tuple) = Tuple(maybe_serializable(val) for val in cache) +function serializable_cache(cache::NamedTuple) + new_keys = filter(!=(:data), keys(cache)) + return NamedTuple{new_keys}([maybe_serializable(cache[key]) for key in new_keys]) +end \ No newline at end of file diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index 74329bf3..95875a43 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -6,6 +6,7 @@ using ..TestUtilities using MLJBase using Tables using StableRNGs +using Serialization rng = StableRNG(616161) # A dummy clustering model: @@ -230,6 +231,112 @@ enode = @node mae(ys, yhat) end +@testset "Test serializable of pipeline" begin + # Composite model with some C inside + filename = "pipe_mach.jls" + X, y = TestUtilities.simpledata() + pipe = (X -> coerce(X, :x₁=>Continuous)) |> DecisionTreeRegressor() + mach = machine(pipe, X, y) + fit!(mach, verbosity=0) + + # Check serializable function + smach = MLJBase.serializable(mach) + TestUtilities.generic_tests(mach, smach) + @test MLJBase.predict(smach, X) == MLJBase.predict(mach, X) + @test keys(fitted_params(smach)) == keys(fitted_params(mach)) + @test keys(report(smach)) == keys(report(mach)) + # Check data has been wiped out from models at the first level of composition + @test length(machines(glb(smach))) == length(machines(glb(mach))) + for submach in machines(glb(smach)) + TestUtilities.test_data(submach) + end + + # End to end + MLJBase.save(filename, mach) + smach = machine(filename) + @test predict(smach, X) == predict(mach, X) + + rm(filename) +end + + +@testset "Test serializable of composite machines" begin + # Composite model with some C inside + filename = "stack_mach.jls" + X, y = TestUtilities.simpledata() + model = Stack( + metalearner = DecisionTreeRegressor(), + tree1 = DecisionTreeRegressor(min_samples_split=3), + tree2 = DecisionTreeRegressor(), + measures=rmse) + mach = machine(model, X, y) + fit!(mach, verbosity=0) + + # Check serializable function + smach = MLJBase.serializable(mach) + TestUtilities.generic_tests(mach, smach) + # Check data has been wiped out from models at the first level of composition + @test length(machines(glb(smach))) == length(machines(glb(mach))) + for submach in machines(glb(smach)) + @test !isdefined(submach, :data) + @test !isdefined(submach, :resampled_data) + @test submach.cache isa Nothing || :data ∉ keys(submach.cache) + end + + # Testing extra report field : it is a deepcopy + @test smach.report.cv_report === mach.report.cv_report + + @test smach.fitresult isa MLJBase.CompositeFitresult + + Serialization.serialize(filename, smach) + smach = Serialization.deserialize(filename) + MLJBase.restore!(smach) + + @test MLJBase.predict(smach, X) == MLJBase.predict(mach, X) + @test keys(fitted_params(smach)) == keys(fitted_params(mach)) + @test keys(report(smach)) == keys(report(mach)) + + rm(filename) + + # End to end + MLJBase.save(filename, mach) + smach = machine(filename) + @test predict(smach, X) == predict(mach, X) + + rm(filename) +end + +@testset "Test serializable of nested composite machines" begin + # Composite model with some C inside + filename = "nested stack_mach.jls" + X, y = TestUtilities.simpledata() + + pipe = (X -> coerce(X, :x₁=>Continuous)) |> DecisionTreeRegressor() + model = Stack( + metalearner = DecisionTreeRegressor(), + pipe = pipe) + mach = machine(model, X, y) + fit!(mach, verbosity=0) + + MLJBase.save(filename, mach) + smach = machine(filename) + + @test predict(smach, X) == predict(mach, X) + + # Test data as been erased at the first and second level of composition + for submach in machines(glb(smach)) + TestUtilities.test_data(submach) + if submach isa Machine{<:Composite} + for subsubmach in machines(glb(submach)) + TestUtilities.test_data(subsubmach) + end + end + end + + rm(filename) + + +end mutable struct DummyComposite <: DeterministicComposite stand1 diff --git a/test/machines.jl b/test/machines.jl index e697487a..b5d34176 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -6,6 +6,7 @@ using Statistics using ..Models const MLJModelInterface = MLJBase.MLJModelInterface using StableRNGs +using Serialization using ..TestUtilities DecisionTreeRegressor() @@ -262,6 +263,39 @@ end end +@testset "Test serializable method of simple machines" begin + X, y = TestUtilities.simpledata() + filename = "decisiontree.jls" + mach = machine(DecisionTreeRegressor(), X, y) + fit!(mach, verbosity=0) + # Check serializable function + smach = MLJBase.serializable(mach) + @test smach.report == mach.report + @test smach.fitresult == mach.fitresult + TestUtilities.generic_tests(mach, smach) + # Check restore! function + Serialization.serialize(filename, smach) + smach = Serialization.deserialize(filename) + MLJBase.restore!(smach) + + @test MLJBase.predict(smach, X) == MLJBase.predict(mach, X) + @test fitted_params(smach) isa NamedTuple + @test report(smach) == report(mach) + + rm(filename) + + # End to end save and reload + MLJBase.save(filename, mach) + smach = machine(filename) + @test predict(smach, X) == predict(mach, X) + + # Try to reset the data + smach = machine(filename, X, y) + fit!(smach, verbosity=0) + @test predict(smach) == predict(mach) + + rm(filename) +end end # module diff --git a/test/test_utilities.jl b/test/test_utilities.jl index d7120460..ab35616c 100644 --- a/test/test_utilities.jl +++ b/test/test_utilities.jl @@ -93,4 +93,39 @@ macro test_model_sequence(fit_ex, sequence_exs...) end) end +############################################################################### +##### THE FOLLOWINGS ARE USED TO TEST SERIALIZATION CAPACITIES ##### +############################################################################### + + +function test_args(mach) + # Check source nodes are empty if any + for arg in mach.args + if arg isa Source + @test arg == source() + end + end +end + +function test_data(mach) + @test !isdefined(mach, :old_rows) + @test !isdefined(mach, :data) + @test !isdefined(mach, :resampled_data) + if mach isa NamedTuple + @test :data ∉ keys(mach.cache) + end +end + +function generic_tests(mach₁, mach₂) + test_args(mach₂) + test_data(mach₂) + @test mach₂.state == -1 + for field in (:frozen, :model, :old_model, :old_upstream_state, :fit_okay) + @test getfield(mach₁, field) == getfield(mach₂, field) + end +end + +simpledata(;n=100) = (x₁=rand(n),), rand(n) + + end From 02a42b516506ba01f44dd136b18c9a3f73c456e7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 11:00:02 +0000 Subject: [PATCH 02/18] add filesize check for composite --- .../composition/learning_networks/machines.jl | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index 95875a43..de60b0bc 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -232,7 +232,6 @@ end @testset "Test serializable of pipeline" begin - # Composite model with some C inside filename = "pipe_mach.jls" X, y = TestUtilities.simpledata() pipe = (X -> coerce(X, :x₁=>Continuous)) |> DecisionTreeRegressor() @@ -261,7 +260,6 @@ end @testset "Test serializable of composite machines" begin - # Composite model with some C inside filename = "stack_mach.jls" X, y = TestUtilities.simpledata() model = Stack( @@ -307,8 +305,7 @@ end end @testset "Test serializable of nested composite machines" begin - # Composite model with some C inside - filename = "nested stack_mach.jls" + filename = "nested_stack_mach.jls" X, y = TestUtilities.simpledata() pipe = (X -> coerce(X, :x₁=>Continuous)) |> DecisionTreeRegressor() @@ -335,7 +332,33 @@ end rm(filename) +end +@testset "Test serialized filesize does not increase with datasize" begin + model = Stack( + metalearner = FooBarRegressor(lambda=1.), + model_1 = DeterministicConstantRegressor(), + model_2=ConstantRegressor()) + + filesizes = [] + for n in [100, 500, 1000] + filename = "serialized_temp_$n.jls" + X, y = TestUtilities.simpledata(n=n) + mach = machine(model, X, y) + fit!(mach, verbosity=0) + MLJBase.save(filename, mach) + push!(filesizes, filesize(filename)) + rm(filename) + end + @test all(x==filesizes[1] for x in filesizes) + # What if no serializable procedure had happened + filename = "full_of_data.jls" + X, y = TestUtilities.simpledata(n=1000) + mach = machine(model, X, y) + fit!(mach, verbosity=0) + serialize(filename, mach) + @test filesize(filename) > filesizes[1] + rm(filename) end mutable struct DummyComposite <: DeterministicComposite From 10d91949418c4af2311c54bced553aa08b8b667d Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 11:19:58 +0000 Subject: [PATCH 03/18] add some Misc tests --- test/machines.jl | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/test/machines.jl b/test/machines.jl index b5d34176..e21e562e 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -4,12 +4,11 @@ using MLJBase using Test using Statistics using ..Models -const MLJModelInterface = MLJBase.MLJModelInterface using StableRNGs using Serialization using ..TestUtilities -DecisionTreeRegressor() +const MLJModelInterface = MLJBase.MLJModelInterface N=50 X = (a=rand(N), b=rand(N), c=rand(N)); @@ -297,6 +296,39 @@ end rm(filename) end +@testset "Test Misc functions used in `serializable`" begin + X, y = TestUtilities.simpledata() + mach = machine(DeterministicConstantRegressor(), X, y) + fit!(mach, verbosity=0) + # setreport! default + @test mach.report isa NamedTuple + MLJBase.setreport!(mach, "toto") + @test mach.report == "toto" + + # serializable_cache + # The default is to return the original cache + @test MLJBase.serializable_cache(mach.cache) === mach.cache + # For Tuples and NamedTuples, a machine might live in the cache + # (as it is the case in MLJTuned model) + # and therefore has to be called `serializable` upon + # if a `data` field lives in the cache, it is removed, this is for + # learning networks. However this might become useless if the + # data anomynization process in fit! is removed + submach = machine(DeterministicConstantRegressor(), X, y) + fit!(submach, verbosity=0) + # Tuple type in cache + mach.cache = (submach,) + newcache = MLJBase.serializable_cache(mach.cache) + @test newcache[1] isa Machine + @test !isdefined(newcache[1], :data) + # NamedTuple type in cache + mach.cache = (machine=submach, data=(X,y)) + newcache = MLJBase.serializable_cache(mach.cache) + @test :data ∉ keys(newcache) + @test !isdefined(newcache[1], :data) +end + + end # module true From 4b5848845481221017815957b56899faefd3d016 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 11:23:58 +0000 Subject: [PATCH 04/18] add some docstrings --- src/machines.jl | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/machines.jl b/src/machines.jl index 2816cad1..bb86af30 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -935,9 +935,30 @@ setreport!(mach::Machine, report) = maybe_serializable(val) = val maybe_serializable(val::Machine) = serializable(val) +""" + serializable_cache(cache) +Default fallbacks to return the original cache. +""" serializable_cache(cache) = cache -serializable_cache(cache::Tuple) = Tuple(maybe_serializable(val) for val in cache) + +""" + serializable_cache(cache::Tuple) + +If the cache is a Tuple, any machine in the cache is called +`serializable` upon. This is to address TunedModels. A dispatch on +TunedModel would have been possible but would require a new api function. +""" +serializable_cache(cache::Tuple) = + Tuple(maybe_serializable(val) for val in cache) + +""" + serializable_cache(cache::NamedTuple) + +If the cache is a NamedTuple, any data field is filtered, this is to address +the current learning networks cache. Any machine in the cache is also called +`serializable` upon. +""" function serializable_cache(cache::NamedTuple) new_keys = filter(!=(:data), keys(cache)) return NamedTuple{new_keys}([maybe_serializable(cache[key]) for key in new_keys]) From 6dd72707470aa95b3533dc7a71008ddf9ba787df Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 11:29:50 +0000 Subject: [PATCH 05/18] add UnupervisedModel test --- test/machines.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/machines.jl b/test/machines.jl index e21e562e..0f26c8a4 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -262,7 +262,7 @@ end end -@testset "Test serializable method of simple machines" begin +@testset "Test serializable method of Supervised Machine" begin X, y = TestUtilities.simpledata() filename = "decisiontree.jls" mach = machine(DecisionTreeRegressor(), X, y) @@ -296,6 +296,20 @@ end rm(filename) end +@testset "Test serializable method of Unsupervised Machine" begin + X, _ = TestUtilities.simpledata() + filename = "standardizer.jls" + mach = machine(Standardizer(), X) + fit!(mach, verbosity=0) + + MLJBase.save(filename, mach) + smach = machine(filename) + + @test transform(mach, X) == transform(smach, X) + + rm(filename) +end + @testset "Test Misc functions used in `serializable`" begin X, y = TestUtilities.simpledata() mach = machine(DeterministicConstantRegressor(), X, y) From 25100d3666803f15394d5b21f2100455362462de Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 11:36:15 +0000 Subject: [PATCH 06/18] add warn message when loading machine with state != -1 --- src/machines.jl | 3 +++ test/composition/learning_networks/machines.jl | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/machines.jl b/src/machines.jl index bb86af30..3fc04eb0 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -386,6 +386,9 @@ Serialization module. """ function machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) smach = deserialize(file) + smach.state == -1 || + @warn "Deserialized machine state is not -1 (=$(smach.state)). "* + "It means that the machine has not been saved by a conventional MLJ routine." restore!(smach) if raw_arg1 !== nothing args = source.((raw_arg1, raw_args...)) diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index de60b0bc..895d4a25 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -358,6 +358,12 @@ end fit!(mach, verbosity=0) serialize(filename, mach) @test filesize(filename) > filesizes[1] + + @test_logs (:warn, "Deserialized machine state is"* + " not -1 (=1). It means that the"* + " machine has not been saved by a"* + " conventional MLJ routine.") machine(filename) + rm(filename) end From 3581923c2a744303fc9b2d77476a4b62548e7231 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 12:59:10 +0000 Subject: [PATCH 07/18] refactor Base.replace to use in save --- src/composition/learning_networks/machines.jl | 202 ++++++++---------- .../composition/learning_networks/machines.jl | 1 + 2 files changed, 91 insertions(+), 112 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 4a123705..91ef5dcf 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -415,8 +415,92 @@ network_model_names(model::Nothing, mach::Machine{<:Surrogate}) = nothing +function copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old) + train_args = [newnode_given_old[arg] for arg in N.machine.args] + return Machine(newmodel_given_old[N.machine.model], + train_args...) +end + +function copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old) + m = serializable(N.machine) + m.args = Tuple(newnode_given_old[s] for s in N.machine.args) + return m +end + ## DUPLICATING AND REPLACING PARTS OF A LEARNING NETWORK MACHINE +function update_mappings_with_node!( + newnode_given_old, + newmach_given_old, + newmodel_given_old, + N::AbstractNode) + args = [newnode_given_old[arg] for arg in N.args] + if N.machine === nothing + newnode_given_old[N] = node(N.operation, args...) + else + if N.machine in keys(newmach_given_old) + m = newmach_given_old[N.machine] + else + m = copy_or_replace_machine(N, newmodel_given_old, newnode_given_old) + newmach_given_old[N.machine] = m + end + newnode_given_old[N] = N.operation(m, args...) + end +end + +update_mappings_with_node!( + newnode_given_old, + newmach_given_old, + newmodel_given_old, + N::Source) = nothing + +function copysignature(signature, newnode_given_old; newmodel_given_old=nothing) + operation_nodes = values(_operation_part(signature)) + report_nodes = values(_report_part(signature)) + W = glb(operation_nodes..., report_nodes...) + # Note: We construct nodes of the new network as values of a + # dictionary keyed on the nodes of the old network. Additionally, + # there are dictionaries of models keyed on old models and + # machines keyed on old machines. The node and machine + # dictionaries must be built simultaneously. + + # instantiate node and machine dictionaries: + newoperation_node_given_old = + IdDict{AbstractNode,AbstractNode}() + newreport_node_given_old = + IdDict{AbstractNode,AbstractNode}() + newmach_given_old = IdDict{Machine,Machine}() + + # build the new network: + for N in nodes(W) + update_mappings_with_node!( + newnode_given_old, + newmach_given_old, + newmodel_given_old, + N + ) + if N in operation_nodes # could be `Source` + newoperation_node_given_old[N] = newnode_given_old[N] + elseif N in report_nodes + newreport_node_given_old[N] = newnode_given_old[N] + end + end + newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in + operation_nodes) + newreport_nodes = Tuple(newreport_node_given_old[N] for N in + report_nodes) + report_tuple = + NamedTuple{keys(_report_part(signature))}(newreport_nodes) + operation_tuple = + NamedTuple{keys(_operation_part(signature))}(newoperation_nodes) + newsignature = if isempty(report_tuple) + operation_tuple + else + merge(operation_tuple, (report=report_tuple,)) + end + + return newsignature +end """ replace(mach, a1=>b1, a2=>b2, ...; empty_unspecified_sources=false) @@ -439,13 +523,7 @@ function Base.replace(mach::Machine{<:Surrogate}, W = glb(operation_nodes..., report_nodes...) - # Note: We construct nodes of the new network as values of a - # dictionary keyed on the nodes of the old network. Additionally, - # there are dictionaries of models keyed on old models and - # machines keyed on old machines. The node and machine - # dictionaries must be built simultaneously. - - # build model dict: + # Instantiate model dictionary: model_pairs = filter(collect(pairs)) do pair first(pair) isa Model end @@ -453,7 +531,6 @@ function Base.replace(mach::Machine{<:Surrogate}, models_to_copy = setdiff(models_, first.(model_pairs)) model_copy_pairs = [model=>deepcopy(model) for model in models_to_copy] newmodel_given_old = IdDict(vcat(model_pairs, model_copy_pairs)) - # build complete source replacement pairs: sources_ = sources(W) specified_source_pairs = filter(collect(pairs)) do pair @@ -476,57 +553,12 @@ function Base.replace(mach::Machine{<:Surrogate}, end all_source_pairs = vcat(specified_source_pairs, unspecified_source_pairs) - - nodes_ = nodes(W) - - # instantiate node and machine dictionaries: newnode_given_old = IdDict{AbstractNode,AbstractNode}(all_source_pairs) - newsources = [newnode_given_old[s] for s in sources_] - newoperation_node_given_old = - IdDict{AbstractNode,AbstractNode}() - newreport_node_given_old = - IdDict{AbstractNode,AbstractNode}() - newmach_given_old = IdDict{Machine,Machine}() + newsources = [newnode_given_old[s] for s in sources(W)] - # build the new network: - for N in nodes_ - if N isa Node # ie, not a `Source` - args = [newnode_given_old[arg] for arg in N.args] - if N.machine === nothing - newnode_given_old[N] = node(N.operation, args...) - else - if N.machine in keys(newmach_given_old) - m = newmach_given_old[N.machine] - else - train_args = [newnode_given_old[arg] for arg in N.machine.args] - m = Machine(newmodel_given_old[N.machine.model], - train_args...) - newmach_given_old[N.machine] = m - end - newnode_given_old[N] = N.operation(m, args...) - end - end - if N in operation_nodes # could be `Source` - newoperation_node_given_old[N] = newnode_given_old[N] - elseif N in report_nodes - newreport_node_given_old[N] = newnode_given_old[N] - end - end - newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in - operation_nodes) - newreport_nodes = Tuple(newreport_node_given_old[N] for N in - report_nodes) - report_tuple = - NamedTuple{keys(_report_part(signature))}(newreport_nodes) - operation_tuple = - NamedTuple{keys(_operation_part(signature))}(newoperation_nodes) - - newsignature = if isempty(report_tuple) - operation_tuple - else - merge(operation_tuple, (report=report_tuple,)) - end + newsignature = copysignature(signature, newnode_given_old, newmodel_given_old=newmodel_given_old) + return machine(mach.model, newsources...; newsignature...) @@ -548,71 +580,17 @@ in it needs to be called `serializable` upon. Ideally this method should "reuse" as much as possible `Base.replace`. """ function save(model::Composite, fitresult) - # THIS IS WIP: NOT WORKING signature = MLJBase.signature(fitresult) - operation_nodes = values(MLJBase._operation_part(signature)) report_nodes = values(MLJBase._report_part(signature)) - W = glb(operation_nodes..., report_nodes...) - - nodes_ = filter(x -> !(x isa Source), nodes(W)) - - # instantiate node dictionary with source nodes and exception nodes - # This supposes that exception nodes only occur in the signature otherwise we need - # to to this differently newnode_given_old = IdDict{AbstractNode,AbstractNode}([old => source() for old in sources(W)]) - # Other useful mappings - newoperation_node_given_old = - IdDict{AbstractNode,AbstractNode}() - newreport_node_given_old = - IdDict{AbstractNode,AbstractNode}() - newmach_given_old = IdDict{Machine,Machine}() - - # build the new network, nodes are nicely ordered - for N in nodes_ - # Retrieve the future node's ancestors - args = [newnode_given_old[arg] for arg in N.args] - if N.machine === nothing - newnode_given_old[N] = node(N.operation, args...) - else - # The same machine can be associated with multiple nodes - if N.machine in keys(newmach_given_old) - m = newmach_given_old[N.machine] - else - m = serializable(N.machine) - m.args = Tuple(newnode_given_old[s] for s in N.machine.args) - newmach_given_old[N.machine] = m - end - newnode_given_old[N] = N.operation(m, args...) - end - # Sort nodes according to: operation_node, report_node - if N in operation_nodes - newoperation_node_given_old[N] = newnode_given_old[N] - elseif N in report_nodes - newreport_node_given_old[N] = newnode_given_old[N] - end - end - newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in - operation_nodes) - newreport_nodes = Tuple(newreport_node_given_old[N] for N in - report_nodes) - report_tuple = - NamedTuple{keys(MLJBase._report_part(signature))}(newreport_nodes) - operation_tuple = - NamedTuple{keys(MLJBase._operation_part(signature))}(newoperation_nodes) - - newsignature = if isempty(report_tuple) - operation_tuple - else - merge(operation_tuple, (report=report_tuple,)) - end - + newsignature = copysignature(signature, newnode_given_old; newmodel_given_old=nothing) newfitresult = MLJBase.CompositeFitresult(newsignature) - setfield!(newfitresult, :report_additions, report_tuple) + setfield!(newfitresult, :report_additions, getfield(fitresult, :report_additions)) return newfitresult end diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index 895d4a25..3ea6de2a 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -285,6 +285,7 @@ end @test smach.report.cv_report === mach.report.cv_report @test smach.fitresult isa MLJBase.CompositeFitresult + @test getfield(smach.fitresult, :report_additions) === getfield(mach.fitresult, :report_additions) Serialization.serialize(filename, smach) smach = Serialization.deserialize(filename) From a8c260c35418d790bbbe22e874dcba1ca4d8e746 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 28 Jan 2022 14:11:20 +0000 Subject: [PATCH 08/18] add docstrings --- src/composition/learning_networks/machines.jl | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 91ef5dcf..737b8ccd 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -414,6 +414,7 @@ end network_model_names(model::Nothing, mach::Machine{<:Surrogate}) = nothing +## DUPLICATING/REPLACING PARTS OF A LEARNING NETWORK MACHINE function copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old) train_args = [newnode_given_old[arg] for arg in N.machine.args] @@ -421,13 +422,31 @@ function copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_gi train_args...) end +""" + copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old) + +For now, two top functions will lead to a call of this function: `Base.replace` and +`save`. If `save` is the calling function, the argument `newmodel_given_old` will be nothing +and the goal is to make the machine in the current learning network serializable. +This method will be called. If `Base.replace` is the calling function, then `newmodel_given_old` +will be defined and the other method called, a new Machine will be built with training data. +""" function copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old) m = serializable(N.machine) m.args = Tuple(newnode_given_old[s] for s in N.machine.args) return m end -## DUPLICATING AND REPLACING PARTS OF A LEARNING NETWORK MACHINE +""" + update_mappings_with_node!( + newnode_given_old, + newmach_given_old, + newmodel_given_old, + N::AbstractNode) + +For Nodes that are not sources, update the appropriate mappings +between elements of the learning to be copied and the copy itself. +""" function update_mappings_with_node!( newnode_given_old, newmach_given_old, @@ -453,6 +472,21 @@ update_mappings_with_node!( newmodel_given_old, N::Source) = nothing +""" + copysignature(signature, newnode_given_old; newmodel_given_old=nothing) + +Copies the given signature of a learning network. + +# Arguments: +- `signature`: signature of the learning network to be copied +- `newnode_given_old`: initialized mapping between nodes of the +learning network to be copied and the new one. At this stage it should +contain only source nodes. +- `newmodel_given_old`: initialized mapping between models of the +learning network to be copied and the new one. This is `nothing` if `save` was +the calling function which will result in a different behaviour of +`update_mappings_with_node!` +""" function copysignature(signature, newnode_given_old; newmodel_given_old=nothing) operation_nodes = values(_operation_part(signature)) report_nodes = values(_report_part(signature)) From 5fde0b1219d757e48c410f77e972c68b97392d96 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 31 Jan 2022 09:43:15 +0000 Subject: [PATCH 09/18] remove simple_data --- src/composition/learning_networks/machines.jl | 21 ++++++++++++------- .../composition/learning_networks/machines.jl | 6 +++--- test/machines.jl | 6 +++--- test/test_utilities.jl | 2 -- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 737b8ccd..89c3f150 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -416,6 +416,13 @@ network_model_names(model::Nothing, mach::Machine{<:Surrogate}) = ## DUPLICATING/REPLACING PARTS OF A LEARNING NETWORK MACHINE +""" + copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old) + +For now, two top functions will lead to a call of this function: `Base.replace(::Machine, ...)` and +`save(::Machine, ...)`. A call from `Base.replace` with given `newmodel_given_old` will dispatch to this method. +A new Machine is built with training data from node N. +""" function copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old) train_args = [newnode_given_old[arg] for arg in N.machine.args] return Machine(newmodel_given_old[N.machine.model], @@ -425,11 +432,11 @@ end """ copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old) -For now, two top functions will lead to a call of this function: `Base.replace` and -`save`. If `save` is the calling function, the argument `newmodel_given_old` will be nothing -and the goal is to make the machine in the current learning network serializable. -This method will be called. If `Base.replace` is the calling function, then `newmodel_given_old` -will be defined and the other method called, a new Machine will be built with training data. +For now, two top functions will lead to a call of this function: `Base.replace(::Machine, ...)` and +`save(::Machine, ...)`. A call from `save` will set `newmodel_given_old` to `nothing` which will +then dispatch to this method. +In this circumstance, the purpose is to make the machine attached to node N serializable (see `serializable(::Machine)`). + """ function copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old) m = serializable(N.machine) @@ -445,7 +452,7 @@ end N::AbstractNode) For Nodes that are not sources, update the appropriate mappings -between elements of the learning to be copied and the copy itself. +between elements of the learning networks to be copied and the copy itself. """ function update_mappings_with_node!( newnode_given_old, @@ -610,8 +617,6 @@ end Returns a new `CompositeFitresult` that is a shallow copy of the original one. To do so, we build a copy of the learning network where each machine contained in it needs to be called `serializable` upon. - -Ideally this method should "reuse" as much as possible `Base.replace`. """ function save(model::Composite, fitresult) signature = MLJBase.signature(fitresult) diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index 3ea6de2a..e0804fe9 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -233,7 +233,7 @@ end @testset "Test serializable of pipeline" begin filename = "pipe_mach.jls" - X, y = TestUtilities.simpledata() + X, y = make_regression(100, 1) pipe = (X -> coerce(X, :x₁=>Continuous)) |> DecisionTreeRegressor() mach = machine(pipe, X, y) fit!(mach, verbosity=0) @@ -261,7 +261,7 @@ end @testset "Test serializable of composite machines" begin filename = "stack_mach.jls" - X, y = TestUtilities.simpledata() + X, y = make_regression(100, 1) model = Stack( metalearner = DecisionTreeRegressor(), tree1 = DecisionTreeRegressor(min_samples_split=3), @@ -307,7 +307,7 @@ end @testset "Test serializable of nested composite machines" begin filename = "nested_stack_mach.jls" - X, y = TestUtilities.simpledata() + X, y = make_regression(100, 1) pipe = (X -> coerce(X, :x₁=>Continuous)) |> DecisionTreeRegressor() model = Stack( diff --git a/test/machines.jl b/test/machines.jl index 0f26c8a4..5b58afc2 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -263,7 +263,7 @@ end end @testset "Test serializable method of Supervised Machine" begin - X, y = TestUtilities.simpledata() + X, y = make_regression(100, 1) filename = "decisiontree.jls" mach = machine(DecisionTreeRegressor(), X, y) fit!(mach, verbosity=0) @@ -297,7 +297,7 @@ end end @testset "Test serializable method of Unsupervised Machine" begin - X, _ = TestUtilities.simpledata() + X, _ = make_regression(100, 1) filename = "standardizer.jls" mach = machine(Standardizer(), X) fit!(mach, verbosity=0) @@ -311,7 +311,7 @@ end end @testset "Test Misc functions used in `serializable`" begin - X, y = TestUtilities.simpledata() + X, y = make_regression(100, 1) mach = machine(DeterministicConstantRegressor(), X, y) fit!(mach, verbosity=0) # setreport! default diff --git a/test/test_utilities.jl b/test/test_utilities.jl index ab35616c..b56df8a6 100644 --- a/test/test_utilities.jl +++ b/test/test_utilities.jl @@ -125,7 +125,5 @@ function generic_tests(mach₁, mach₂) end end -simpledata(;n=100) = (x₁=rand(n),), rand(n) - end From dbca31c1f3bb44b92b01ceb8b2ed1605244ba99a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 31 Jan 2022 10:11:08 +0000 Subject: [PATCH 10/18] update some docstrings --- src/composition/learning_networks/machines.jl | 12 ++-- src/machines.jl | 62 +++++++++++++++---- .../composition/learning_networks/machines.jl | 9 +-- 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 89c3f150..389d44b8 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -480,9 +480,11 @@ update_mappings_with_node!( N::Source) = nothing """ - copysignature(signature, newnode_given_old; newmodel_given_old=nothing) + copysignature!(signature, newnode_given_old; newmodel_given_old=nothing) -Copies the given signature of a learning network. +Copies the given signature of a learning network. Contrary to Julia's convention, +this method is actually mutating `newnode_given_old`` and `newmodel_given_old`` and not +the first `signature` argument. # Arguments: - `signature`: signature of the learning network to be copied @@ -494,7 +496,7 @@ learning network to be copied and the new one. This is `nothing` if `save` was the calling function which will result in a different behaviour of `update_mappings_with_node!` """ -function copysignature(signature, newnode_given_old; newmodel_given_old=nothing) +function copysignature!(signature, newnode_given_old; newmodel_given_old=nothing) operation_nodes = values(_operation_part(signature)) report_nodes = values(_report_part(signature)) W = glb(operation_nodes..., report_nodes...) @@ -598,7 +600,7 @@ function Base.replace(mach::Machine{<:Surrogate}, IdDict{AbstractNode,AbstractNode}(all_source_pairs) newsources = [newnode_given_old[s] for s in sources(W)] - newsignature = copysignature(signature, newnode_given_old, newmodel_given_old=newmodel_given_old) + newsignature = copysignature!(signature, newnode_given_old, newmodel_given_old=newmodel_given_old) return machine(mach.model, newsources...; newsignature...) @@ -626,7 +628,7 @@ function save(model::Composite, fitresult) newnode_given_old = IdDict{AbstractNode,AbstractNode}([old => source() for old in sources(W)]) - newsignature = copysignature(signature, newnode_given_old; newmodel_given_old=nothing) + newsignature = copysignature!(signature, newnode_given_old; newmodel_given_old=nothing) newfitresult = MLJBase.CompositeFitresult(newsignature) setfield!(newfitresult, :report_additions, getfield(fitresult, :report_additions)) diff --git a/src/machines.jl b/src/machines.jl index 3fc04eb0..8ebbbaff 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -378,6 +378,12 @@ function machine(model::Model, arg1::AbstractNode, args::AbstractNode...; return Machine(model, arg1, args...; kwargs...) end + +warn_bad_deserialization(state) = + "Deserialized machine state is not -1 (got $state). "* + "This means that the machine has not been saved by a conventional MLJ routine.\n" + "For example, it's possible original training data is accessible from the deserialised object. " + """ machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) @@ -387,8 +393,7 @@ Serialization module. function machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) smach = deserialize(file) smach.state == -1 || - @warn "Deserialized machine state is not -1 (=$(smach.state)). "* - "It means that the machine has not been saved by a conventional MLJ routine." + @warn warn_bad_deserialization(smach.state) restore!(smach) if raw_arg1 !== nothing args = source.((raw_arg1, raw_args...)) @@ -825,12 +830,43 @@ end """ serializable(mach::Machine) -Returns a shallow copy of the machine to make it serializable, in particular: - - Removes all data from caches, args and data fields - - Makes all `fitresults` serializable - - Annotates the state as -1 +Returns a shallow copy of the machine to make it serializable. In particular, +all training data is removed and, if necessary, learned parameters are replaced +with persistent representations. + +Any general purpose Julia serialization may be applied to the output of +serializable (eg, JLSO, BSON, JLD) but you must call restore!(mach) on +the deserialised object mach before using it. See the example below. + +If using Julia's standard Serialization library, a shorter workflow is +available using the [`save`](@ref) method. + +A machine returned by serializable is characterized by the property mach.state == -1. + +### Example using [JLSO](https://invenia.github.io/JLSO.jl/stable/) + + using MLJ + using JLSO + tree = @load DecisionTreeClassifier + X, y = @load_iris + mach = fit!(machine(tree, X, y)) + + # This machine can now be serialized + smach = serializable(mach) + JLSO.save("machine.jlso", machine => smach) + + # Some fitresults may have to be restored + loaded_mach = JLSO.load("machine.jlso")[:machine] + restore!(loaded_mach) + + predict(loaded_mach, X) + predict(mach, X) """ function serializable(mach::Machine{<:Any, C}) where C + # Returns a shallow copy of the machine to make it serializable, in particular: + # - Removes all data from caches, args and data fields + # - Makes all `fitresults` serializable + # - Annotates the state as -1 copymach = machine(mach.model, mach.args..., cache=C) for fieldname in fieldnames(Machine) @@ -880,7 +916,9 @@ end Serialize the machine `mach` to a file with path `filename`, or to an input/output stream `io` (at least `IOBuffer` instances are -supported) using the Serialization module. +supported) using the Serialization module. + +To serialise using a different format, see `serializable`. Machines are de-serialized using the `machine` constructor as shown in the example below. Data (or nodes) may be optionally passed to the @@ -894,11 +932,11 @@ constructor for retraining on new data using the saved model. X, y = @load_iris mach = fit!(machine(tree, X, y)) - MLJ.save("tree.jlso", mach) - mach_predict_only = machine("tree.jlso") + MLJ.save("tree.jls", mach) + mach_predict_only = machine("tree.jls") predict(mach_predict_only, X) - mach2 = machine("tree.jlso", selectrows(X, 1:100), y[1:100]) + mach2 = machine("tree.jls", selectrows(X, 1:100), y[1:100]) predict(mach2, X) # same as above fit!(mach2) # saved learned parameters are over-written @@ -912,10 +950,10 @@ constructor for retraining on new data using the saved model. predict(predict_only_mach, X) !!! warning "Only load files from trusted sources" - Maliciously constructed JLSO files, like pickles, and most other + Maliciously constructed JLS files, like pickles, and most other general purpose serialization formats, can allow for arbitrary code execution during loading. This means it is possible for someone - to use a JLSO file that looks like a serialized MLJ machine as a + to use a JLS file that looks like a serialized MLJ machine as a [Trojan horse](https://en.wikipedia.org/wiki/Trojan_horse_(computing)). diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index e0804fe9..da607fb1 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -344,7 +344,7 @@ end filesizes = [] for n in [100, 500, 1000] filename = "serialized_temp_$n.jls" - X, y = TestUtilities.simpledata(n=n) + X, y = make_regression(n, 1) mach = machine(model, X, y) fit!(mach, verbosity=0) MLJBase.save(filename, mach) @@ -354,16 +354,13 @@ end @test all(x==filesizes[1] for x in filesizes) # What if no serializable procedure had happened filename = "full_of_data.jls" - X, y = TestUtilities.simpledata(n=1000) + X, y = make_regression(1000, 1) mach = machine(model, X, y) fit!(mach, verbosity=0) serialize(filename, mach) @test filesize(filename) > filesizes[1] - @test_logs (:warn, "Deserialized machine state is"* - " not -1 (=1). It means that the"* - " machine has not been saved by a"* - " conventional MLJ routine.") machine(filename) + @test_logs (:warn, MLJBase.warn_bad_deserialization(mach.state)) machine(filename) rm(filename) end From 8270e33b8f94e63767cdf1fa436f83ccf5a61e82 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 2 Feb 2022 15:21:12 +0000 Subject: [PATCH 11/18] move network_model_names to CompositeFitresult and remove report_additions field from it --- src/composition/learning_networks/machines.jl | 23 +++---- src/composition/models/inspection.jl | 4 +- src/composition/models/methods.jl | 11 ++-- .../composition/learning_networks/machines.jl | 4 +- test/composition/models/methods.jl | 64 +++++++++---------- 5 files changed, 49 insertions(+), 57 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 389d44b8..f66775f2 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -101,18 +101,14 @@ end mutable struct CompositeFitresult signature glb - report_additions + network_model_names function CompositeFitresult(signature) - glb = MLJBase.glb(_nodes(signature)...) - new(signature, glb) + signature_node = glb(_nodes(signature)...) + new(signature, signature_node) end end signature(c::CompositeFitresult) = getfield(c, :signature) glb(c::CompositeFitresult) = getfield(c, :glb) -report_additions(c::CompositeFitresult) = getfield(c, :report_additions) - -update!(c::CompositeFitresult) = - setfield!(c, :report_additions, _call(_report_part(signature(c)))) # To accommodate pre-existing design (operations.jl) arrange # that `fitresult.predict` returns the predict node, etc: @@ -251,9 +247,9 @@ See also [`machine`](@ref) function fit!(mach::Machine{<:Surrogate}; kwargs...) glb_node = glb(mach) fit!(glb_node; kwargs...) - update!(mach.fitresult) # updates `report_additions` mach.state += 1 - mach.report = merge(report(glb_node), report_additions(mach.fitresult)) + report_additions_ = _call(_report_part(signature(mach.fitresult))) + mach.report = merge(report(glb_node), report_additions_) return mach end @@ -390,8 +386,6 @@ function return!(mach::Machine{<:Surrogate}, model::Union{Model,Nothing}, verbosity) - _network_model_names = network_model_names(model, mach) - verbosity isa Nothing || fit!(mach, verbosity=verbosity) # anonymize the data @@ -404,9 +398,12 @@ function return!(mach::Machine{<:Surrogate}, cache = (sources = sources, data=data, - network_model_names=_network_model_names, old_model=old_model) + setfield!(mach.fitresult, + :network_model_names, + network_model_names(model, mach)) + return mach.fitresult, cache, mach.report end @@ -631,7 +628,7 @@ function save(model::Composite, fitresult) newsignature = copysignature!(signature, newnode_given_old; newmodel_given_old=nothing) newfitresult = MLJBase.CompositeFitresult(newsignature) - setfield!(newfitresult, :report_additions, getfield(fitresult, :report_additions)) + setfield!(newfitresult, :network_model_names, getfield(fitresult, :network_model_names)) return newfitresult end diff --git a/src/composition/models/inspection.jl b/src/composition/models/inspection.jl index 7fce1486..b4950629 100644 --- a/src/composition/models/inspection.jl +++ b/src/composition/models/inspection.jl @@ -3,7 +3,7 @@ try_scalarize(v) = length(v) == 1 ? v[1] : v function machines_given_model_name(mach::Machine{M}) where M<:Composite - network_model_names = mach.cache.network_model_names + network_model_names = getfield(mach.fitresult, :network_model_names) names = unique(filter(name->!(name === nothing), network_model_names)) network_models = MLJBase.models(glb(mach)) network_machines = MLJBase.machines(glb(mach)) @@ -27,14 +27,12 @@ function tuple_keyed_on_model_names(item_given_machine, mach) end function report(mach::Machine{<:Composite}) - machines = mach.report.machines dict = mach.report.report_given_machine return merge(tuple_keyed_on_model_names(dict, mach), mach.report) end function fitted_params(mach::Machine{<:Composite}) fp = fitted_params(mach.model, mach.fitresult) - _machines = fp.machines dict = fp.fitted_params_given_machine return merge(MLJBase.tuple_keyed_on_model_names(dict, mach), fp) end diff --git a/src/composition/models/methods.jl b/src/composition/models/methods.jl index 94f1c4be..0544523f 100644 --- a/src/composition/models/methods.jl +++ b/src/composition/models/methods.jl @@ -31,7 +31,7 @@ function update(model::M, # underlying learning network machine. For this it is necessary to # temporarily "de-anonymize" the source nodes. - network_model_names = cache.network_model_names + network_model_names = getfield(fitresult, :network_model_names) old_model = cache.old_model glb_node = glb(fitresult) # greatest lower bound @@ -47,7 +47,8 @@ function update(model::M, end fit!(glb_node; verbosity=verbosity) - update!(fitresult) # updates report_additions + # Retrieve additional report values + report_additions_ = _call(_report_part(signature(fitresult))) # anonymize data again: for s in sources @@ -57,13 +58,11 @@ function update(model::M, # record current model state: cache = (sources=cache.sources, data=cache.data, - network_model_names=cache.network_model_names, old_model = deepcopy(model)) - + return (fitresult, cache, - merge(report(glb_node), - report_additions(fitresult))) + merge(report(glb_node), report_additions_)) end diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index da607fb1..173a7796 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -185,8 +185,7 @@ enode = @node mae(ys, yhat) fit!(mach, verbosity=0) fit!(mach2, verbosity=0) @test predict(mach, X) ≈ predict(mach2, X) - @test MLJBase.report_additions(mach.fitresult).mae ≈ - MLJBase.report_additions(mach2.fitresult).mae + @test report(mach).mae ≈ report(mach2).mae @test mach2.args[1]() == Xs() @test mach2.args[2]() == ys() @@ -285,7 +284,6 @@ end @test smach.report.cv_report === mach.report.cv_report @test smach.fitresult isa MLJBase.CompositeFitresult - @test getfield(smach.fitresult, :report_additions) === getfield(mach.fitresult, :report_additions) Serialization.serialize(filename, smach) smach = Serialization.deserialize(filename) diff --git a/test/composition/models/methods.jl b/test/composition/models/methods.jl index 41bf798f..3959c618 100644 --- a/test/composition/models/methods.jl +++ b/test/composition/models/methods.jl @@ -30,9 +30,9 @@ X, y = make_regression(10, 2) yhat = predict(mach1, W) mach = machine(Deterministic(), Xs, ys; predict=yhat) fitresult, cache, _ = return!(mach, model, 0) - @test cache.network_model_names == [:model_in_network, nothing] + network_model_names = getfield(fitresult, :network_model_names) + @test network_model_names == [:model_in_network, nothing] old_model = cache.old_model - network_model_names = cache.network_model_names glb_node = MLJBase.glb(mach) @test !MLJBase.fallback(model, old_model, network_model_names, glb_node) @@ -191,41 +191,41 @@ end # julia bug? If I return the following test to a @testset block, then # the test marked with ******* fails (bizarre!) #@testset "second test of hand-exported network" begin - function MLJBase.fit(model::WrappedRidge, verbosity::Integer, X, y) - Xs = source(X) - ys = source(y) - - stand = Standardizer() - standM = machine(stand, Xs) - W = transform(standM, Xs) +function MLJBase.fit(model::WrappedRidge, verbosity::Integer, X, y) + Xs = source(X) + ys = source(y) - boxcox = UnivariateBoxCoxTransformer() - boxcoxM = machine(boxcox, ys) - z = transform(boxcoxM, ys) + stand = Standardizer() + standM = machine(stand, Xs) + W = transform(standM, Xs) - ridgeM = machine(model.ridge, W, z) - zhat = predict(ridgeM, W) - yhat = inverse_transform(boxcoxM, zhat) + boxcox = UnivariateBoxCoxTransformer() + boxcoxM = machine(boxcox, ys) + z = transform(boxcoxM, ys) - mach = machine(Deterministic(), Xs, ys; predict=yhat) - return!(mach, model, verbosity) - end + ridgeM = machine(model.ridge, W, z) + zhat = predict(ridgeM, W) + yhat = inverse_transform(boxcoxM, zhat) - MLJBase.input_scitype(::Type{<:WrappedRidge}) = - Table(Continuous) - MLJBase.target_scitype(::Type{<:WrappedRidge}) = - AbstractVector{<:Continuous} + mach = machine(Deterministic(), Xs, ys; predict=yhat) + return!(mach, model, verbosity) +end - ridge = FooBarRegressor(lambda=0.1) - model_ = WrappedRidge(ridge) - mach = machine(model_, Xin, yin) - id = objectid(mach) - fit!(mach, verbosity=0) - @test objectid(mach) == id # ********* - yhat=predict(mach, Xin); - ridge.lambda = 1.0 - fit!(mach, verbosity=0) - @test predict(mach, Xin) != yhat +MLJBase.input_scitype(::Type{<:WrappedRidge}) = + Table(Continuous) +MLJBase.target_scitype(::Type{<:WrappedRidge}) = + AbstractVector{<:Continuous} + +ridge = FooBarRegressor(lambda=0.1) +model_ = WrappedRidge(ridge) +mach = machine(model_, Xin, yin) +id = objectid(mach) +fit!(mach, verbosity=0) +@test objectid(mach) == id # ********* +yhat=predict(mach, Xin); +ridge.lambda = 1.0 +fit!(mach, verbosity=0) +@test predict(mach, Xin) != yhat #end From f5429ab57fc313d64367f15211215ca8ebec9d69 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 2 Feb 2022 16:21:53 +0000 Subject: [PATCH 12/18] update backquotes and old syntax --- src/composition/learning_networks/machines.jl | 2 +- src/machines.jl | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index f66775f2..a30acd49 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -480,7 +480,7 @@ update_mappings_with_node!( copysignature!(signature, newnode_given_old; newmodel_given_old=nothing) Copies the given signature of a learning network. Contrary to Julia's convention, -this method is actually mutating `newnode_given_old`` and `newmodel_given_old`` and not +this method is actually mutating `newnode_given_old` and `newmodel_given_old` and not the first `signature` argument. # Arguments: diff --git a/src/machines.jl b/src/machines.jl index 8ebbbaff..8e8c8a57 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -835,19 +835,20 @@ all training data is removed and, if necessary, learned parameters are replaced with persistent representations. Any general purpose Julia serialization may be applied to the output of -serializable (eg, JLSO, BSON, JLD) but you must call restore!(mach) on -the deserialised object mach before using it. See the example below. +`serializable` (eg, JLSO, BSON, JLD) but you must call `restore!(mach)` on +the deserialised object `mach` before using it. See the example below. If using Julia's standard Serialization library, a shorter workflow is available using the [`save`](@ref) method. -A machine returned by serializable is characterized by the property mach.state == -1. +A machine returned by serializable is characterized by the property `mach.state == -1`. ### Example using [JLSO](https://invenia.github.io/JLSO.jl/stable/) using MLJ using JLSO - tree = @load DecisionTreeClassifier + Tree = @load DecisionTreeClassifier + tree = Tree() X, y = @load_iris mach = fit!(machine(tree, X, y)) @@ -855,7 +856,7 @@ A machine returned by serializable is characterized by the property mach.state = smach = serializable(mach) JLSO.save("machine.jlso", machine => smach) - # Some fitresults may have to be restored + # Deserialize and restore learned parameters to useable form: loaded_mach = JLSO.load("machine.jlso")[:machine] restore!(loaded_mach) @@ -928,9 +929,9 @@ constructor for retraining on new data using the saved model. ### Example using MLJ - tree = @load DecisionTreeClassifier + Tree = @load DecisionTreeClassifier X, y = @load_iris - mach = fit!(machine(tree, X, y)) + mach = fit!(machine(Tree(), X, y)) MLJ.save("tree.jls", mach) mach_predict_only = machine("tree.jls") From d48871754fad5c9f012186f1e206ac844d7f9dee Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 2 Feb 2022 16:37:49 +0000 Subject: [PATCH 13/18] switching to a undefined cache for serializable machines --- src/machines.jl | 41 +------------------ .../composition/learning_networks/machines.jl | 4 +- test/machines.jl | 22 ---------- test/test_utilities.jl | 4 +- 4 files changed, 3 insertions(+), 68 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index 8e8c8a57..98cf4055 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -871,18 +871,12 @@ function serializable(mach::Machine{<:Any, C}) where C copymach = machine(mach.model, mach.args..., cache=C) for fieldname in fieldnames(Machine) - if fieldname ∈ (:model, :report) + if fieldname ∈ (:model, :report, :cache, :data, :resampled_data, :old_rows) continue elseif fieldname == :state setfield!(copymach, :state, -1) - # Wipe data from cache - elseif fieldname == :cache - setfield!(copymach, :cache, serializable_cache(mach.cache)) elseif fieldname == :args setfield!(copymach, fieldname, ()) - # Let those fields undefined - elseif fieldname ∈ (:data, :resampled_data, :old_rows) - continue # Make fitresult ready for serialization elseif fieldname == :fitresult copymach.fitresult = save(mach.model, getfield(mach, fieldname)) @@ -972,36 +966,3 @@ end setreport!(mach::Machine, report) = setfield!(mach, :report, report) - - -maybe_serializable(val) = val -maybe_serializable(val::Machine) = serializable(val) - -""" - serializable_cache(cache) - -Default fallbacks to return the original cache. -""" -serializable_cache(cache) = cache - -""" - serializable_cache(cache::Tuple) - -If the cache is a Tuple, any machine in the cache is called -`serializable` upon. This is to address TunedModels. A dispatch on -TunedModel would have been possible but would require a new api function. -""" -serializable_cache(cache::Tuple) = - Tuple(maybe_serializable(val) for val in cache) - -""" - serializable_cache(cache::NamedTuple) - -If the cache is a NamedTuple, any data field is filtered, this is to address -the current learning networks cache. Any machine in the cache is also called -`serializable` upon. -""" -function serializable_cache(cache::NamedTuple) - new_keys = filter(!=(:data), keys(cache)) - return NamedTuple{new_keys}([maybe_serializable(cache[key]) for key in new_keys]) -end \ No newline at end of file diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index 173a7796..a31174f8 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -275,9 +275,7 @@ end # Check data has been wiped out from models at the first level of composition @test length(machines(glb(smach))) == length(machines(glb(mach))) for submach in machines(glb(smach)) - @test !isdefined(submach, :data) - @test !isdefined(submach, :resampled_data) - @test submach.cache isa Nothing || :data ∉ keys(submach.cache) + TestUtilities.test_data(submach) end # Testing extra report field : it is a deepcopy diff --git a/test/machines.jl b/test/machines.jl index 5b58afc2..db517665 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -318,28 +318,6 @@ end @test mach.report isa NamedTuple MLJBase.setreport!(mach, "toto") @test mach.report == "toto" - - # serializable_cache - # The default is to return the original cache - @test MLJBase.serializable_cache(mach.cache) === mach.cache - # For Tuples and NamedTuples, a machine might live in the cache - # (as it is the case in MLJTuned model) - # and therefore has to be called `serializable` upon - # if a `data` field lives in the cache, it is removed, this is for - # learning networks. However this might become useless if the - # data anomynization process in fit! is removed - submach = machine(DeterministicConstantRegressor(), X, y) - fit!(submach, verbosity=0) - # Tuple type in cache - mach.cache = (submach,) - newcache = MLJBase.serializable_cache(mach.cache) - @test newcache[1] isa Machine - @test !isdefined(newcache[1], :data) - # NamedTuple type in cache - mach.cache = (machine=submach, data=(X,y)) - newcache = MLJBase.serializable_cache(mach.cache) - @test :data ∉ keys(newcache) - @test !isdefined(newcache[1], :data) end diff --git a/test/test_utilities.jl b/test/test_utilities.jl index b56df8a6..eef0f8a8 100644 --- a/test/test_utilities.jl +++ b/test/test_utilities.jl @@ -111,9 +111,7 @@ function test_data(mach) @test !isdefined(mach, :old_rows) @test !isdefined(mach, :data) @test !isdefined(mach, :resampled_data) - if mach isa NamedTuple - @test :data ∉ keys(mach.cache) - end + @test !isdefined(mach, :cache) end function generic_tests(mach₁, mach₂) From c91a81042827b852e11fe0eaa4adba24867f6b64 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 2 Feb 2022 16:41:36 +0000 Subject: [PATCH 14/18] remove passing args to machine constructor from file --- src/machines.jl | 9 ++------- test/machines.jl | 5 ----- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index 98cf4055..ef555324 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -385,21 +385,16 @@ warn_bad_deserialization(state) = "For example, it's possible original training data is accessible from the deserialised object. " """ - machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) + machine(file::Union{String, IO}) Rebuild from a file a machine that has been serialized using the default Serialization module. """ -function machine(file::Union{String, IO}, raw_arg1=nothing, raw_args...) +function machine(file::Union{String, IO}) smach = deserialize(file) smach.state == -1 || @warn warn_bad_deserialization(smach.state) restore!(smach) - if raw_arg1 !== nothing - args = source.((raw_arg1, raw_args...)) - MLJBase.check(smach.model, args...; full=true) - smach.args = args - end return smach end diff --git a/test/machines.jl b/test/machines.jl index db517665..721e0ef5 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -288,11 +288,6 @@ end smach = machine(filename) @test predict(smach, X) == predict(mach, X) - # Try to reset the data - smach = machine(filename, X, y) - fit!(smach, verbosity=0) - @test predict(smach) == predict(mach) - rm(filename) end From 6d83802484df56b94e9626ffbbbfa4f2532ae493 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Feb 2022 09:03:30 +0000 Subject: [PATCH 15/18] add restore! now sets state to 1 --- src/composition/learning_networks/machines.jl | 1 + src/machines.jl | 3 ++- test/test_utilities.jl | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index a30acd49..583eaf83 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -644,6 +644,7 @@ function restore!(mach::Machine{<:Composite}) for submach in machines(glb_node) restore!(submach) end + mach.state = 1 return mach end diff --git a/src/machines.jl b/src/machines.jl index ef555324..5b6ecb44 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -889,10 +889,11 @@ end restore!(mach::Machine) Default method to restores the state of a machine that is currently serializable. -Such a machine is annotated with `state=-1` +Such a machine is annotated with `state=1` """ function restore!(mach::Machine) mach.fitresult = restore(mach.model, mach.fitresult) + mach.state = 1 return mach end diff --git a/test/test_utilities.jl b/test/test_utilities.jl index eef0f8a8..37c0df60 100644 --- a/test/test_utilities.jl +++ b/test/test_utilities.jl @@ -117,7 +117,7 @@ end function generic_tests(mach₁, mach₂) test_args(mach₂) test_data(mach₂) - @test mach₂.state == -1 + @test mach₂.state == 1 for field in (:frozen, :model, :old_model, :old_upstream_state, :fit_okay) @test getfield(mach₁, field) == getfield(mach₂, field) end From 03a74e039073fd7185161acd5f8464b4c113d815 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Feb 2022 09:27:38 +0000 Subject: [PATCH 16/18] add warning deserialized for operations with args --- src/operations.jl | 5 +++++ test/composition/learning_networks/machines.jl | 1 - test/machines.jl | 10 ++++++++++ test/test_utilities.jl | 2 +- 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/operations.jl b/src/operations.jl index 85882afc..7a4c4557 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -33,6 +33,10 @@ _err_serialized(operation) = "deserialized machine with no data "* "bound to it. ")) +warn_serializable_mach(operation) = "The operation $operation has been called on a "* + "deserialised machine mach whose learned parameters "* + "may be unusable. To be sure, first run restore!(mach)." + # 0. operations on machine, given rows=...: for operation in OPERATIONS @@ -77,6 +81,7 @@ for operation in OPERATIONS # 1. operations on machines, given *concrete* data: function $operation(mach::Machine, Xraw) if mach.state != 0 + mach.state == -1 && @warn warn_serializable_mach($operation) return $(operation)(mach.model, mach.fitresult, reformat(mach.model, Xraw)...) diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index a31174f8..aefcfd96 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -240,7 +240,6 @@ end # Check serializable function smach = MLJBase.serializable(mach) TestUtilities.generic_tests(mach, smach) - @test MLJBase.predict(smach, X) == MLJBase.predict(mach, X) @test keys(fitted_params(smach)) == keys(fitted_params(mach)) @test keys(report(smach)) == keys(report(mach)) # Check data has been wiped out from models at the first level of composition diff --git a/test/machines.jl b/test/machines.jl index 721e0ef5..10d2e227 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -271,12 +271,16 @@ end smach = MLJBase.serializable(mach) @test smach.report == mach.report @test smach.fitresult == mach.fitresult + @test_throws(ArgumentError, predict(smach)) + @test_logs (:warn, MLJBase.warn_serializable_mach(predict)) predict(smach, X) + TestUtilities.generic_tests(mach, smach) # Check restore! function Serialization.serialize(filename, smach) smach = Serialization.deserialize(filename) MLJBase.restore!(smach) + @test smach.state == 1 @test MLJBase.predict(smach, X) == MLJBase.predict(mach, X) @test fitted_params(smach) isa NamedTuple @test report(smach) == report(mach) @@ -286,6 +290,7 @@ end # End to end save and reload MLJBase.save(filename, mach) smach = machine(filename) + @test smach.state == 1 @test predict(smach, X) == predict(mach, X) rm(filename) @@ -301,6 +306,11 @@ end smach = machine(filename) @test transform(mach, X) == transform(smach, X) + @test_throws(ArgumentError, transform(smach)) + + # warning on non-restored machine + smach = deserialize(filename) + @test_logs (:warn, MLJBase.warn_serializable_mach(transform)) transform(smach, X) rm(filename) end diff --git a/test/test_utilities.jl b/test/test_utilities.jl index 37c0df60..eef0f8a8 100644 --- a/test/test_utilities.jl +++ b/test/test_utilities.jl @@ -117,7 +117,7 @@ end function generic_tests(mach₁, mach₂) test_args(mach₂) test_data(mach₂) - @test mach₂.state == 1 + @test mach₂.state == -1 for field in (:frozen, :model, :old_model, :old_upstream_state, :fit_okay) @test getfield(mach₁, field) == getfield(mach₂, field) end From 69303a075ac222af0beebb62a1e70eda11a8dbcf Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Feb 2022 12:56:48 +0000 Subject: [PATCH 17/18] correct code to unbreak test but I believe the test itself is broken --- src/composition/learning_networks/machines.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 583eaf83..b96e17fd 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -386,7 +386,10 @@ function return!(mach::Machine{<:Surrogate}, model::Union{Model,Nothing}, verbosity) + network_model_names_ = network_model_names(model, mach) + verbosity isa Nothing || fit!(mach, verbosity=verbosity) + setfield!(mach.fitresult, :network_model_names, network_model_names_) # anonymize the data sources = MLJBase.sources(glb(mach)) From 6a32f25284534ca6e55f409b875bb9729030347c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Feb 2022 13:00:38 +0000 Subject: [PATCH 18/18] add another warning test --- test/composition/learning_networks/machines.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/composition/learning_networks/machines.jl b/test/composition/learning_networks/machines.jl index aefcfd96..3a35e1c3 100644 --- a/test/composition/learning_networks/machines.jl +++ b/test/composition/learning_networks/machines.jl @@ -239,6 +239,9 @@ end # Check serializable function smach = MLJBase.serializable(mach) + @test_logs((:warn, MLJBase.warn_serializable_mach(predict), + (:warn, MLJBase.warn_serializable_mach(predict)), + fit!(mach, verbosity=0))) TestUtilities.generic_tests(mach, smach) @test keys(fitted_params(smach)) == keys(fitted_params(mach)) @test keys(report(smach)) == keys(report(mach))