diff --git a/src/abstractdataframe/abstractdataframe.jl b/src/abstractdataframe/abstractdataframe.jl index 96b8c97494..c8e87d13de 100644 --- a/src/abstractdataframe/abstractdataframe.jl +++ b/src/abstractdataframe/abstractdataframe.jl @@ -887,20 +887,20 @@ function dropmissing!(df::AbstractDataFrame, end """ - filter(function, df::AbstractDataFrame) - filter(cols => function, df::AbstractDataFrame) + filter(fun, df::AbstractDataFrame) + filter(cols => fun, df::AbstractDataFrame) -Return a copy of data frame `df` containing only rows for which `function` +Return a copy of data frame `df` containing only rows for which `fun` returns `true`. -If `cols` is not specified then the function is passed `DataFrameRow`s. +If `cols` is not specified then the predicate `fun` is passed `DataFrameRow`s. -If `cols` is specified then the function is passed elements of the corresponding -columns as separate positional arguments, unless `cols` is an `AsTable` selector, -in which case a `NamedTuple` of these arguments is passed. -`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR), -and column duplicates are allowed if a vector of `Symbol`s, strings, or integers -is passed. +If `cols` is specified then the predicate `fun` is passed elements of the +corresponding columns as separate positional arguments, unless `cols` is an +`AsTable` selector, in which case a `NamedTuple` of these arguments is passed. +`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR), and +column duplicates are allowed if a vector of `Symbol`s, strings, or integers is +passed. Passing `cols` leads to a more efficient execution of the operation for large data frames. @@ -960,7 +960,6 @@ Base.filter((cols, f)::Pair{<:AbstractVector{Symbol}}, df::AbstractDataFrame) = filter([index(df)[col] for col in cols] => f, df) Base.filter((cols, f)::Pair{<:AbstractVector{<:AbstractString}}, df::AbstractDataFrame) = filter([index(df)[col] for col in cols] => f, df) - Base.filter((cols, f)::Pair, df::AbstractDataFrame) = filter(index(df)[cols] => f, df) @@ -977,29 +976,30 @@ function _filter_helper(df::AbstractDataFrame, f, cols...) end function Base.filter((cols, f)::Pair{<:AsTable}, df::AbstractDataFrame) - dff = select(df, cols.cols, copycols=false) - if ncol(dff) == 0 + df_tmp = select(df, cols.cols, copycols=false) + if ncol(df_tmp) == 0 throw(ArgumentError("At least one column must be passed to filter on")) end - return _filter_helper_astable(df, Tables.namedtupleiterator(dff), f) + return _filter_helper_astable(df, Tables.namedtupleiterator(df_tmp), f) end _filter_helper_astable(df::AbstractDataFrame, nti::Tables.NamedTupleIterator, f) = df[(x -> f(x)::Bool).(nti), :] """ - filter!(function, df::AbstractDataFrame) - filter!(cols => function, df::AbstractDataFrame) - -Remove rows from data frame `df` for which `function` returns `false`. - -If `cols` is not specified then the function is passed `DataFrameRow`s. -If `cols` is specified then the function is passed elements of the corresponding -columns as separate positional arguments, unless `cols` is an `AsTable` selector, -in which case a `NamedTuple` of these arguments is passed. -`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR), -and column duplicates are allowed if a vector of `Symbol`s, strings, or integers -is passed. + filter!(fun, df::AbstractDataFrame) + filter!(cols => fun, df::AbstractDataFrame) + +Remove rows from data frame `df` for which `fun` returns `false`. + +If `cols` is not specified then the predicate `fun` is passed `DataFrameRow`s. + +If `cols` is specified then the predicate `fun` is passed elements of the +corresponding columns as separate positional arguments, unless `cols` is an +`AsTable` selector, in which case a `NamedTuple` of these arguments is passed. +`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR), and +column duplicates are allowed if a vector of `Symbol`s, strings, or integers is +passed. Passing `cols` leads to a more efficient execution of the operation for large data frames. diff --git a/src/groupeddataframe/groupeddataframe.jl b/src/groupeddataframe/groupeddataframe.jl index 5a4349845e..49bb618979 100644 --- a/src/groupeddataframe/groupeddataframe.jl +++ b/src/groupeddataframe/groupeddataframe.jl @@ -568,3 +568,102 @@ function Base.get(gd::GroupedDataFrame, key::Union{Tuple, NamedTuple}, default) return default end end + +""" + filter(fun, gdf::GroupedDataFrame) + filter(cols => fun, gdf::GroupedDataFrame) + +Return a new `GroupedDataFrame` containing only groups for which `fun` +returns `true`. + +If `cols` is not specified then the predicate `fun` is called with a +`SubDataFrame` for each group. + +If `cols` is specified then the predicate `fun` is called for each group with +views of the corresponding columns as separate positional arguments, unless +`cols` is an `AsTable` selector, in which case a `NamedTuple` of these arguments +is passed. `cols` can be any column selector ($COLUMNINDEX_STR; +$MULTICOLUMNINDEX_STR), and column duplicates are allowed if a vector of +`Symbol`s, strings, or integers is passed. + +# Examples +``` +julia> df = DataFrame(g=[1, 2], x=['a', 'b']); + +julia> gd = groupby(df, :g) +GroupedDataFrame with 2 groups based on key: g +First Group (1 row): g = 1 +│ Row │ g │ x │ +│ │ Int64 │ Char │ +├─────┼───────┼──────┤ +│ 1 │ 1 │ 'a' │ +⋮ +Last Group (1 row): g = 2 +│ Row │ g │ x │ +│ │ Int64 │ Char │ +├─────┼───────┼──────┤ +│ 1 │ 2 │ 'b' │ + +julia> filter(x -> x.x[1] == 'a', gd) +GroupedDataFrame with 1 group based on key: g +First Group (1 row): g = 1 +│ Row │ g │ x │ +│ │ Int64 │ Char │ +├─────┼───────┼──────┤ +│ 1 │ 1 │ 'a' │ + +julia> filter(:x => x -> x[1] == 'a', gd) +GroupedDataFrame with 1 group based on key: g +First Group (1 row): g = 1 +│ Row │ g │ x │ +│ │ Int64 │ Char │ +├─────┼───────┼──────┤ +│ 1 │ 1 │ 'a' │ + +``` +""" +Base.filter(f, gdf::GroupedDataFrame) = + gdf[[f(sdf)::Bool for sdf in gdf]] +Base.filter((col, f)::Pair{<:ColumnIndex}, gdf::GroupedDataFrame) = + _filter_helper(gdf, f, gdf.idx, gdf.starts, gdf.ends, parent(gdf)[!, col]) +Base.filter((cols, f)::Pair{<:AbstractVector{Symbol}}, gdf::GroupedDataFrame) = + filter([index(parent(gdf))[col] for col in cols] => f, gdf) +Base.filter((cols, f)::Pair{<:AbstractVector{<:AbstractString}}, gdf::GroupedDataFrame) = + filter([index(parent(gdf))[col] for col in cols] => f, gdf) +Base.filter((cols, f)::Pair, gdf::GroupedDataFrame) = + filter(index(parent(gdf))[cols] => f, gdf) +Base.filter((cols, f)::Pair{<:AbstractVector{Int}}, gdf::GroupedDataFrame) = + _filter_helper(gdf, f, gdf.idx, gdf.starts, gdf.ends, (parent(gdf)[!, i] for i in cols)...) + +function _filter_helper(gdf::GroupedDataFrame, f, idx::Vector{Int}, + starts::Vector{Int}, ends::Vector{Int}, cols...) + function mapper(i::Integer) + idxs = idx[starts[i]:ends[i]] + return map(x -> view(x, idxs), cols) + end + + if length(cols) == 0 + throw(ArgumentError("At least one column must be passed to filter on")) + end + sel = [f(mapper(i)...)::Bool for i in 1:length(gdf)] + return gdf[sel] +end + +function Base.filter((cols, f)::Pair{<:AsTable}, gdf::GroupedDataFrame) + df_tmp = select(parent(gdf), cols.cols, copycols=false) + if ncol(df_tmp) == 0 + throw(ArgumentError("At least one column must be passed to filter on")) + end + return _filter_helper_astable(gdf, Tables.columntable(df_tmp), f, + gdf.idx, gdf.starts, gdf.ends) +end + +function _filter_helper_astable(gdf::GroupedDataFrame, nt::NamedTuple, f, + idx::Vector{Int}, starts::Vector{Int}, ends::Vector{Int}) + function mapper(i::Integer) + idxs = idx[starts[i]:ends[i]] + return map(x -> view(x, idxs), nt) + end + + return gdf[[f(mapper(i))::Bool for i in 1:length(gdf)]] +end diff --git a/test/grouping.jl b/test/grouping.jl index 66315532a9..0b0c849b17 100644 --- a/test/grouping.jl +++ b/test/grouping.jl @@ -2338,4 +2338,43 @@ end @test eltype(df2.a) === eltype(df2.b) === Union{UInt, Missing} end +@testset "filter" begin + for df in (DataFrame(g1=[1, 3, 2, 1, 4, 1, 2, 5], x1=1:8, + g2=[1, 3, 2, 1, 4, 1, 2, 5], x2=1:8), + view(DataFrame(g1=[1, 3, 2, 1, 4, 1, 2, 5, 4, 5], x1=1:10, + g2=[1, 3, 2, 1, 4, 1, 2, 5, 4, 5], x2=1:10, y=1:10), + 1:8, Not(:y))) + for gcols in (:g1, [:g1, :g2]), cutoff in (1, 0, 10), + predicate in (x -> nrow(x) > cutoff, + 1 => x -> length(x) > cutoff, + :x1 => x -> length(x) > cutoff, + "x1" => x -> length(x) > cutoff, + [1, 2] => (x1, x2) -> length(x1) > cutoff, + [:x1, :x2] => (x1, x2) -> length(x1) > cutoff, + ["x1", "x2"] => (x1, x2) -> length(x1) > cutoff, + r"x" => (x1, x2) -> length(x1) > cutoff, + AsTable(:x1) => x -> length(x.x1) > cutoff, + AsTable(r"x") => x -> length(x.x1) > cutoff) + gdf1 = groupby_checked(df, gcols) + gdf2 = filter(predicate, gdf1) + if cutoff == 1 + @test getindex.(keys(gdf2), 1) == 1:2 + elseif cutoff == 0 + @test gdf1 == gdf2 + elseif cutoff == 10 + @test isempty(gdf2) + end + end + + @test_throws TypeError filter(x -> 1, groupby_checked(df, :g1)) + @test_throws TypeError filter(r"x" => (x...) -> 1, groupby_checked(df, :g1)) + @test_throws TypeError filter(AsTable(r"x") => (x...) -> 1, groupby_checked(df, :g1)) + + @test_throws ArgumentError filter(r"y" => (x...) -> true, groupby_checked(df, :g1)) + @test_throws ArgumentError filter([] => (x...) -> true, groupby_checked(df, :g1)) + @test_throws ArgumentError filter(AsTable(r"y") => (x...) -> true, groupby_checked(df, :g1)) + @test_throws ArgumentError filter(AsTable([]) => (x...) -> true, groupby_checked(df, :g1)) + end +end + end # module