diff --git a/docs/src/lib/functions.md b/docs/src/lib/functions.md index 69944899e2..11fc5cf63d 100644 --- a/docs/src/lib/functions.md +++ b/docs/src/lib/functions.md @@ -16,6 +16,8 @@ by colwise combine groupby +groupindices +groupvars join map melt diff --git a/src/DataFrames.jl b/src/DataFrames.jl index d77a97efe7..07ff57b6f1 100644 --- a/src/DataFrames.jl +++ b/src/DataFrames.jl @@ -38,6 +38,8 @@ export AbstractDataFrame, dropmissing!, eltypes, groupby, + groupindices, + groupvars, insertcols!, mapcols, melt, diff --git a/src/groupeddataframe/grouping.jl b/src/groupeddataframe/grouping.jl index 77c9cda94c..7f3ec15fb8 100644 --- a/src/groupeddataframe/grouping.jl +++ b/src/groupeddataframe/grouping.jl @@ -104,8 +104,22 @@ Base.last(gd::GroupedDataFrame) = gd[end] Base.getindex(gd::GroupedDataFrame, idx::Integer) = view(gd.parent, gd.idx[gd.starts[idx]:gd.ends[idx]], :) -Base.getindex(gd::GroupedDataFrame, idxs::AbstractArray) = - GroupedDataFrame(gd.parent, gd.cols, gd.groups, gd.idx, gd.starts[idxs], gd.ends[idxs]) + +function Base.getindex(gd::GroupedDataFrame, idxs::AbstractArray) + new_starts = gd.starts[idxs] + new_ends = gd.ends[idxs] + if !allunique(new_starts) + throw(ArgumentError("duplicates in idxs argument are not allowed")) + end + new_groups = zeros(Int, length(gd.groups)) + for idx in eachindex(new_starts) + @inbounds for j in new_starts[idx]:new_ends[idx] + new_groups[gd.idx[j]] = idx + end + end + GroupedDataFrame(gd.parent, gd.cols, new_groups, gd.idx, new_starts, new_ends) +end + Base.getindex(gd::GroupedDataFrame, idxs::Colon) = GroupedDataFrame(gd.parent, gd.cols, gd.groups, gd.idx, gd.starts, gd.ends) @@ -1125,3 +1139,21 @@ function DataFrame(gd::GroupedDataFrame) resize!(idx, doff - 1) parent(gd)[idx, :] end + +""" + groupindices(gd::GroupedDataFrame) + +Return a vector of group indices for each row of `parent(gd)`. + +Rows appearing in group `gd[i]` are attributed index `i`. Rows not present in +any group are attributed `missing` (this can happen if `skipmissing=true` was +passed when creating `gd`, or if `gd` is a subset from a larger `GroupedDataFrame`). +""" +groupindices(gd::GroupedDataFrame) = replace(gd.groups, 0=>missing) + +""" + groupvars(gd::GroupedDataFrame) + +Return a vector of column names in `parent(gd)` used for grouping. +""" +groupvars(gd::GroupedDataFrame) = _names(gd)[gd.cols] diff --git a/test/grouping.jl b/test/grouping.jl index b775d5aa39..cdea249122 100644 --- a/test/grouping.jl +++ b/test/grouping.jl @@ -6,12 +6,16 @@ const ≅ = isequal function groupby_checked(df::AbstractDataFrame, keys, args...; kwargs...) gd = groupby(df, keys, args...; kwargs...) - for i in 1:length(gd) - # checking that groups field is consistent with other fields - # (since == and isequal do not use it) - # and that idx is increasing per group - @assert findall(==(i), gd.groups) == gd.idx[gd.starts[i]:gd.ends[i]] + # checking that groups field is consistent with other fields + # (since == and isequal do not use it) + # and that idx is increasing per group + new_groups = zeros(Int, length(gd.groups)) + for idx in eachindex(gd.starts) + subidx = gd.idx[gd.starts[idx]:gd.ends[idx]] + @assert issorted(subidx) + new_groups[subidx] .= idx end + @assert new_groups == gd.groups if length(gd) > 0 se = sort!(collect(zip(gd.starts, gd.ends))) @@ -24,15 +28,10 @@ function groupby_checked(df::AbstractDataFrame, keys, args...; kwargs...) for i in eachindex(se) @assert se[i][1] <= se[i][2] if i > 1 + # the blocks returned by groupby must be continuous @assert se[i-1][2] + 1 == se[i][1] end end - - # correct coverage of missings if dropped - @assert findall(==(0), gd.groups) == gd.idx[1:se[1][1]-1] - else - # a case when missings are dropped and nothing was left to group by - @assert all(==(0), gd.groups) end gd @@ -878,11 +877,17 @@ end else @test_throws ArgumentError gd[true] end + @test_throws ArgumentError gd[[1, 2, 1]] @test_throws MethodError gd["a"] - gd2 = gd[[true, false, false, false]] + gd2 = gd[[false, true, false, false]] @test length(gd2) == 1 - @test gd2[1] == gd[1] + @test gd2[1] == gd[2] @test_throws BoundsError gd[[true, false]] + @test gd2.groups == [0, 1, 0, 0, 0, 1, 0, 0] + @test gd2.starts == [3] + @test gd2.ends == [4] + @test gd2.idx == gd.idx + gd3 = gd[:] @test gd3 isa GroupedDataFrame @test length(gd3) == 4 @@ -890,13 +895,17 @@ end for i in 1:4 @test gd3[i] == gd[i] end - gd4 = gd[[1,2]] + gd4 = gd[[2,1]] @test gd4 isa GroupedDataFrame @test length(gd4) == 2 for i in 1:2 - @test gd4[i] == gd[i] + @test gd4[i] == gd[3-i] end @test_throws BoundsError gd[1:5] + @test gd4.groups == [2, 1, 0, 0, 2, 1, 0, 0] + @test gd4.starts == [3,1] + @test gd4.ends == [4,2] + @test gd4.idx == gd.idx end @testset "== and isequal" begin @@ -1074,10 +1083,16 @@ end @test sort(DataFrame(gd), :B) ≅ sort(df, :B) @test eltypes(DataFrame(gd)) == [Union{Missing, Symbol}, Int] + gd2 = gd[[3,2]] + @test DataFrame(gd2) == df[[3,5,2,4], :] + gd = groupby_checked(df, :A, skipmissing=true) @test sort(DataFrame(gd), :B) == sort(dropmissing(df, disallowmissing=false), :B) @test eltypes(DataFrame(gd)) == [Union{Missing, Symbol}, Int] + + gd2 = gd[[2,1]] + @test DataFrame(gd2) == df[[3,5,2,4], :] end df = DataFrame(a=Int[], b=[], c=Union{Missing, String}[]) @@ -1091,4 +1106,25 @@ end @test eltypes(DataFrame(gd)) == [Union{Missing, Symbol}, Int] end +@testset "groupindices and groupvars" begin + df = DataFrame(A = [missing, :A, :B, :A, :B, missing], B = 1:6) + gd = groupby_checked(df, :A) + @inferred groupindices(gd) + @test groupindices(gd) == [1, 2, 3, 2, 3, 1] + @test groupvars(gd) == [:A] + gd2 = gd[[3,2]] + @inferred groupindices(gd2) + @test groupindices(gd2) ≅ [missing, 2, 1, 2, 1, missing] + @test groupvars(gd2) == [:A] + + gd = groupby_checked(df, :A, skipmissing=true) + @inferred groupindices(gd) + @test groupindices(gd) ≅ [missing, 1, 2, 1, 2, missing] + @test groupvars(gd) == [:A] + gd2 = gd[[2,1]] + @inferred groupindices(gd2) + @test groupindices(gd2) ≅ [missing, 2, 1, 2, 1, missing] + @test groupvars(gd2) == [:A] +end + end # module