Skip to content

Commit

Permalink
improvements to measure doc-strings - format and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Feb 8, 2022
1 parent 6914722 commit 48d4d02
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
25 changes: 16 additions & 9 deletions src/measures/continuous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ body=
"""
``\\text{mean absolute error} = n^{-1}∑ᵢ|yᵢ-ŷᵢ|`` or
``\\text{mean absolute error} = n^{-1}∑ᵢwᵢ|yᵢ-ŷᵢ|``
""")
""",
scitype=DOC_INFINITE)

call(::MeanAbsoluteError, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
abs.(ŷ .- y) |> skipinvalid |> mean
Expand Down Expand Up @@ -55,7 +56,8 @@ body=
"""
``\\text{root mean squared error} = \\sqrt{n^{-1}∑ᵢ|yᵢ-ŷᵢ|^2}`` or
``\\text{root mean squared error} = \\sqrt{\\frac{∑ᵢwᵢ|yᵢ-ŷᵢ|^2}{∑ᵢwᵢ}}``
""")
""",
scitype=DOC_INFINITE)

call(::RootMeanSquaredError, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
(y .- ŷ).^2 |> skipinvalid |> mean |> sqrt
Expand Down Expand Up @@ -89,7 +91,8 @@ The R² (also known as R-squared or coefficient of determination) is suitable fo
Let ``\\overline{y}`` denote the mean of ``y``, then
``\\text{R^2} = 1 - \\frac{∑ (\\hat{y} - y)^2}{∑ \\overline{y} - y)^2}.``
""")
""",
scitype=DOC_INFINITE)

function call(::RSquared, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real})
num = (ŷ .- y).^2 |> skipinvalid |> sum
Expand Down Expand Up @@ -121,7 +124,8 @@ body=
"""
Constructor signature: `LPLoss(p=2)`. Reports
`|ŷ[i] - y[i]|^p` for every index `i`.
""")
""",
scitype=DOC_INFINITE)

single(m::LPLoss, ŷ::Real, y::Real) = abs(y - ŷ)^(m.p)

Expand All @@ -146,7 +150,8 @@ body=
``\\text{root mean squared log error} =
n^{-1}∑ᵢ\\log\\left({yᵢ \\over ŷᵢ}\\right)``
""",
footer="See also [`rmslp1`](@ref).")
footer="See also [`rmslp1`](@ref).",
scitype=DOC_INFINITE)

call(::RootMeanSquaredLogError, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
(log.(y) - log.(ŷ)).^2 |> skipinvalid |> mean |> sqrt
Expand Down Expand Up @@ -185,7 +190,8 @@ Constructor signature: `RootMeanSquaredLogProportionalError(; offset = 1.0)`.
``\\text{root mean squared log proportional error} =
n^{-1}∑ᵢ\\log\\left({yᵢ + \\text{offset} \\over ŷᵢ + \\text{offset}}\\right)``
""",
footer="See also [`rmsl`](@ref). ")
footer="See also [`rmsl`](@ref). ",
scitype=DOC_INFINITE)

call(m::RMSLP, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
(log.(y .+ m.offset) - log.(ŷ .+ m.offset)).^2 |>
Expand Down Expand Up @@ -226,7 +232,7 @@ m^{-1}∑ᵢ \\left({yᵢ-ŷᵢ \\over yᵢ}\\right)^2``
where the sum is over indices such that `abs(yᵢ) > tol` and `m` is the number
of such indices.
""")
""", scitype=DOC_INFINITE)

function call(m::RootMeanSquaredProportionalError,
::ArrMissing{<:Real},
Expand Down Expand Up @@ -274,7 +280,7 @@ Constructor key-word arguments: `tol` (default = `eps()`).
where the sum is over indices such that `abs(yᵢ) > tol` and `m` is the number
of such indices.
""")
""", scitype=DOC_INFINITE)

function call(m::MeanAbsoluteProportionalError,
::ArrMissing{<:Real},
Expand Down Expand Up @@ -311,7 +317,8 @@ const LogCosh = LogCoshLoss
@create_aliases LogCoshLoss

@create_docs(LogCoshLoss,
body="Reports ``\\log(\\cosh(ŷᵢ-yᵢ))`` for each index `i`. ")
body="Reports ``\\log(\\cosh(ŷᵢ-yᵢ))`` for each index `i`. ",
scitype=DOC_INFINITE)

_softplus(x::T) where T<:Real = x > zero(T) ? x + log1p(exp(-x)) : log1p(exp(x))
_log_cosh(x::T) where T<:Real = x + _softplus(-2x) - log(convert(T, 2))
Expand Down
19 changes: 12 additions & 7 deletions src/measures/measures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@ const PROPER_SCORING_RULES = "[Gneiting and Raftery (2007), \"Strictly"*
"Proper Scoring Rules, Prediction, and Estimation\""*
"](https://doi.org/10.1198/016214506000001437)"
const DOC_FINITE =
"`AbstractArray{<:Finite}` (multiclass classification)"
"`AbstractArray{<:Union{Finite,Missing}` (multiclass classification)"
const DOC_FINITE_BINARY =
"`AbstractArray{<:Finite{2}}` (binary classification)"
"`AbstractArray{<:Union{Finite{2},Missing}}` (binary classification)"
const DOC_ORDERED_FACTOR =
"`AbstractArray{<:OrderedFactor}` (classification of ordered target)"
"`AbstractArray{<:Union{OrderedFactor,Missing}}` (classification of ordered target)"
const DOC_ORDERED_FACTOR_BINARY =
"`AbstractArray{<:OrderedFactor{2}}` "*
"`AbstractArray{<:Union{OrderedFactor{2},Missing}}` "*
"(binary classification where choice of \"true\" effects the measure)"
const DOC_CONTINUOUS = "`AbstractArray{Continuous}` (regression)"
const DOC_COUNT = "`AbstractArray{Count}`"
const DOC_INFINITE = "AbstractArray{<:Infinite}"
const DOC_CONTINUOUS = "`AbstractArray{<:Union{Continuous,Missing}}` (regression)"
const DOC_COUNT = "`AbstractArray{<:Union{Count,Missing}}`"
const DOC_MULTI = "`AbtractArray{<:Union{Missing,T}` where `T` is `Continuous` "*
"or `Count` (for respectively continuous or discrete Distribution.jl objects in "*
"`ŷ`) or `OrderedFactor` or `Multiclass` "*
"(for `UnivariateFinite` distributions in `ŷ`)"

const DOC_INFINITE = "`AbstractArray{<:Union{Infinite,Missing}}`"
const INVARIANT_LABEL =
"This metric is invariant to class reordering."
const VARIANT_LABEL =
Expand Down
4 changes: 2 additions & 2 deletions src/measures/meta_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function detailed_doc_string(M; typename="", body="", footer="", scitype="")
_instances = _decorate(instances(M))
human_name = MLJBase.human_name(M)
if isempty(scitype)
scitype = target_scitype(M) |> string
scitype = "`$(target_scitype(M))`"
end

if isempty(typename)
Expand Down Expand Up @@ -54,7 +54,7 @@ function detailed_doc_string(M; typename="", body="", footer="", scitype="")
(ret *= DOC_CLASS_WEIGHTS)
ret *= "\n\n"
isempty(body) || (ret *= "$body\n\n")
ret *= "Requires `scitype(y)` to be a subtype of `$scitype`; "
ret *= "Requires `scitype(y)` to be a subtype of $scitype; "
ret *= "`ŷ` must be an array of `$(prediction_type(M))` predictions. "
isempty(footer) ||(ret *= "\n\n$footer")
ret *= "\n\n"
Expand Down
14 changes: 9 additions & 5 deletions src/measures/probabilistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ supported `Distributions.UnivariateDistribution` such as `Normal` or
`Poisson`.
See also [`LogLoss`](@ref), which differs only in sign.
""")
""",
scitype=DOC_MULTI)

# for single finite observation:
single(c::LogScore, d::UnivariateFinite{S,V,R,P}, η::Label) where {S,V,R,P} =
Expand Down Expand Up @@ -193,7 +194,8 @@ const CrossEntropy = LogLoss
body=
"""
For details, see [`LogScore`](@ref), which differs only by a sign.
""")
""",
scitype=DOC_MULTI)

# for single finite observation:
single(c::LogLoss, d::UnivariateFinite{S,V,R,P}, η::Label) where {S,V,R,P} =
Expand Down Expand Up @@ -261,7 +263,7 @@ in the `Continuous` case (`p` the probablity density function) or
in the `Count` cae (`p` the probablity mass function).
""",
scitype=DOC_FINITE)
scitype=DOC_MULTI)

# calling on single finite observation:
function single(::BrierScore,
Expand Down Expand Up @@ -317,7 +319,8 @@ metadata_measure(BrierLoss;
body=
"""
For details, see [`BrierScore`](@ref), which differs only by a sign.
""")
""",
scitype=DOC_MULTI)

# calling on single finite observations:
single(::BrierLoss, d::UnivariateFinite{S,V,R,P}, η::Label) where {S,V,R,P} =
Expand Down Expand Up @@ -370,7 +373,8 @@ where `α` is the measure parameter `alpha`.
$DOC_DISTRIBUTIONS
""")
""",
scitype=DOC_MULTI)

# calling on single observations:
function single(s::SphericalScore,
Expand Down

0 comments on commit 48d4d02

Please sign in to comment.