diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index d10329ca0f2..379750bb0b7 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,9 @@ enum Kind { DENSE_RANK(29), PERCENT_RANK(30), TDIGEST(31), // This can take a delta argument for accuracy level - MERGE_TDIGEST(32); // This can take a delta argument for accuracy level + MERGE_TDIGEST(32), // This can take a delta argument for accuracy level + HISTOGRAM(33), + MERGE_HISTOGRAM(34); final int nativeId; @@ -918,6 +920,26 @@ static TDigestAggregation mergeTDigest(int delta) { return new TDigestAggregation(Kind.MERGE_TDIGEST, delta); } + static final class HistogramAggregation extends NoParamAggregation { + private HistogramAggregation() { + super(Kind.HISTOGRAM); + } + } + + static final class MergeHistogramAggregation extends NoParamAggregation { + private MergeHistogramAggregation() { + super(Kind.MERGE_HISTOGRAM); + } + } + + static HistogramAggregation histogram() { + return new HistogramAggregation(); + } + + static MergeHistogramAggregation mergeHistogram() { + return new MergeHistogramAggregation(); + } + /** * Create one of the aggregations that only needs a kind, no other parameters. This does not * work for all types and for code safety reasons each kind is added separately. diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index 500d18f7eae..0fae33927b6 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -315,4 +315,26 @@ public static GroupByAggregation createTDigest(int delta) { public static GroupByAggregation mergeTDigest(int delta) { return new GroupByAggregation(Aggregation.mergeTDigest(delta)); } + + /** + * Histogram aggregation, computing the frequencies for each unique row. + * + * A histogram is given as a lists column, in which the first child stores unique rows from + * the input values and the second child stores their corresponding frequencies. + * + * @return A lists of structs column in which each list contains a histogram corresponding to + * an input key. + */ + public static GroupByAggregation histogram() { + return new GroupByAggregation(Aggregation.histogram()); + } + + /** + * MergeHistogram aggregation, to merge multiple histograms. + * + * @return A new histogram in which the frequencies of the unique rows are sum up. + */ + public static GroupByAggregation mergeHistogram() { + return new GroupByAggregation(Aggregation.mergeHistogram()); + } } diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java index eab1c94fd2c..ba8ae379bae 100644 --- a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -286,4 +286,22 @@ public static ReductionAggregation mergeSets(NullEquality nullEquality, NaNEqual return new ReductionAggregation(Aggregation.mergeSets(nullEquality, nanEquality)); } + /** + * Create HistogramAggregation, computing the frequencies for each unique row. + * + * @return A structs column in which the first child stores unique rows from the input and the + * second child stores their corresponding frequencies. + */ + public static ReductionAggregation histogram() { + return new ReductionAggregation(Aggregation.histogram()); + } + + /** + * Create MergeHistogramAggregation, to merge multiple histograms. + * + * @return A new histogram in which the frequencies of the unique rows are sum up. + */ + public static ReductionAggregation mergeHistogram() { + return new ReductionAggregation(Aggregation.mergeHistogram()); + } } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 6ac73282615..bc62e95c36a 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,6 +90,11 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv case 30: // ANSI SQL PERCENT_RANK return cudf::make_rank_aggregation(cudf::rank_method::MIN, {}, cudf::null_policy::INCLUDE, {}, cudf::rank_percentage::ONE_NORMALIZED); + case 33: // HISTOGRAM + return cudf::make_histogram_aggregation(); + case 34: // MERGE_HISTOGRAM + return cudf::make_merge_histogram_aggregation(); + default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); } }(); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 59f0d180c6e..faa73ac4322 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4129,6 +4129,115 @@ void testMergeTDigestReduction() { } } + @Test + void testGroupbyHistogram() { + StructType histogramStruct = new StructType(false, + new BasicType(false, DType.INT32), // values + new BasicType(false, DType.INT64)); // frequencies + ListType histogramList = new ListType(false, histogramStruct); + + // key = 0: values = [2, 2, -3, -2, 2] + // key = 1: values = [2, 0, 5, 2, 1] + // key = 2: values = [-3, 1, 1, 2, 2] + try (Table input = new Table.TestBuilder() + .column(2, 0, 2, 1, 1, 1, 0, 0, 0, 1, 2, 2, 1, 0, 2) + .column(-3, 2, 1, 2, 0, 5, 2, -3, -2, 2, 1, 2, 1, 2, 2) + .build(); + Table result = input.groupBy(0) + .aggregate(GroupByAggregation.histogram().onColumn(1)); + Table sortedResult = result.orderBy(OrderByArg.asc(0)); + ColumnVector sortedOutHistograms = sortedResult.getColumn(1).listSortRows(false, false); + + ColumnVector expectedKeys = ColumnVector.fromInts(0, 1, 2); + ColumnVector expectedHistograms = ColumnVector.fromLists(histogramList, + Arrays.asList(new StructData(-3, 1L), new StructData(-2, 1L), new StructData(2, 3L)), + Arrays.asList(new StructData(0, 1L), new StructData(1, 1L), new StructData(2, 2L), + new StructData(5, 1L)), + Arrays.asList(new StructData(-3, 1L), new StructData(1, 2L), new StructData(2, 2L))) + ) { + assertColumnsAreEqual(expectedKeys, sortedResult.getColumn(0)); + assertColumnsAreEqual(expectedHistograms, sortedOutHistograms); + } + } + + @Test + void testGroupbyMergeHistogram() { + StructType histogramStruct = new StructType(false, + new BasicType(false, DType.INT32), // values + new BasicType(false, DType.INT64)); // frequencies + ListType histogramList = new ListType(false, histogramStruct); + + // key = 0: histograms = [[<-3, 1>, <-2, 1>, <2, 3>], [<0, 1>, <1, 1>], [<-3, 3>, <0, 1>, <1, 2>]] + // key = 1: histograms = [[<-2, 1>, <1, 3>, <2, 2>], [<0, 2>, <1, 1>, <2, 2>]] + try (Table input = new Table.TestBuilder() + .column(0, 1, 0, 1, 0) + .column(histogramStruct, + new StructData[]{new StructData(-3, 1L), new StructData(-2, 1L), new StructData(2, 3L)}, + new StructData[]{new StructData(-2, 1L), new StructData(1, 3L), new StructData(2, 2L)}, + new StructData[]{new StructData(0, 1L), new StructData(1, 1L)}, + new StructData[]{new StructData(0, 2L), new StructData(1, 1L), new StructData(2, 2L)}, + new StructData[]{new StructData(-3, 3L), new StructData(0, 1L), new StructData(1, 2L)}) + .build(); + Table result = input.groupBy(0) + .aggregate(GroupByAggregation.mergeHistogram().onColumn(1)); + Table sortedResult = result.orderBy(OrderByArg.asc(0)); + ColumnVector sortedOutHistograms = sortedResult.getColumn(1).listSortRows(false, false); + + ColumnVector expectedKeys = ColumnVector.fromInts(0, 1); + ColumnVector expectedHistograms = ColumnVector.fromLists(histogramList, + Arrays.asList(new StructData(-3, 4L), new StructData(-2, 1L), new StructData(0, 2L), + new StructData(1, 3L), new StructData(2, 3L)), + Arrays.asList(new StructData(-2, 1L), new StructData(0, 2L), new StructData(1, 4L), + new StructData(2, 4L))) + ) { + assertColumnsAreEqual(expectedKeys, sortedResult.getColumn(0)); + assertColumnsAreEqual(expectedHistograms, sortedOutHistograms); + } + } + + @Test + void testReductionHistogram() { + StructType histogramStruct = new StructType(false, + new BasicType(false, DType.INT32), // values + new BasicType(false, DType.INT64)); // frequencies + + try (ColumnVector input = ColumnVector.fromInts(-3, 2, 1, 2, 0, 5, 2, -3, -2, 2, 1); + Scalar result = input.reduce(ReductionAggregation.histogram(), DType.LIST); + ColumnVector resultCV = result.getListAsColumnView().copyToColumnVector(); + Table resultTable = new Table(resultCV); + Table sortedResult = resultTable.orderBy(OrderByArg.asc(0)); + + ColumnVector expectedHistograms = ColumnVector.fromStructs(histogramStruct, + new StructData(-3, 2L), new StructData(-2, 1L), new StructData(0, 1L), + new StructData(1, 2L), new StructData(2, 4L), new StructData(5, 1L)) + ) { + assertColumnsAreEqual(expectedHistograms, sortedResult.getColumn(0)); + } + } + + @Test + void testReductionMergeHistogram() { + StructType histogramStruct = new StructType(false, + new BasicType(false, DType.INT32), // values + new BasicType(false, DType.INT64)); // frequencies + + try (ColumnVector input = ColumnVector.fromStructs(histogramStruct, + new StructData(-3, 2L), new StructData(2, 1L), new StructData(1, 1L), + new StructData(2, 2L), new StructData(0, 4L), new StructData(5, 1L), + new StructData(2, 2L), new StructData(-3, 3L), new StructData(-2, 5L), + new StructData(2, 3L), new StructData(1, 4L)); + Scalar result = input.reduce(ReductionAggregation.mergeHistogram(), DType.LIST); + ColumnVector resultCV = result.getListAsColumnView().copyToColumnVector(); + Table resultTable = new Table(resultCV); + Table sortedResult = resultTable.orderBy(OrderByArg.asc(0)); + + ColumnVector expectedHistograms = ColumnVector.fromStructs(histogramStruct, + new StructData(-3, 5L), new StructData(-2, 5L), new StructData(0, 4L), + new StructData(1, 5L), new StructData(2, 8L), new StructData(5, 1L)) + ) { + assertColumnsAreEqual(expectedHistograms, sortedResult.getColumn(0)); + } + } @Test void testGroupByMinMaxDecimal() { try (Table t1 = new Table.TestBuilder()