diff --git a/src/extras.jl b/src/extras.jl index e5d66d5b..5381039d 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -41,6 +41,13 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing}, end end +""" + _default_formatter_ + +Provide the default label format for the `cut` function. +""" +_default_formatter_(from, to, i; extend=false) = string("[", from, ", ", to, extend ? "]" : ")") + """ cut(x::AbstractArray, breaks::AbstractVector; extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false) @@ -56,16 +63,16 @@ also accept them. * `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall outside of the breaks; when `true`, breaks are automatically added to include all values in `x`, and the upper bound is included in the last interval. -* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the - intervals; if empty, default labels are used. -* `label_formatter::Function`: a function `f(from,to;extend=false)` that generates the labels from the left and right interval boundaries. Defaults to `string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`. +* `labels::Union{AbstractVector,Function}=_default_formatter_`: a vector of strings giving the names to use for the + intervals; or a function `f(from,to,i;extend=false)` that generates the labels from the left and right interval boundaries and the group index. Defaults to `string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`. * `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values. only supported when `x` accepts missing values. """ function cut(x::AbstractArray{T, N}, breaks::AbstractVector; - extend::Bool=false, labels::AbstractVector{U}=String[], - label_formatter=_default_formatter_, + extend::Bool=false, labels=_default_formatter_, allow_missing::Bool=false) where {T, N, U<:AbstractString} + (labels isa AbstractVector) || (labels isa Function) || throw(ArgumentError("labels must be a formatter function or an AbstractVector")) + if !issorted(breaks) breaks = sort(breaks) end @@ -94,7 +101,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector; end n = length(breaks) - if isempty(labels) + if labels isa Function @static if VERSION >= v"0.7.0-DEV.4524" from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1]) to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n]) @@ -104,12 +111,12 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector; end levs = Vector{String}(undef, n-1) for i in 1:n-2 - levs[i] = label_formatter(from[i], to[i]) + levs[i] = labels(from[i], to[i], i) end if extend - levs[end] = label_formatter(from[end], to[end], extend=extend) + levs[end] = labels(from[end], to[end], n-1, extend=extend) else - levs[end] = label_formatter(from[end], to[end]) + levs[end] = labels(from[end], to[end], n-1) end else length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))")) @@ -130,8 +137,8 @@ Cut a numeric array into `ngroups` quantiles, determined using [`quantile`](@ref). """ cut(x::AbstractArray, ngroups::Integer; - labels::AbstractVector{U}=String[], label_formatter=_default_formatter_) where {U<:AbstractString} = - cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels, label_formatter=label_formatter) + labels=_default_formatter_) = + cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels) + -_default_formatter_(from, to; extend=false) = string("[", from, ", ", to, extend ? "]" : ")") diff --git a/test/15_extras.jl b/test/15_extras.jl index 9e8834da..e83a194f 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -108,4 +108,20 @@ end @test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"] end +@testset "formatter function" begin + my_formatter1(from,to,i;extend=false) = "group $i" + my_formatter2(from,to,i;extend=false) = "$i: $from -- $to" + function my_formatter3(from,to,i;extend=true) + percentile(x) = Int(round(100 * parse.(Float64,x),digits=0)) + string("P",percentile(from),"P",percentile(to)) + end + + x = collect(0.15:0.20:0.95) + p = [0, 0.4, 0.8, 1.0] + + @test cut(x, p, labels=my_formatter1) == ["group 1", "group 1", "group 2", "group 2", "group 3"] + @test cut(x, p, labels=my_formatter2) == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"] + @test cut(x, p, labels=my_formatter3) == ["P0P40" , "P0P40" , "P40P80" , "P40P80" , "P80P100"] +end + end