Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new option for exporting learning networks as stand-alone composite model types #841

Merged
merged 69 commits into from
Oct 31, 2022

Conversation

ablaom
Copy link
Member

@ablaom ablaom commented Sep 16, 2022

This non-breaking PR addresses #831.

It is mildly breaking: Where previously report(mach) returned NamedTuple(), it will now return nothing instead.

Requires:

Has been tested (with above MLJModelInterface PR) as non-breaking for MLJTuning, MLJEnsembles, MLJIteration.
Has been tested with MLJTestIntegration, which applies integration tests for majority of model providing packages.

Summary

This PR adds a new abstract type NetworkComposite, with subtypes DeterministicNetworkComposite, etc, for the purpose of "exporting" learning networks as new stand-alone model types. It is intended to render all existing alternatives obsolete, but these are not touched in this PR. It is simpler, in that there are no "learning network machines" , no Surrogate models, and no special return! method to call. Refer to #831 for a list of other drawbacks overcome by the new approach.

Below is a demonstration of exporting a learning network to combine predictions of two regressors. It additionally exposes in the new model's report the output of a node called disagreement (measuring disagreement between the two regressors, at time of training):

Step 1: Define the composite type - must subtype NetworkComposite

using MLJBase

mutable struct DoubleRegressor <: DeterministicNetworkComposite
    regressor1
    regressor2
    mix::Float64
end

Step 2: Implement MLJBase.prefit (method added in this PR)

This method has same signature as fit(model, ...) and returns a named tuple constituting an "interface" to a learning network constructed out of it's arguments. A novelty is that component models are represented by the names (symbols) of the relevant field of the composite.

function MLJBase.prefit(composite::DoubleRegressor, verbosity, X, y)
    Xs = source(X)
    ys = source(y)

    mach1 = machine(:regressor1, Xs, ys)   #  <--- symbol instead of model
    mach2 = machine(:regressor2, Xs, ys)

    yhat1 = predict(mach1, Xs)
    yhat2 = predict(mach2, Xs)

    # node to return agreement between the regressor predictions:
    disagreement = node((y1, y2) -> l2(y1, y2), yhat1, yhat2)

    # get the weighted average the predictions of the regressors:
    λ = composite.mix
    yhat = (1 - λ)*yhat1 + λ*yhat2

    return (
        predict = yhat,
        report= (; training_disagreement=disagreement)
    )
end

Details

A new documentation PR at MLJ has already been posted. The relevant section is here

There are two additional changes introduced in this PR which make the above possible:

Symbols instead of models in machines

One can bind data in a machine to a "virtual" model, which is just a symbol, as in mach = machine(:transformer, X). If my_composite is any object with :transformer as a property, then fit!(mach; composite=my_composite) will train the machine with :transformer replaced with getproperty(my_composite, :transformer). This "run-time interpolation" of models propagates to all machines in a learning network. The general user need not bother with this, but if they know about it, the export process above is less mysterious.

Separate training reports and operation reports in machines

In a change purely internal to MLJBase, the report field of a machine is now a dictionary d, keyed on method (:fit, :predict, etc). When an reporting operation is called on the machine, it's report is added to the dictionary; previously it was merged with the training report. Now when you call report on a machine, the various reports in dictionary get merged by calling a new method MLJModelInterface.report(::Model, d). This method can be overloaded by a model but there is a fallback that does a pretty good job of it (the fallback unexpectedly "just worked" for the new NetworkComposite types).

In the existing export process, fitted_params(::Machine, ...) and report(::Machine, ...) have to be special-cased for composite model types. With this new approach, that is no longer necessary, which restores uniformity to the interface.

To do

  • Add save and restore methods for serialising <:NetworkComposite
  • Add facility to expose result of calling an internal node in the composite model fitted_params, as we have already for reports.

ablaom added 30 commits July 12, 2022 14:00
Adapt some tests for Julia nightly.
oops

generalize models(::AbstractNode) to find symbolic models

fix test

fix logic in `fit_only!` for symbolic machines

fix some tests
extra cleanup around serialization

fix missed tests
@ablaom
Copy link
Member Author

ablaom commented Oct 13, 2022

The MLJ doc update associated with the change proposed here has been posted. The new rewritten Learning Networks section is here.

@davnn
Copy link
Contributor

davnn commented Oct 14, 2022

Keen to move this along soon.

@davnn Have you made any progress testing this PR with OutlierDetection.jl?

Hey, sorry I've been very busy. I'm trying to integrate it currently, my approach is as follows:

