Skip to content

Commit

Permalink
GroupedDataFrame: fix getindex and add groupvars and groupindices
Browse files Browse the repository at this point in the history
  • Loading branch information
bkamins committed Feb 5, 2019
1 parent 8afc9e7 commit 359bf31
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/src/lib/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ by
colwise
combine
groupby
groupindices
groupvars
join
map
melt
Expand Down
2 changes: 2 additions & 0 deletions src/DataFrames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ export AbstractDataFrame,
dropmissing!,
eltypes,
groupby,
groupindices,
groupvars,
insertcols!,
mapcols,
melt,
Expand Down
36 changes: 34 additions & 2 deletions src/groupeddataframe/grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,22 @@ Base.last(gd::GroupedDataFrame) = gd[end]

Base.getindex(gd::GroupedDataFrame, idx::Integer) =
view(gd.parent, gd.idx[gd.starts[idx]:gd.ends[idx]], :)
Base.getindex(gd::GroupedDataFrame, idxs::AbstractArray) =
GroupedDataFrame(gd.parent, gd.cols, gd.groups, gd.idx, gd.starts[idxs], gd.ends[idxs])

function Base.getindex(gd::GroupedDataFrame, idxs::AbstractArray)
new_starts = gd.starts[idxs]
new_ends = gd.ends[idxs]
if !allunique(new_starts)
throw(ArgumentError("duplicates in idxs argument are not allowed"))
end
new_groups = zeros(Int, length(gd.groups))
for idx in eachindex(new_starts)
@inbounds for j in new_starts[idx]:new_ends[idx]
new_groups[gd.idx[j]] = idx
end
end
GroupedDataFrame(gd.parent, gd.cols, new_groups, gd.idx, new_starts, new_ends)
end

Base.getindex(gd::GroupedDataFrame, idxs::Colon) =
GroupedDataFrame(gd.parent, gd.cols, gd.groups, gd.idx, gd.starts, gd.ends)

Expand Down Expand Up @@ -1125,3 +1139,21 @@ function DataFrame(gd::GroupedDataFrame)
resize!(idx, doff - 1)
parent(gd)[idx, :]
end

"""
groupindices(gd::GroupedDataFrame)
Return a vector of group indices for each row of `parent(gd)`.
Rows appearing in group `gd[i]` are attributed index `i`. Rows not present in
any group are attributed `missing` (this can happen if `skipmissing=true` was
passed when creating `gd`, or if `gd` is a subset from a larger `GroupedDataFrame`).
"""
groupindices(gd::GroupedDataFrame) = replace(gd.groups, 0=>missing)

"""
groupvars(gd::GroupedDataFrame)
Return a vector of column names in `parent(gd)` used for grouping.
"""
groupvars(gd::GroupedDataFrame) = _names(gd)[gd.cols]
66 changes: 51 additions & 15 deletions test/grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ const ≅ = isequal
function groupby_checked(df::AbstractDataFrame, keys, args...; kwargs...)
gd = groupby(df, keys, args...; kwargs...)

for i in 1:length(gd)
# checking that groups field is consistent with other fields
# (since == and isequal do not use it)
# and that idx is increasing per group
@assert findall(==(i), gd.groups) == gd.idx[gd.starts[i]:gd.ends[i]]
# checking that groups field is consistent with other fields
# (since == and isequal do not use it)
# and that idx is increasing per group
new_groups = zeros(Int, length(gd.groups))
for idx in eachindex(gd.starts)
subidx = gd.idx[gd.starts[idx]:gd.ends[idx]]
@assert issorted(subidx)
new_groups[subidx] .= idx
end
@assert new_groups == gd.groups

if length(gd) > 0
se = sort!(collect(zip(gd.starts, gd.ends)))
Expand All @@ -24,15 +28,10 @@ function groupby_checked(df::AbstractDataFrame, keys, args...; kwargs...)
for i in eachindex(se)
@assert se[i][1] <= se[i][2]
if i > 1
# the blocks returned by groupby must be continuous
@assert se[i-1][2] + 1 == se[i][1]
end
end

# correct coverage of missings if dropped
@assert findall(==(0), gd.groups) == gd.idx[1:se[1][1]-1]
else
# a case when missings are dropped and nothing was left to group by
@assert all(==(0), gd.groups)
end

gd
Expand Down Expand Up @@ -878,25 +877,35 @@ end
else
@test_throws ArgumentError gd[true]
end
@test_throws ArgumentError gd[[1, 2, 1]]
@test_throws MethodError gd["a"]
gd2 = gd[[true, false, false, false]]
gd2 = gd[[false, true, false, false]]
@test length(gd2) == 1
@test gd2[1] == gd[1]
@test gd2[1] == gd[2]
@test_throws BoundsError gd[[true, false]]
@test gd2.groups == [0, 1, 0, 0, 0, 1, 0, 0]
@test gd2.starts == [3]
@test gd2.ends == [4]
@test gd2.idx == gd.idx

gd3 = gd[:]
@test gd3 isa GroupedDataFrame
@test length(gd3) == 4
@test gd3 == gd
for i in 1:4
@test gd3[i] == gd[i]
end
gd4 = gd[[1,2]]
gd4 = gd[[2,1]]
@test gd4 isa GroupedDataFrame
@test length(gd4) == 2
for i in 1:2
@test gd4[i] == gd[i]
@test gd4[i] == gd[3-i]
end
@test_throws BoundsError gd[1:5]
@test gd4.groups == [2, 1, 0, 0, 2, 1, 0, 0]
@test gd4.starts == [3,1]
@test gd4.ends == [4,2]
@test gd4.idx == gd.idx
end

@testset "== and isequal" begin
Expand Down Expand Up @@ -1074,10 +1083,16 @@ end
@test sort(DataFrame(gd), :B) sort(df, :B)
@test eltypes(DataFrame(gd)) == [Union{Missing, Symbol}, Int]

gd2 = gd[[3,2]]
@test DataFrame(gd2) == df[[3,5,2,4], :]

gd = groupby_checked(df, :A, skipmissing=true)
@test sort(DataFrame(gd), :B) ==
sort(dropmissing(df, disallowmissing=false), :B)
@test eltypes(DataFrame(gd)) == [Union{Missing, Symbol}, Int]

gd2 = gd[[2,1]]
@test DataFrame(gd2) == df[[3,5,2,4], :]
end

df = DataFrame(a=Int[], b=[], c=Union{Missing, String}[])
Expand All @@ -1091,4 +1106,25 @@ end
@test eltypes(DataFrame(gd)) == [Union{Missing, Symbol}, Int]
end

@testset "groupindices and groupvars" begin
df = DataFrame(A = [missing, :A, :B, :A, :B, missing], B = 1:6)
gd = groupby_checked(df, :A)
@inferred groupindices(gd)
@test groupindices(gd) == [1, 2, 3, 2, 3, 1]
@test groupvars(gd) == [:A]
gd2 = gd[[3,2]]
@inferred groupindices(gd2)
@test groupindices(gd2) [missing, 2, 1, 2, 1, missing]
@test groupvars(gd2) == [:A]

gd = groupby_checked(df, :A, skipmissing=true)
@inferred groupindices(gd)
@test groupindices(gd) [missing, 1, 2, 1, 2, missing]
@test groupvars(gd) == [:A]
gd2 = gd[[2,1]]
@inferred groupindices(gd2)
@test groupindices(gd2) [missing, 2, 1, 2, 1, missing]
@test groupvars(gd2) == [:A]
end

end # module

0 comments on commit 359bf31

Please sign in to comment.