Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide formatter for labeling categories in cut function #202

Merged
merged 10 commits into from
Sep 20, 2019
94 changes: 80 additions & 14 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},
end

"""
default_formatter(from, to, i; closed=false)

Provide the default label format for the `cut` function.
"""
default_formatter(from, to, i; closed=false) = string("[", from, ", ", to, closed ? "]" : ")")

@doc raw"""
cut(x::AbstractArray, breaks::AbstractVector;
extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false)
labels::Union{AbstractVector{<:AbstractString},Function},
extend::Bool=false, allow_missing::Bool=false)

Cut a numeric array into intervals and return an ordered `CategoricalArray` indicating
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
Expand All @@ -56,14 +64,76 @@ also accept them.
* `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall
outside of the breaks; when `true`, breaks are automatically added to include all
values in `x`, and the upper bound is included in the last interval.
* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the
intervals; if empty, default labels are used.
* `labels::Union{AbstractVector,Function}: a vector of strings giving the names to use for
the intervals; or a function `f(from, to, i; extend)` that generates the labels from the
greimel marked this conversation as resolved.
Show resolved Hide resolved
left and right interval boundaries and the group index. Defaults to
`string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`.
* `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values.
only supported when `x` accepts missing values.

# Examples
```jldoctest
julia> using CategoricalArrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually needed? There's probably a way to avoid it as it takes a lot of space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify what you mean by "this"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this using line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - I dropped it. Let's see what CI says


julia> x = collect(-1:0.5:1);
greimel marked this conversation as resolved.
Show resolved Hide resolved

julia> cut(x, [0, 1], extend=true)
5-element CategoricalArray{String,1,UInt32}:
"[-1.0, 0.0)"
"[-1.0, 0.0)"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[0.0, 1.0]"

julia> cut(x, 2)
5-element CategoricalArray{String,1,UInt32}:
"[-1.0, 0.0)"
"[-1.0, 0.0)"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[0.0, 1.0]"

julia> cut(x, 2, labels=["A", "B"])
5-element CategoricalArray{String,1,UInt32}:
"A"
"A"
"B"
"B"
"B"

julia> fmt(from, to, i; closed) = "grp $i ($from//$to)"
fmt (generic function with 1 method)

julia> cut(x, 3, labels=fmt)
5-element CategoricalArray{String,1,UInt32}:
"grp 1 (-1.0//-0.333333)"
"grp 1 (-1.0//-0.333333)"
"grp 2 (-0.333333//0.333333)"
"grp 3 (0.333333//1.0)"
"grp 3 (0.333333//1.0)"

julia> using StatsBase: ecdf
greimel marked this conversation as resolved.
Show resolved Hide resolved

julia> percentile(x) = round(Int,100*parse(Float64,x))
percentile (generic function with 1 method)

julia> fmt2(from, to, i; closed) = "P$(percentile(from))P$(percentile(to))"
fmt2 (generic function with 1 method)

julia> cut(ecdf(x)(x), 3, labels=fmt2)
5-element CategoricalArray{String,1,UInt32}:
"P20P47"
"P20P47"
"P47P73"
"P73P100"
"P73P100"
```
"""
function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
extend::Bool=false, labels::AbstractVector{U}=String[],
allow_missing::Bool=false) where {T, N, U<:AbstractString}
extend::Bool=false,
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter,
allow_missing::Bool=false) where {T, N}

greimel marked this conversation as resolved.
Show resolved Hide resolved
if !issorted(breaks)
breaks = sort(breaks)
end
Expand Down Expand Up @@ -92,7 +162,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end

n = length(breaks)
if isempty(labels)
if labels isa Function
@static if VERSION >= v"0.7.0-DEV.4524"
from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1])
to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n])
Expand All @@ -102,13 +172,9 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end
levs = Vector{String}(undef, n-1)
for i in 1:n-2
levs[i] = string("[", from[i], ", ", to[i], ")")
end
if extend
levs[end] = string("[", from[end], ", ", to[end], "]")
else
levs[end] = string("[", from[end], ", ", to[end], ")")
levs[i] = labels(from[i], to[i], i, closed=false)
end
levs[end] = labels(from[end], to[end], n-1, closed=extend)
else
length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
# Levels must have element type String for type stability of the result
Expand All @@ -122,11 +188,11 @@ end

"""
cut(x::AbstractArray, ngroups::Integer;
labels::AbstractVector=String[])
labels::Union{AbstractVector{<:AbstractString},Function})

Cut a numeric array into `ngroups` quantiles, determined using
[`quantile`](@ref).
"""
cut(x::AbstractArray, ngroups::Integer;
labels::AbstractVector{U}=String[]) where {U<:AbstractString} =
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter) =
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels)
9 changes: 9 additions & 0 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,13 @@ end
@test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"]
end

@testset "formatter function" begin
greimel marked this conversation as resolved.
Show resolved Hide resolved
my_formatter(from, to, i; extend) = "$i: $from -- $to"

x = collect(0.15:0.20:0.95)
greimel marked this conversation as resolved.
Show resolved Hide resolved
p = [0, 0.4, 0.8, 1.0]

@test cut(x, p, labels=my_formatter2) == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]
greimel marked this conversation as resolved.
Show resolved Hide resolved
end

end