We defined wrappers for detectors (previously resulting in composite detectors, now in composite network detectors) that take a variabled number of arguments as base detectors. Previously these base detector were used in machine(detector, data...), now we should use machine(:detector_name, data...). Where does machine(:detector_name, data...) to machine(getproperty(composite, :detector_name, data...) happen?

I used something like map(d -> augmented_transform(MLJ.machine(d, Xs), Xs), getfield(model, :detectors)) before, replacing that with map(d -> augmented_transform(MLJ.machine(d, Xs), Xs), tuple_of_detector_symbols) doesn't work, but how am I supposed to map over an iterable of submodels then?

Otherwise I don't see major hurdles, another small thing is that we defined a fallback for fit to allow usage with unsupervised models, i.e. https://github.com/JuliaAI/MLJModelInterface.jl/blob/90875b72fd92fbde0424ea55b5df1c604cf92c27/src/model_api.jl#L21, where would such a fallback land for prefit?

@olivierlabayle
Copy link
Collaborator

@olivierlabayle Did you want to comment on this PR? I realize you might not want to review it as it is quite long, but if you're willing and interested that would be extremely welcome.

Sorry I've just realised I've missed that! I'll try to have a look in the next couple days, it looks really promising, but please feel free to go ahead since it's been open for some time already.

kwargs...)
return Machine(model, arg1, args...; kwargs...)
end

function machine(model::Symbol, arg1::AbstractNode, args::AbstractNode...;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this method is now redundant with the method above.

@olivierlabayle
Copy link
Collaborator

Thanks @ablaom for this simplification of the learning network definition process! My only reservations are about the symbolic access of models in the definition of the network itself. For that magic to occur you've had to hack into and complexify fit_only!. Writing composite.model instead of :model doesn't seem like a huge burden to me and would also be easier to debug. Is there another benefit beyond this short keyed access?

@ablaom
Copy link
Member Author

ablaom commented Oct 17, 2022

@davnn Thanks for looking at this.

Where does machine(:detector_name, data...) to machine(getproperty(composite, :detector_name, data...) happen?

Whenever you call fit!(node, composite=composite) where node is a node, then this will interpolate symbols in machines in the network subtended at node with the model values in the property accessible object composite (assuming those symbols are properties) before training the network. (See also proposed docs.) There is generic MMI.fit(::NetworkComposite, verbosity, data...) method which makes exactly this call on the greatest lower bound node in the network (including "report" nodes) defined by the "network interface" returned by prefit.

replacing that with map(d -> augmented_transform(MLJ.machine(d, Xs), Xs), tuple_of_detector_symbols) doesn't work

Mmm. I think I need more context to understand why this is not working, as I've by now forgotten the details. Can you please point me to the relevant code and I'll take a look. (I didn't think we had a augmented_transform in MMI, as we decided to abandon that?? Is this something you added locally?)

An alternative is to leave that particular line alone - that should still work. You don't have to do the symbol replacement. The problem will be that the reports and fitted_params for the individual detectors in your ensemble will not automatically appear in the composite model's report and fitted_params. In the case of reports, we can (probably) add them "manually" using report nodes; in the case of fitted_params I'm now thinking of adding a facility to do that also in this PR. But I should first like to see if we can avoid this workaround.

@ablaom
Copy link
Member Author

ablaom commented Oct 17, 2022

@olivierlabayle Thanks for taking the time to look over this substantial PR and for your valuable feedback.

Thanks @ablaom for this simplification of the learning network definition process! My only reservations are about the symbolic access of models in the definition of the network itself. For that magic to occur you've had to hack into and complexify fit_only!. Writing composite.model instead of :model doesn't seem like a huge burden to me and would also be easier to debug. Is there another benefit beyond this short keyed access?

Yes, there is some complication to fit_only!, but I think this is way less hacky than what is done now to expose component model reports and fitted_params, and so forth, no? Not sure what you mean by "short keyed access". The benefits are overcoming the shortcomings detailed in #831. As explained in comment above, writing composite.model will mean reports and fitted_params are missing from the composite model report/fitted_params, but is still allowed.

would also be easier to debug.

Yes, it's slightly more tedious to debug. When working with an unexported learning network with the symbol replacements, you need to add the composite=... keyword argument to your fit!(node) calls (see here)

@davnn
Copy link
Contributor

davnn commented Oct 18, 2022

Mmm. I think I need more context to understand why this is not working, as I've by now forgotten the details. Can you please point me to the relevant code and I'll take a look. (I didn't think we had a augmented_transform in MMI, as we decided to abandon that?? Is this something you added locally?)

Yes, I've implemented augmented_transform internally for detector models, which would now have to be extended for the new Network interface. The implementations of augmented transform work on machines, but I see that the new network works on models and fitresults, i.e. https://github.com/JuliaAI/MLJBase.jl/blob/banana/src/operations.jl#L224. I'm not sure how to translate this API to augmented transform. I tried to add an augmented_transform node https://github.com/OutlierDetectionJL/OutlierDetection.jl/blob/banana/src/mlj_wrappers.jl#L193, but that doesn't work (maybe because signatures only work with a hard-coded set of operations? https://github.com/JuliaAI/MLJBase.jl/blob/banana/src/composition/learning_networks/signatures.jl#L148).

Edit: At least I managed to fix the package to work with current #dev and the latest releases (0.20.12 - 0.20.20). The fix is hacky, but better than nothing, see OutlierDetectionJL/OutlierDetection.jl@0398eba. The fix should equally work for #banana, but that appears to break other, non-report related, functionality I think.

update serialization for NetworkComposite models re fitted_params

fixed overlooked problems
@ablaom
Copy link
Member Author

ablaom commented Oct 19, 2022

@davnn I'm going to continue the OutlierDetection specific discussion in a new thread which will appear shortly below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants