Skip to content

Commit

Permalink
[Enhancement](doris-future) Support "REGR_" aggregation functions (PA…
Browse files Browse the repository at this point in the history
…RT II) (#41240)

## Proposed changes

Issue Number: close #38975

<!--Describe your changes.-->
```sql
mysql> select * from test;
+------+------+------+
| id   | x    | y    |
+------+------+------+
|    1 |   18 |   13 |
|    3 |   12 |    2 |
|    5 |   10 |   20 |
|    2 |   14 |   27 |
|    4 |    5 |    6 |
+------+------+------+
5 rows in set (0.07 sec)

mysql> select regr_slope(y,x) , regr_intercept(y,x) from test;
+--------------------+----------------------+
| regr_slope(y, x)   | regr_intercept(y, x) |
+--------------------+----------------------+
| 0.6853448275862069 |    5.512931034482759 |
+--------------------+----------------------+
1 row in set (0.15 sec)

```

---------

Co-authored-by: zhiqiang-hhhh <seuhezhiqiang@163.com>
  • Loading branch information
Yoruet and zhiqiang-hhhh authored Oct 8, 2024
1 parent 1fc6b35 commit 0e5fd2f
Show file tree
Hide file tree
Showing 13 changed files with 1,212 additions and 8 deletions.
93 changes: 93 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "vec/aggregate_functions/aggregate_function_regr_union.h"

#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"

namespace doris::vectorized {

template <typename T, template <typename> class StatFunctionTemplate>
AggregateFunctionPtr type_dispatch_for_aggregate_function_regr(const DataTypes& argument_types,
const bool& result_is_nullable,
bool y_nullable_input,
bool x_nullable_input) {
if (y_nullable_input) {
if (x_nullable_input) {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, true, true>>(
argument_types, result_is_nullable);
} else {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, true, false>>(
argument_types, result_is_nullable);
}
} else {
if (x_nullable_input) {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, false, true>>(
argument_types, result_is_nullable);
} else {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, false, false>>(
argument_types, result_is_nullable);
}
}
}

template <template <typename> class StatFunctionTemplate>
AggregateFunctionPtr create_aggregate_function_regr(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
if (argument_types.size() != 2) {
LOG(WARNING) << "aggregate function " << name << " requires exactly 2 arguments";
return nullptr;
}
if (!result_is_nullable) {
LOG(WARNING) << "aggregate function " << name << " requires nullable result type";
return nullptr;
}

bool y_nullable_input = argument_types[0]->is_nullable();
bool x_nullable_input = argument_types[1]->is_nullable();
WhichDataType y_type(remove_nullable(argument_types[0]));
WhichDataType x_type(remove_nullable(argument_types[1]));

#define DISPATCH(TYPE) \
if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \
return type_dispatch_for_aggregate_function_regr<TYPE, StatFunctionTemplate>( \
argument_types, result_is_nullable, y_nullable_input, x_nullable_input);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << "unsupported input types " << argument_types[0]->get_name() << " and "
<< argument_types[1]->get_name() << " for aggregate function " << name;
return nullptr;
}

void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("regr_slope", create_aggregate_function_regr<RegrSlopeFunc>);
factory.register_function_both("regr_intercept",
create_aggregate_function_regr<RegrInterceptFunc>);
}
} // namespace doris::vectorized
216 changes: 216 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_regr_union.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <cmath>
#include <cstdint>
#include <string>
#include <type_traits>

#include "common/exception.h"
#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/io/io_helper.h"

namespace doris::vectorized {

template <typename T>
struct AggregateFunctionRegrData {
using Type = T;
UInt64 count = 0;
Float64 sum_x {};
Float64 sum_y {};
Float64 sum_of_x_mul_y {};
Float64 sum_of_x_squared {};

void write(BufferWritable& buf) const {
write_binary(sum_x, buf);
write_binary(sum_y, buf);
write_binary(sum_of_x_mul_y, buf);
write_binary(sum_of_x_squared, buf);
write_binary(count, buf);
}

void read(BufferReadable& buf) {
read_binary(sum_x, buf);
read_binary(sum_y, buf);
read_binary(sum_of_x_mul_y, buf);
read_binary(sum_of_x_squared, buf);
read_binary(count, buf);
}

void reset() {
sum_x = {};
sum_y = {};
sum_of_x_mul_y = {};
sum_of_x_squared = {};
count = 0;
}

void merge(const AggregateFunctionRegrData& rhs) {
if (rhs.count == 0) {
return;
}
sum_x += rhs.sum_x;
sum_y += rhs.sum_y;
sum_of_x_mul_y += rhs.sum_of_x_mul_y;
sum_of_x_squared += rhs.sum_of_x_squared;
count += rhs.count;
}

void add(T value_y, T value_x) {
sum_x += value_x;
sum_y += value_y;
sum_of_x_mul_y += value_x * value_y;
sum_of_x_squared += value_x * value_x;
count += 1;
}

Float64 get_slope() const {
Float64 denominator = count * sum_of_x_squared - sum_x * sum_x;
if (count < 2 || denominator == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
Float64 slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
return slope;
}
};

template <typename T>
struct RegrSlopeFunc : AggregateFunctionRegrData<T> {
static constexpr const char* name = "regr_slope";

Float64 get_result() const { return this->get_slope(); }
};

template <typename T>
struct RegrInterceptFunc : AggregateFunctionRegrData<T> {
static constexpr const char* name = "regr_intercept";

Float64 get_result() const {
auto slope = this->get_slope();
if (std::isnan(slope)) {
return slope;
} else {
Float64 intercept = (this->sum_y - slope * this->sum_x) / this->count;
return intercept;
}
}
};

template <typename RegrFunc, bool y_nullable, bool x_nullable>
class AggregateFunctionRegrSimple
: public IAggregateFunctionDataHelper<
RegrFunc, AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>> {
public:
using Type = typename RegrFunc::Type;
using XInputCol = ColumnVector<Type>;
using YInputCol = ColumnVector<Type>;
using ResultCol = ColumnVector<Float64>;

explicit AggregateFunctionRegrSimple(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<
RegrFunc, AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>>(
argument_types_) {
DCHECK(!argument_types_.empty());
}

String get_name() const override { return RegrFunc::name; }

DataTypePtr get_return_type() const override {
return make_nullable(std::make_shared<DataTypeFloat64>());
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
bool y_null = false;
bool x_null = false;
const YInputCol* y_nested_column = nullptr;
const XInputCol* x_nested_column = nullptr;

if constexpr (y_nullable) {
const ColumnNullable& y_column_nullable =
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
y_null = y_column_nullable.is_null_at(row_num);
y_nested_column = assert_cast<const YInputCol*, TypeCheckOnRelease::DISABLE>(
y_column_nullable.get_nested_column_ptr().get());
} else {
y_nested_column = assert_cast<const YInputCol*, TypeCheckOnRelease::DISABLE>(
(*columns[0]).get_ptr().get());
}

if constexpr (x_nullable) {
const ColumnNullable& x_column_nullable =
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[1]);
x_null = x_column_nullable.is_null_at(row_num);
x_nested_column = assert_cast<const XInputCol*, TypeCheckOnRelease::DISABLE>(
x_column_nullable.get_nested_column_ptr().get());
} else {
x_nested_column = assert_cast<const XInputCol*, TypeCheckOnRelease::DISABLE>(
(*columns[1]).get_ptr().get());
}

if (x_null || y_null) {
return;
}

Type y_value = y_nested_column->get_data()[row_num];
Type x_value = x_nested_column->get_data()[row_num];

this->data(place).add(y_value, x_value);
}

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).merge(this->data(rhs));
}

void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
const auto& data = this->data(place);
auto& dst_column_with_nullable = assert_cast<ColumnNullable&>(to);
auto& dst_column = assert_cast<ResultCol&>(dst_column_with_nullable.get_nested_column());
Float64 result = data.get_result();
if (std::isnan(result)) {
dst_column_with_nullable.get_null_map_data().push_back(1);
dst_column.insert_default();
} else {
dst_column_with_nullable.get_null_map_data().push_back(0);
dst_column.get_data().push_back(result);
}
}
};
} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& fact
void register_aggregate_function_percentile_old(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_window_funnel_old(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory);
Expand Down Expand Up @@ -102,6 +103,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_percentile_approx(instance);
register_aggregate_function_window_funnel(instance);
register_aggregate_function_window_funnel_old(instance);
register_aggregate_function_regr_union(instance);
register_aggregate_function_retention(instance);
register_aggregate_function_orthogonal_bitmap(instance);
register_aggregate_function_collect_list(instance);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class AggregateFunction extends Function {
"ndv_no_finalize", "percentile_array", "histogram",
FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG, FunctionSet.ARRAY_AGG,
FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET, FunctionSet.GROUP_ARRAY_INTERSECT,
FunctionSet.SUM0, FunctionSet.MULTI_DISTINCT_SUM0);
FunctionSet.SUM0, FunctionSet.MULTI_DISTINCT_SUM0, FunctionSet.REGR_INTERCEPT, FunctionSet.REGR_SLOPE);

public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx", "first_value",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApproxWeighted;
import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray;
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept;
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope;
import org.apache.doris.nereids.trees.expressions.functions.agg.Retention;
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch;
Expand Down Expand Up @@ -129,13 +131,15 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(Ndv.class, "approx_count_distinct", "ndv"),
agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"),
agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"),
agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"),
agg(Percentile.class, "percentile"),
agg(PercentileApprox.class, "percentile_approx"),
agg(PercentileApproxWeighted.class, "percentile_approx_weighted"),
agg(PercentileArray.class, "percentile_array"),
agg(QuantileUnion.class, "quantile_union"),
agg(Retention.class, "retention"),
agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"),
agg(Percentile.class, "percentile"),
agg(PercentileApprox.class, "percentile_approx"),
agg(PercentileApproxWeighted.class, "percentile_approx_weighted"),
agg(PercentileArray.class, "percentile_array"),
agg(QuantileUnion.class, "quantile_union"),
agg(RegrIntercept.class, "regr_intercept"),
agg(RegrSlope.class, "regr_slope"),
agg(Retention.class, "retention"),
agg(SequenceCount.class, "sequence_count"),
agg(SequenceMatch.class, "sequence_match"),
agg(Stddev.class, "stddev_pop", "stddev"),
Expand Down
Loading

0 comments on commit 0e5fd2f

Please sign in to comment.