Skip to content

Commit

Permalink
Use labels kw, group index, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
greimel committed Sep 15, 2019
1 parent bc50544 commit d1a3461
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
31 changes: 19 additions & 12 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},
end
end

"""
_default_formatter_
Provide the default label format for the `cut` function.
"""
_default_formatter_(from, to, i; extend=false) = string("[", from, ", ", to, extend ? "]" : ")")

"""
cut(x::AbstractArray, breaks::AbstractVector;
extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false)
Expand All @@ -56,16 +63,16 @@ also accept them.
* `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall
outside of the breaks; when `true`, breaks are automatically added to include all
values in `x`, and the upper bound is included in the last interval.
* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the
intervals; if empty, default labels are used.
* `label_formatter::Function`: a function `f(from,to;extend=false)` that generates the labels from the left and right interval boundaries. Defaults to `string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`.
* `labels::Union{AbstractVector,Function}=_default_formatter_`: a vector of strings giving the names to use for the
intervals; or a function `f(from,to,i;extend=false)` that generates the labels from the left and right interval boundaries and the group index. Defaults to `string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`.
* `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values.
only supported when `x` accepts missing values.
"""
function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
extend::Bool=false, labels::AbstractVector{U}=String[],
label_formatter=_default_formatter_,
extend::Bool=false, labels=_default_formatter_,
allow_missing::Bool=false) where {T, N, U<:AbstractString}
(labels isa AbstractVector) || (labels isa Function) || throw(ArgumentError("labels must be a formatter function or an AbstractVector"))

if !issorted(breaks)
breaks = sort(breaks)
end
Expand Down Expand Up @@ -94,7 +101,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end

n = length(breaks)
if isempty(labels)
if labels isa Function
@static if VERSION >= v"0.7.0-DEV.4524"
from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1])
to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n])
Expand All @@ -104,12 +111,12 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end
levs = Vector{String}(undef, n-1)
for i in 1:n-2
levs[i] = label_formatter(from[i], to[i])
levs[i] = labels(from[i], to[i], i)
end
if extend
levs[end] = label_formatter(from[end], to[end], extend=extend)
levs[end] = labels(from[end], to[end], n-1, extend=extend)
else
levs[end] = label_formatter(from[end], to[end])
levs[end] = labels(from[end], to[end], n-1)
end
else
length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
Expand All @@ -130,8 +137,8 @@ Cut a numeric array into `ngroups` quantiles, determined using
[`quantile`](@ref).
"""
cut(x::AbstractArray, ngroups::Integer;
labels::AbstractVector{U}=String[], label_formatter=_default_formatter_) where {U<:AbstractString} =
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels, label_formatter=label_formatter)
labels=_default_formatter_) =
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels)


_default_formatter_(from, to; extend=false) = string("[", from, ", ", to, extend ? "]" : ")")

16 changes: 16 additions & 0 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,20 @@ end
@test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"]
end

@testset "formatter function" begin
my_formatter1(from,to,i;extend=false) = "group $i"
my_formatter2(from,to,i;extend=false) = "$i: $from -- $to"
function my_formatter3(from,to,i;extend=true)
percentile(x) = Int(round(100 * parse.(Float64,x),digits=0))
string("P",percentile(from),"P",percentile(to))
end

x = collect(0.15:0.20:0.95)
p = [0, 0.4, 0.8, 1.0]

@test cut(x, p, labels=my_formatter1) == ["group 1", "group 1", "group 2", "group 2", "group 3"]
@test cut(x, p, labels=my_formatter2) == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]
@test cut(x, p, labels=my_formatter3) == ["P0P40" , "P0P40" , "P40P80" , "P40P80" , "P80P100"]
end

end

0 comments on commit d1a3461

Please sign in to comment.