Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add filter and filter! to GroupedDataFrame #2279

Merged
merged 8 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions src/abstractdataframe/abstractdataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,20 +887,20 @@ function dropmissing!(df::AbstractDataFrame,
end

"""
filter(function, df::AbstractDataFrame)
filter(cols => function, df::AbstractDataFrame)
filter(fun, df::AbstractDataFrame)
filter(cols => fun, df::AbstractDataFrame)

Return a copy of data frame `df` containing only rows for which `function`
Return a copy of data frame `df` containing only rows for which `fun`
returns `true`.

If `cols` is not specified then the function is passed `DataFrameRow`s.
If `cols` is not specified then the predicate `fun` is passed `DataFrameRow`s.

If `cols` is specified then the function is passed elements of the corresponding
columns as separate positional arguments, unless `cols` is an `AsTable` selector,
in which case a `NamedTuple` of these arguments is passed.
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR),
and column duplicates are allowed if a vector of `Symbol`s, strings, or integers
is passed.
If `cols` is specified then the predicate `fun` is passed elements of the
corresponding columns as separate positional arguments, unless `cols` is an
`AsTable` selector, in which case a `NamedTuple` of these arguments is passed.
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR), and
column duplicates are allowed if a vector of `Symbol`s, strings, or integers is
passed.

Passing `cols` leads to a more efficient execution of the operation for large data frames.

Expand Down Expand Up @@ -960,7 +960,6 @@ Base.filter((cols, f)::Pair{<:AbstractVector{Symbol}}, df::AbstractDataFrame) =
filter([index(df)[col] for col in cols] => f, df)
Base.filter((cols, f)::Pair{<:AbstractVector{<:AbstractString}}, df::AbstractDataFrame) =
filter([index(df)[col] for col in cols] => f, df)

Base.filter((cols, f)::Pair, df::AbstractDataFrame) =
filter(index(df)[cols] => f, df)

Expand All @@ -977,29 +976,30 @@ function _filter_helper(df::AbstractDataFrame, f, cols...)
end

function Base.filter((cols, f)::Pair{<:AsTable}, df::AbstractDataFrame)
dff = select(df, cols.cols, copycols=false)
if ncol(dff) == 0
df_tmp = select(df, cols.cols, copycols=false)
if ncol(df_tmp) == 0
throw(ArgumentError("At least one column must be passed to filter on"))
end
return _filter_helper_astable(df, Tables.namedtupleiterator(dff), f)
return _filter_helper_astable(df, Tables.namedtupleiterator(df_tmp), f)
end

_filter_helper_astable(df::AbstractDataFrame, nti::Tables.NamedTupleIterator, f) =
df[(x -> f(x)::Bool).(nti), :]

"""
filter!(function, df::AbstractDataFrame)
filter!(cols => function, df::AbstractDataFrame)

Remove rows from data frame `df` for which `function` returns `false`.

If `cols` is not specified then the function is passed `DataFrameRow`s.
If `cols` is specified then the function is passed elements of the corresponding
columns as separate positional arguments, unless `cols` is an `AsTable` selector,
in which case a `NamedTuple` of these arguments is passed.
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR),
and column duplicates are allowed if a vector of `Symbol`s, strings, or integers
is passed.
filter!(fun, df::AbstractDataFrame)
filter!(cols => fun, df::AbstractDataFrame)

Remove rows from data frame `df` for which `fun` returns `false`.

If `cols` is not specified then the predicate `fun` is passed `DataFrameRow`s.

If `cols` is specified then the predicate `fun` is passed elements of the
corresponding columns as separate positional arguments, unless `cols` is an
`AsTable` selector, in which case a `NamedTuple` of these arguments is passed.
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR), and
column duplicates are allowed if a vector of `Symbol`s, strings, or integers is
passed.

Passing `cols` leads to a more efficient execution of the operation for large data frames.

Expand Down
99 changes: 99 additions & 0 deletions src/groupeddataframe/groupeddataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,102 @@ function Base.get(gd::GroupedDataFrame, key::Union{Tuple, NamedTuple}, default)
return default
end
end

