diff --git a/src/extras.jl b/src/extras.jl index ca7f9356..9a9e2c62 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -42,8 +42,16 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing}, end """ + default_formatter(from, to, i; closed=false) + +Provide the default label format for the `cut` function. +""" +default_formatter(from, to, i; closed) = string("[", from, ", ", to, closed ? "]" : ")") + +@doc raw""" cut(x::AbstractArray, breaks::AbstractVector; - extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false) + labels::Union{AbstractVector{<:AbstractString},Function}, + extend::Bool=false, allow_missing::Bool=false) Cut a numeric array into intervals and return an ordered `CategoricalArray` indicating the interval into which each entry falls. Intervals are of the form `[lower, upper)`, @@ -56,14 +64,55 @@ also accept them. * `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall outside of the breaks; when `true`, breaks are automatically added to include all values in `x`, and the upper bound is included in the last interval. -* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the - intervals; if empty, default labels are used. +* `labels::Union{AbstractVector,Function}: a vector of strings giving the names to use for + the intervals; or a function `f(from, to, i; closed)` that generates the labels from the + left and right interval boundaries and the group index. Defaults to + `"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`). * `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values. only supported when `x` accepts missing values. + +# Examples +```jldoctest +julia> cut(-1:0.5:1, [0, 1], extend=true) +5-element CategoricalArray{String,1,UInt32}: + "[-1.0, 0.0)" + "[-1.0, 0.0)" + "[0.0, 1.0]" + "[0.0, 1.0]" + "[0.0, 1.0]" + +julia> cut(-1:0.5:1, 2) +5-element CategoricalArray{String,1,UInt32}: + "[-1.0, 0.0)" + "[-1.0, 0.0)" + "[0.0, 1.0]" + "[0.0, 1.0]" + "[0.0, 1.0]" + +julia> cut(-1:0.5:1, 2, labels=["A", "B"]) +5-element CategoricalArray{String,1,UInt32}: + "A" + "A" + "B" + "B" + "B" + +julia> fmt(from, to, i; closed) = "grp $i ($from//$to)" +fmt (generic function with 1 method) + +julia> cut(-1:0.5:1, 3, labels=fmt) +5-element CategoricalArray{String,1,UInt32}: + "grp 1 (-1.0//-0.333333)" + "grp 1 (-1.0//-0.333333)" + "grp 2 (-0.333333//0.333333)" + "grp 3 (0.333333//1.0)" + "grp 3 (0.333333//1.0)" +``` """ function cut(x::AbstractArray{T, N}, breaks::AbstractVector; - extend::Bool=false, labels::AbstractVector{U}=String[], - allow_missing::Bool=false) where {T, N, U<:AbstractString} + extend::Bool=false, + labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter, + allow_missing::Bool=false) where {T, N} if !issorted(breaks) breaks = sort(breaks) end @@ -92,7 +141,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector; end n = length(breaks) - if isempty(labels) + if labels isa Function @static if VERSION >= v"0.7.0-DEV.4524" from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1]) to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n]) @@ -102,13 +151,9 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector; end levs = Vector{String}(undef, n-1) for i in 1:n-2 - levs[i] = string("[", from[i], ", ", to[i], ")") - end - if extend - levs[end] = string("[", from[end], ", ", to[end], "]") - else - levs[end] = string("[", from[end], ", ", to[end], ")") + levs[i] = labels(from[i], to[i], i, closed=false) end + levs[end] = labels(from[end], to[end], n-1, closed=extend) else length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))")) # Levels must have element type String for type stability of the result @@ -122,11 +167,11 @@ end """ cut(x::AbstractArray, ngroups::Integer; - labels::AbstractVector=String[]) + labels::Union{AbstractVector{<:AbstractString},Function}) Cut a numeric array into `ngroups` quantiles, determined using [`quantile`](@ref). """ cut(x::AbstractArray, ngroups::Integer; - labels::AbstractVector{U}=String[]) where {U<:AbstractString} = + labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter) = cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels) diff --git a/test/15_extras.jl b/test/15_extras.jl index 9e8834da..6c2b7692 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -108,4 +108,14 @@ end @test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"] end +@testset "cut with formatter function" begin + my_formatter(from, to, i; closed) = "$i: $from -- $to" + + x = 0.15:0.20:0.95 + p = [0, 0.4, 0.8, 1.0] + + @test cut(x, p, labels=my_formatter) == + ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"] +end + end