Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide formatter for labeling categories in cut function #202

Merged
merged 10 commits into from
Sep 20, 2019
73 changes: 59 additions & 14 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},
end

"""
default_formatter(from, to, i; closed=false)

Provide the default label format for the `cut` function.
"""
default_formatter(from, to, i; closed) = string("[", from, ", ", to, closed ? "]" : ")")

@doc raw"""
cut(x::AbstractArray, breaks::AbstractVector;
extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false)
labels::Union{AbstractVector{<:AbstractString},Function},
extend::Bool=false, allow_missing::Bool=false)

Cut a numeric array into intervals and return an ordered `CategoricalArray` indicating
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
Expand All @@ -56,14 +64,55 @@ also accept them.
* `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall
outside of the breaks; when `true`, breaks are automatically added to include all
values in `x`, and the upper bound is included in the last interval.
* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the
intervals; if empty, default labels are used.
* `labels::Union{AbstractVector,Function}: a vector of strings giving the names to use for
the intervals; or a function `f(from, to, i; closed)` that generates the labels from the
left and right interval boundaries and the group index. Defaults to
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
* `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values.
only supported when `x` accepts missing values.

# Examples
```jldoctest
julia> cut(-1:0.5:1, [0, 1], extend=true)
5-element CategoricalArray{String,1,UInt32}:
"[-1.0, 0.0)"
"[-1.0, 0.0)"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[0.0, 1.0]"

julia> cut(-1:0.5:1, 2)
5-element CategoricalArray{String,1,UInt32}:
"[-1.0, 0.0)"
"[-1.0, 0.0)"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[0.0, 1.0]"

julia> cut(-1:0.5:1, 2, labels=["A", "B"])
5-element CategoricalArray{String,1,UInt32}:
"A"
"A"
"B"
"B"
"B"

julia> fmt(from, to, i; closed) = "grp $i ($from//$to)"
fmt (generic function with 1 method)

julia> cut(-1:0.5:1, 3, labels=fmt)
5-element CategoricalArray{String,1,UInt32}:
"grp 1 (-1.0//-0.333333)"
"grp 1 (-1.0//-0.333333)"
"grp 2 (-0.333333//0.333333)"
"grp 3 (0.333333//1.0)"
"grp 3 (0.333333//1.0)"
```
"""
function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
extend::Bool=false, labels::AbstractVector{U}=String[],
allow_missing::Bool=false) where {T, N, U<:AbstractString}
extend::Bool=false,
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter,
allow_missing::Bool=false) where {T, N}
if !issorted(breaks)
breaks = sort(breaks)
end
Expand Down Expand Up @@ -92,7 +141,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end

n = length(breaks)
if isempty(labels)
if labels isa Function
@static if VERSION >= v"0.7.0-DEV.4524"
from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1])
to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n])
Expand All @@ -102,13 +151,9 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end
levs = Vector{String}(undef, n-1)
for i in 1:n-2
levs[i] = string("[", from[i], ", ", to[i], ")")
end
if extend
levs[end] = string("[", from[end], ", ", to[end], "]")
else
levs[end] = string("[", from[end], ", ", to[end], ")")
levs[i] = labels(from[i], to[i], i, closed=false)
end
levs[end] = labels(from[end], to[end], n-1, closed=extend)
else
length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
# Levels must have element type String for type stability of the result
Expand All @@ -122,11 +167,11 @@ end

"""
cut(x::AbstractArray, ngroups::Integer;
labels::AbstractVector=String[])
labels::Union{AbstractVector{<:AbstractString},Function})

Cut a numeric array into `ngroups` quantiles, determined using
[`quantile`](@ref).
"""
cut(x::AbstractArray, ngroups::Integer;
labels::AbstractVector{U}=String[]) where {U<:AbstractString} =
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter) =
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels)
10 changes: 10 additions & 0 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,14 @@ end
@test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"]
end

@testset "cut with formatter function" begin
my_formatter(from, to, i; closed) = "$i: $from -- $to"

x = 0.15:0.20:0.95
p = [0, 0.4, 0.8, 1.0]

@test cut(x, p, labels=my_formatter) ==
["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]
end

end