Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal to add support to allow non-generalizing models to contribute to a machine's report #806

Merged
merged 6 commits into from
Jul 14, 2022

Conversation

ablaom
Copy link
Member

@ablaom ablaom commented Jul 12, 2022

In response to design discussions at JuliaAI/MLJ.jl#950 and JuliaAI/MLJ.jl#852.

Requires:

Related: JuliaAI/StatisticalTraits.jl#25

To do:

  • Bump compat for MLJModelInterface
  • Create issues at MLJ to support new trait and to update the MLJ manual

Context. Some clustering models (eg, DBSCAN) and some imputing models do not generalize to new data. That is, there is no training data - only a transformation determined completely by input data. From the point-of-view of model composition, such models fit most naturally fit into the current Static model typing, which means that fit is a no-op; all the heavy lifting occurs in the transform method. However, at present, only the fit method can contribute a report, which at the level of machines is accessible via report(mach). So extra byproducts of the transformation computation (eg, point types in DBSCAN) cannot be easily exposed to the user.

It is proposed that we add a new model trait reporting_operations that lists those operations (such as :transform) which are understood to return two pieces of information, when called on a model instance: the usual output, and some report data (named tuple) . In those cases, calling the operation on a machine is only to return the output, but the report gets merged into the machine's report.

While the main use-case is Static models, such enhancements could be applied to any model, and such models can be used in composite models (eg, pipelines) with their reports accessible as usual.

The proposal in action

In implementation

mutable struct StaticKefir <: Static
    alpha::Float64 # non-zero to be invertible
end
MLJBase.reporting_operations(::Type{<:StaticKefir}) = (:transform, :inverse_transform)

# piece-wise linear function that is linear only for `alpha=1`:
kefir(x, alpha) = x > 0 ? x * alpha : x / alpha

MLJBase.transform(model::StaticKefir, _, X) = (
    broadcast(kefir, X, model.alpha),
    (; first = first(X)),                          # <-----------------   report component 
)

MLJBase.inverse_transform(model::StaticKefir, _, W) = (
    broadcast(kefir, W, 1/(model.alpha)),
    (; last = last(W)),                          # <-----------------   report component 
)

User workflow

model = StaticKefir(2)
mach = machine(StaticKefir(2))  # remember there is no training data to attach to a machine for a `Static` model
julia> transform(mach, [1, 2, 3])  # no need to `fit!` a `Static` model
3-element Vector{Float64}:
 2.0
 4.0
 6.0

julia> report(mach)
(first = 1,)

julia> inverse_transform(mach, [2, 4, 6])
3-element Vector{Float64}:
 1.0
 2.0
 3.0

julia> report(mach)
(first = 1,
 last = 6,)

If you don't care to see the report, there's the one-liner,

julia> transform(machine(StaticKefir(2)), [1, 2, 3])
3-element Vector{Float64}:
 2.0
 4.0
 6.0

@ablaom
Copy link
Member Author

ablaom commented Jul 12, 2022

@ablaom ablaom marked this pull request as draft July 12, 2022 01:52
@ablaom
Copy link
Member Author

ablaom commented Jul 12, 2022

cc @pazzo83

@ablaom
Copy link
Member Author

ablaom commented Jul 14, 2022

I've now tested this using MLJTestIntegration.jl, which tests integration with wider MLJ ecosystem, and all good.

@ablaom ablaom marked this pull request as ready for review July 14, 2022 07:28
@codecov-commenter
Copy link

codecov-commenter commented Jul 14, 2022

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.70%. Comparing base (ab8d12c) to head (8af858b).
Report is 420 commits behind head on dev.

Additional details and impacted files
@@            Coverage Diff             @@
##              dev     #806      +/-   ##
==========================================
+ Coverage   85.61%   85.70%   +0.08%     
==========================================
  Files          36       36              
  Lines        3477     3497      +20     
==========================================
+ Hits         2977     2997      +20     
  Misses        500      500              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants