Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add report method for merging fit reports with operational reports #160

Merged
merged 23 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
185286f
add OrderedCollections to test deps
ablaom Aug 30, 2022
1dbae10
fix a deprecation warning
ablaom Aug 30, 2022
101ffa3
add report method for merging fit reports with operation reports
ablaom Aug 30, 2022
912f82c
add forgotten switch for `nothing`
ablaom Aug 31, 2022
e78f6ed
tweak warning for bad docstrings
ablaom Aug 31, 2022
cf5c5e3
oops
ablaom Aug 31, 2022
7ae5265
Merge pull request #161 from JuliaAI/descr-warning-tweak
ablaom Sep 1, 2022
dc1a12a
add OrderedCollections to test deps
ablaom Aug 30, 2022
b30f791
fix a deprecation warning
ablaom Aug 30, 2022
438a2d2
add report method for merging fit reports with operation reports
ablaom Aug 30, 2022
d698309
add forgotten switch for `nothing`
ablaom Aug 31, 2022
a2e6794
Merge branch 'banana' of https://github.com/alan-turing-institute/MLJ…
ablaom Sep 1, 2022
3b477ec
handle clashes in keys of reports
ablaom Sep 7, 2022
06c5971
improved key clash handling
ablaom Sep 11, 2022
c2aacba
empty merged reports should be replaced with `nothing` in `report()` …
ablaom Sep 11, 2022
d4a3419
overload fitted_params(::Static, ..) = nothing
ablaom Sep 15, 2022
49e128b
make merge fallback more robust
ablaom Sep 15, 2022
25986d2
make sure empty tuples are scrubbed to `nothing` in report return value
ablaom Sep 15, 2022
6508287
bump compat julia = "1.6" and update ci
ablaom Sep 16, 2022
ed0fd8c
trivial commit
ablaom Sep 16, 2022
b8ad6f7
rm redundant `const`
ablaom Sep 22, 2022
82aa20a
improve docstring
ablaom Sep 22, 2022
2319e85
tweak docstring
ablaom Sep 22, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.6'
- '1'
os:
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
[compat]
ScientificTypesBase = "3.0"
StatisticalTraits = "3.2"
julia = "1"
julia = "1.6"

[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "ScientificTypes", "Tables", "Test"]
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "OrderedCollections", "ScientificTypes", "Tables", "Test"]
17 changes: 11 additions & 6 deletions src/metadata_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,16 @@ function _extend!(program::Expr, trait::Symbol, value, T)
end
end

const DEPWARN_DOCSTRING =
"`metadata_model` should not be called with the keyword argument "*
"`descr` or `docstring`. Implementers of the MLJ model interface "*
"should instead create an MLJ-compliant docstring in the usual way. "*
"See https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings for details. "
depwarn_docstring(T) =
"""

Regarding $T: `metadata_model` should not be called with the keyword argument `descr`
or `docstring`. Implementers of the MLJ model interface should instead create an
MLJ-compliant docstring in the usual way. See
https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings
for details.

"""

"""
metadata_model(T; args...)
Expand Down Expand Up @@ -122,7 +127,7 @@ function metadata_model(
supports_training_losses::Union{Nothing,Bool}=nothing,
reports_feature_importances::Union{Nothing,Bool}=nothing,
)
docstring === nothing || Base.depwarn(DEPWARN_DOCSTRING, :metadata_model)
docstring === nothing || Base.depwarn(depwarn_docstring(T), :metadata_model)

program = quote end

Expand Down
57 changes: 56 additions & 1 deletion src/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ part of the tuple returned by `fit`.

"""
fitted_params(::Model, fitresult) = (fitresult=fitresult,)

fitted_params(::Static, ::Nothing) = nothing
"""

predict(model, fitresult, new_data...)
Expand Down Expand Up @@ -173,6 +173,8 @@ the feature importances from the model's `fitresult` and `report` as an
abstract vector of `feature::Symbol => importance::Real` pairs
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).

# New model implementations

The following trait overload is also required:
`MLJModelInterface.reports_feature_importances(::Type{<:M}) = true`

Expand All @@ -182,3 +184,56 @@ If for some reason a model is sometimes unable to report feature importances the

"""
function feature_importances end

