Skip to content

Commit

Permalink
feat: implement AggCountWhere to support generic counting using `Fi…
Browse files Browse the repository at this point in the history
…lter` (deephaven#6497)

## Groovy Examples

```
Table countTable = table.aggBy(List.of(
        AggCountWhere("filter1", "intCol >= 50"),
        AggCountWhere("filter2", "intCol >= 50", "intCol != 80"),
        AggCountWhere("filter3", Filter.or(Filter.from("intCol >= 50", "intCol == 3"))),
        AggCountWhere("filter4", "true"),
        AggCountWhere("filter5", "false"),
        AggCountWhere("filter6", "intCol % 2 == 0"),
        AggCountWhere("filter7",
                Filter.and(Filter.or(Filter.from("false", "intCol % 3 == 0")),
                        Filter.or(Filter.from("false", "intCol % 2 == 0")))),
        AggCountWhere("filter8", "intCol % 2 == 0", "intCol % 3 == 0"),
        AggCountWhere("filter9",
                Filter.and(Filter.and(Filter.from("intCol > 0")),
                        Filter.and(Filter.from("intCol <= 10", "intCol >= 5")))),
        // Multiple input columns
        AggCountWhere("filter10", "intCol >= 5", "doubleCol <= 10.0"),
        AggCountWhere("filter11", "intCol >= 5 && intColNulls != 3 && doubleCol <= 10.0"),
        // DynamicWhereFilter
        AggCountWhere("filter12", new DynamicWhereFilter(setTable, true, MatchPairFactory.getExpressions("intCol")))),
        "Sym").sort("Sym");
```

## Python Examples

```
from deephaven import empty_table
from deephaven.agg import count_where

table_size = 120000000
table = empty_table(table_size).update(["int1M=randomInt(0,1000000)", "int640=randomInt(0,640)", "int250=randomInt(0,250)"]).select()

# zero-key
t_count = table.agg_by(aggs=count_where(col="count", filters=["int1M < 500000", "int1M > 499000"]))

# bucketed
t_count = table.agg_by(aggs=count_where(col="count", filters=["int1M < 500000", "int1M > 499000"]), by="int250")
```
  • Loading branch information
lbooker42 authored Dec 19, 2024
1 parent 9b260b3 commit 83e0c97
Show file tree
Hide file tree
Showing 67 changed files with 4,340 additions and 2,946 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import io.deephaven.engine.table.impl.by.ssmminmax.SsmChunkedMinMaxOperator;
import io.deephaven.engine.table.impl.by.ssmpercentile.SsmChunkedPercentileOperator;
import io.deephaven.engine.table.impl.select.SelectColumn;
import io.deephaven.engine.table.impl.select.WhereFilter;
import io.deephaven.engine.table.impl.sources.ReinterpretUtils;
import io.deephaven.engine.table.impl.ssms.SegmentedSortedMultiSet;
import io.deephaven.engine.table.impl.util.freezeby.FreezeByCountOperator;
Expand All @@ -107,13 +108,7 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -693,6 +688,56 @@ public void visit(@NotNull final Count count) {
addNoInputOperator(new CountAggregationOperator(count.column().name()));
}

@Override
public void visit(@NotNull final CountWhere countWhere) {
final WhereFilter[] whereFilters = WhereFilter.fromInternal(countWhere.filter());

final Map<String, RecordingInternalOperator> inputColumnRecorderMap = new HashMap<>();
final List<RecordingInternalOperator> recorderList = new ArrayList<>();
final List<RecordingInternalOperator[]> filterRecorderList = new ArrayList<>();

// Verify all the columns in the where filters are present in the table and valid for use.
for (final WhereFilter whereFilter : whereFilters) {
whereFilter.init(table.getDefinition());
if (whereFilter.isRefreshing()) {
throw new UnsupportedOperationException("AggCountWhere does not support refreshing filters");
}

// Compute which recording operators this filter will use.
final List<String> inputColumnNames = whereFilter.getColumns();
final int inputColumnCount = whereFilter.getColumns().size();
final RecordingInternalOperator[] recorders = new RecordingInternalOperator[inputColumnCount];
for (int ii = 0; ii < inputColumnCount; ++ii) {
final String inputColumnName = inputColumnNames.get(ii);
final RecordingInternalOperator recorder =
inputColumnRecorderMap.computeIfAbsent(inputColumnName, k -> {
// Create a recording operator for the column and add it to the list of operators.
final ColumnSource<?> inputSource = table.getColumnSource(inputColumnName);
final RecordingInternalOperator newRecorder =
new RecordingInternalOperator(inputColumnName, inputSource);
recorderList.add(newRecorder);
return newRecorder;
});
recorders[ii] = recorder;
}
filterRecorderList.add(recorders);
}

final RecordingInternalOperator[] recorders = recorderList.toArray(RecordingInternalOperator[]::new);
final RecordingInternalOperator[][] filterRecorders =
filterRecorderList.toArray(RecordingInternalOperator[][]::new);
final String[] inputColumnNames =
inputColumnRecorderMap.keySet().toArray(ArrayTypeUtils.EMPTY_STRING_ARRAY);

// Add the recording operators, making them dependent on all input columns so they all are populated if any
// are modified
for (final RecordingInternalOperator recorder : recorders) {
addOperator(recorder, recorder.getInputColumnSource(), inputColumnNames);
}
addOperator(new CountWhereOperator(countWhere.column().name(), whereFilters, recorders, filterRecorders),
null, inputColumnNames);
}

@Override
public void visit(@NotNull final FirstRowKey firstRowKey) {
addFirstOrLastOperators(true, firstRowKey.column().name());
Expand Down Expand Up @@ -1004,6 +1049,11 @@ public void visit(@NotNull final Count count) {
addNoInputOperator(new CountAggregationOperator(count.column().name()));
}

@Override
public void visit(@NotNull final CountWhere countWhere) {
addNoInputOperator(new CountAggregationOperator(countWhere.column().name()));
}

@Override
public void visit(@NotNull final NullColumns nullColumns) {
transformers.add(new NullColumnAggregationTransformer(nullColumns.resultColumns()));
Expand Down Expand Up @@ -1149,6 +1199,13 @@ public void visit(@NotNull final Count count) {
addOperator(makeSumOperator(resultSource.getType(), resultName, false), resultSource, resultName);
}

@Override
public void visit(@NotNull final CountWhere countWhere) {
final String resultName = countWhere.column().name();
final ColumnSource<?> resultSource = table.getColumnSource(resultName);
addOperator(makeSumOperator(resultSource.getType(), resultName, false), resultSource, resultName);
}

@Override
public void visit(@NotNull final NullColumns nullColumns) {
transformers.add(new NullColumnAggregationTransformer(nullColumns.resultColumns()));
Expand Down
Loading

0 comments on commit 83e0c97

Please sign in to comment.