"""
filter(fun, gdf::GroupedDataFrame)
filter(cols => fun, gdf::GroupedDataFrame)

Return a new `GroupedDataFrame` containing only groups for which `fun`
returns `true`.

If `cols` is not specified then the predicate `fun` is called with a
`SubDataFrame` for each group.

If `cols` is specified then the predicate `fun` is called for each group with
views of the corresponding columns as separate positional arguments, unless
`cols` is an `AsTable` selector, in which case a `NamedTuple` of these arguments
is passed. `cols` can be any column selector ($COLUMNINDEX_STR;
$MULTICOLUMNINDEX_STR), and column duplicates are allowed if a vector of
`Symbol`s, strings, or integers is passed.

# Examples
```
julia> df = DataFrame(g=[1, 2], x=['a', 'b']);

julia> gd = groupby(df, :g)
GroupedDataFrame with 2 groups based on key: g
First Group (1 row): g = 1
│ Row │ g │ x │
│ │ Int64 │ Char │
├─────┼───────┼──────┤
│ 1 │ 1 │ 'a' │
Last Group (1 row): g = 2
│ Row │ g │ x │
│ │ Int64 │ Char │
├─────┼───────┼──────┤
│ 1 │ 2 │ 'b' │

julia> filter(x -> x.x[1] == 'a', gd)
GroupedDataFrame with 1 group based on key: g
First Group (1 row): g = 1
│ Row │ g │ x │
│ │ Int64 │ Char │
├─────┼───────┼──────┤
│ 1 │ 1 │ 'a' │

julia> filter(:x => x -> x[1] == 'a', gd)
GroupedDataFrame with 1 group based on key: g
First Group (1 row): g = 1
│ Row │ g │ x │
│ │ Int64 │ Char │
├─────┼───────┼──────┤
│ 1 │ 1 │ 'a' │

```
"""
Base.filter(f, gdf::GroupedDataFrame) =
gdf[[f(sdf)::Bool for sdf in gdf]]
Base.filter((col, f)::Pair{<:ColumnIndex}, gdf::GroupedDataFrame) =
_filter_helper(gdf, f, gdf.idx, gdf.starts, gdf.ends, parent(gdf)[!, col])
Base.filter((cols, f)::Pair{<:AbstractVector{Symbol}}, gdf::GroupedDataFrame) =
filter([index(parent(gdf))[col] for col in cols] => f, gdf)
Base.filter((cols, f)::Pair{<:AbstractVector{<:AbstractString}}, gdf::GroupedDataFrame) =
filter([index(parent(gdf))[col] for col in cols] => f, gdf)
Base.filter((cols, f)::Pair, gdf::GroupedDataFrame) =
filter(index(parent(gdf))[cols] => f, gdf)
Base.filter((cols, f)::Pair{<:AbstractVector{Int}}, gdf::GroupedDataFrame) =
_filter_helper(gdf, f, gdf.idx, gdf.starts, gdf.ends, (parent(gdf)[!, i] for i in cols)...)

function _filter_helper(gdf::GroupedDataFrame, f, idx::Vector{Int},
starts::Vector{Int}, ends::Vector{Int}, cols...)
function mapper(i::Integer)
idxs = idx[starts[i]:ends[i]]
return map(x -> view(x, idxs), cols)
end

if length(cols) == 0
throw(ArgumentError("At least one column must be passed to filter on"))
end
sel = [f(mapper(i)...)::Bool for i in 1:length(gdf)]
return gdf[sel]
end

function Base.filter((cols, f)::Pair{<:AsTable}, gdf::GroupedDataFrame)
df_tmp = select(parent(gdf), cols.cols, copycols=false)
if ncol(df_tmp) == 0
throw(ArgumentError("At least one column must be passed to filter on"))
end
return _filter_helper_astable(gdf, Tables.columntable(df_tmp), f,
gdf.idx, gdf.starts, gdf.ends)
end

function _filter_helper_astable(gdf::GroupedDataFrame, nt::NamedTuple, f,
idx::Vector{Int}, starts::Vector{Int}, ends::Vector{Int})
function mapper(i::Integer)
idxs = idx[starts[i]:ends[i]]
return map(x -> view(x, idxs), nt)
end

return gdf[[f(mapper(i))::Bool for i in 1:length(gdf)]]
end
39 changes: 39 additions & 0 deletions test/grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2338,4 +2338,43 @@ end
@test eltype(df2.a) === eltype(df2.b) === Union{UInt, Missing}
end

@testset "filter" begin
for df in (DataFrame(g1=[1, 3, 2, 1, 4, 1, 2, 5], x1=1:8,
g2=[1, 3, 2, 1, 4, 1, 2, 5], x2=1:8),
view(DataFrame(g1=[1, 3, 2, 1, 4, 1, 2, 5, 4, 5], x1=1:10,
g2=[1, 3, 2, 1, 4, 1, 2, 5, 4, 5], x2=1:10, y=1:10),
1:8, Not(:y)))
for gcols in (:g1, [:g1, :g2]), cutoff in (1, 0, 10),
predicate in (x -> nrow(x) > cutoff,
1 => x -> length(x) > cutoff,
:x1 => x -> length(x) > cutoff,
"x1" => x -> length(x) > cutoff,
[1, 2] => (x1, x2) -> length(x1) > cutoff,
[:x1, :x2] => (x1, x2) -> length(x1) > cutoff,
["x1", "x2"] => (x1, x2) -> length(x1) > cutoff,
r"x" => (x1, x2) -> length(x1) > cutoff,
AsTable(:x1) => x -> length(x.x1) > cutoff,
AsTable(r"x") => x -> length(x.x1) > cutoff)
gdf1 = groupby_checked(df, gcols)
gdf2 = filter(predicate, gdf1)
if cutoff == 1
@test getindex.(keys(gdf2), 1) == 1:2
elseif cutoff == 0
@test gdf1 == gdf2
elseif cutoff == 10
@test isempty(gdf2)
end
end

@test_throws TypeError filter(x -> 1, groupby_checked(df, :g1))
@test_throws TypeError filter(r"x" => (x...) -> 1, groupby_checked(df, :g1))
@test_throws TypeError filter(AsTable(r"x") => (x...) -> 1, groupby_checked(df, :g1))

@test_throws ArgumentError filter(r"y" => (x...) -> true, groupby_checked(df, :g1))
@test_throws ArgumentError filter([] => (x...) -> true, groupby_checked(df, :g1))
@test_throws ArgumentError filter(AsTable(r"y") => (x...) -> true, groupby_checked(df, :g1))
@test_throws ArgumentError filter(AsTable([]) => (x...) -> true, groupby_checked(df, :g1))
end
end

end # module