_named_tuple(named_tuple::NamedTuple) = named_tuple
_named_tuple(::Nothing) = NamedTuple()
_named_tuple(something_else) = (report=something_else,)
_scrub(x) = x
_scrub(x::NamedTuple) = isempty(x) ? nothing : x
_keys(named_tuple) = keys(named_tuple)
_keys(::Nothing) = ()

"""
MLJModelInterface.report(model, report_given_method)

Merge the reports in the dictionary `report_given_method` into a single
property-accessible object. It is supposed that each key of the dictionary is either
`:fit` or the name of an operation, such as `:predict` or `:transform`. Each value will be
the `report` component returned by a training method (`fit` or `update`) dispatched on the
`model` type, in the case of `:fit`, or the report component returned by an operation that
supports reporting.

# New model implementations

Overloading this method is optional, unless the model generates reports that are neither
named tuples nor `nothing`.

Assuming each value in the `report_given_method` dictionary is either a named tuple
or `nothing`, and there are no conflicts between the keys of the dictionary values
(the individual reports), the fallback returns the usual named tuple merge of the
dictionary values, ignoring any `nothing` value. If there is a key conflict, all operation
OkonSamuel marked this conversation as resolved.
Show resolved Hide resolved
reports are first wrapped in a named
tuple of length one, as in `(predict=predict_report,)`. A `:fit` report is never wrapped.

If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as
`(report=value, )` before merging.

"""
function report(model, report_given_method)

return_keys = vcat(collect.(_keys.(values(report_given_method)))...)

# Note that we want to avoid copying values in each individual report named tuple, and
# merge the reports in a reproducible order.

methods = collect(keys(report_given_method)) |> sort!
length(methods) == 1 && return _scrub(report_given_method[only(methods)])
need_to_wrap = return_keys != unique(return_keys)
reports = map(methods) do method
tup = _named_tuple(report_given_method[method])
isempty(tup) ? NamedTuple() :
(need_to_wrap && method !== :fit) ? NamedTuple{(method,)}((tup,)) :
tup
end
return _scrub(merge(reports...))
end
4 changes: 2 additions & 2 deletions test/data_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ end
eval(:(module UserSide
import MLJModelInterface: metadata_model, metadata_pkg
struct A end
descr = "something"
human_name = "Big Foot"
# Smoke tests.
metadata_model(A; descr=descr)
metadata_model(A; human_name)
metadata_pkg(A)
end))
end
45 changes: 44 additions & 1 deletion test/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ end
f0::Int
end


mutable struct APIx1 <: Static end

@testset "selectrows(model, data...)" begin
Expand Down Expand Up @@ -95,3 +94,47 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end
@test yhat == fill(DummyUnivariateFinite(), 3)

end

@testset "fallback for `report()` method" begin
report_given_method =
OrderedCollections.OrderedDict(
:predict=>(y=7,),
:fit=>(x=1, z=3),
:transform=>nothing,
)
@test MLJModelInterface.report(APIx0(f0=1), report_given_method) ==
(x=1, z=3, y=7)

report_given_method =
OrderedCollections.OrderedDict(
:predict=>(y=7,),
:fit=>(y=1, z=3),
:transform=>nothing,
)
@test MLJModelInterface.report(APIx0(f0=1), report_given_method) ==
(y=1, z=3, predict=(y=7,))

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => nothing, :transform => NamedTuple()),
) |> isnothing

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => 42),
) == 42

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => nothing),
) |> isnothing

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => NamedTuple()),
OkonSamuel marked this conversation as resolved.
Show resolved Hide resolved
) |> isnothing


end


1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ScientificTypesBase, ScientificTypes
using Tables, Distances, CategoricalArrays, InteractiveUtils
import DataFrames: DataFrame
import Markdown
import OrderedCollections

const M = MLJModelInterface
const FI = M.FullInterface
Expand Down