Skip to content

Commit

Permalink
[feature](agg-func) agg function linear_histogram (apache#39546)
Browse files Browse the repository at this point in the history
add aggregate function: `linear_histogram(expr, DOUBLE interval[, DOUBLE offset)`

The linear_histogram function is used to describe the distribution of
the data, It uses an "equal width" bucking strategy, and divides the
data into buckets according to the value of the data.
  • Loading branch information
joker-star-l authored and garenshi committed Oct 16, 2024
1 parent 2989016 commit 94403f7
Show file tree
Hide file tree
Showing 19 changed files with 1,718 additions and 355 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "vec/aggregate_functions/aggregate_function_linear_histogram.h"

#include "vec/aggregate_functions/helpers.h"

namespace doris::vectorized {

const std::string AggregateFunctionLinearHistogramConsts::NAME = "linear_histogram";

template <typename T>
AggregateFunctionPtr create_agg_function_linear_histogram(const DataTypes& argument_types,
const bool result_is_nullable) {
bool has_offset = (argument_types.size() == 3);

if (has_offset) {
return creator_without_type::create<
AggregateFunctionLinearHistogram<T, AggregateFunctionLinearHistogramData<T>, true>>(
argument_types, result_is_nullable);
} else {
return creator_without_type::create<AggregateFunctionLinearHistogram<
T, AggregateFunctionLinearHistogramData<T>, false>>(argument_types,
result_is_nullable);
}
}

AggregateFunctionPtr create_aggregate_function_linear_histogram(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
if (type.idx == TypeIndex::TYPE) \
return create_agg_function_linear_histogram<TYPE>(argument_types, result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}",
argument_types[0]->get_name(), name);
return nullptr;
}

void register_aggregate_function_linear_histogram(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both(AggregateFunctionLinearHistogramConsts::NAME,
create_aggregate_function_linear_histogram);
}

} // namespace doris::vectorized
258 changes: 258 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_linear_histogram.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <rapidjson/document.h>
#include <rapidjson/prettywriter.h>
#include <rapidjson/stringbuffer.h>

#include <unordered_map>
#include <vector>

#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/io/io_helper.h"

// TODO: optimize count=0
// TODO: support datetime
// TODO: support foreach

namespace doris::vectorized {

template <typename T>
struct AggregateFunctionLinearHistogramData {
// bucket key limits
const static int32_t MIN_BUCKET_KEY = std::numeric_limits<int32_t>::min();
const static int32_t MAX_BUCKET_KEY = std::numeric_limits<int32_t>::max();

private:
// influxdb use double
double interval = 0;
double offset;
double lower; // not used yet
double upper; // not used yet
std::unordered_map<int32_t, size_t,
decltype([](int32_t key) { return static_cast<size_t>(key); })>
buckets;

public:
// reset
void reset() {
offset = 0;
interval = 0;
buckets.clear();
}

void set_parameters(double input_interval, double input_offset) {
interval = input_interval;
offset = input_offset;
}

// add
void add(const T& value, UInt32 scale) {
double val = 0;
if constexpr (IsDecimalNumber<T>) {
using NativeType = typename T::NativeType;
val = static_cast<double>(value.value) / decimal_scale_multiplier<NativeType>(scale);
} else {
val = static_cast<double>(value);
}
double key = std::floor((val - offset) / interval);
if (key <= MIN_BUCKET_KEY || key >= MAX_BUCKET_KEY) {
throw doris::Exception(ErrorCode::INVALID_ARGUMENT, "{} exceeds the bucket range limit",
value);
}
buckets[static_cast<int32_t>(key)]++;
}

// merge
void merge(const AggregateFunctionLinearHistogramData& rhs) {
if (rhs.interval == 0) {
return;
}

interval = rhs.interval;
offset = rhs.offset;

for (const auto& [key, count] : rhs.buckets) {
buckets[key] += count;
}
}

// write
void write(BufferWritable& buf) const {
write_binary(offset, buf);
write_binary(interval, buf);
write_binary(lower, buf);
write_binary(upper, buf);
write_binary(buckets.size(), buf);
for (const auto& [key, count] : buckets) {
write_binary(key, buf);
write_binary(count, buf);
}
}

// read
void read(BufferReadable& buf) {
read_binary(offset, buf);
read_binary(interval, buf);
read_binary(lower, buf);
read_binary(upper, buf);
size_t size;
read_binary(size, buf);
for (size_t i = 0; i < size; i++) {
int32_t key;
size_t count;
read_binary(key, buf);
read_binary(count, buf);
buckets[key] = count;
}
}

// insert_result_into
void insert_result_into(IColumn& to) const {
std::vector<std::pair<int32_t, size_t>> bucket_vector(buckets.begin(), buckets.end());
std::sort(bucket_vector.begin(), bucket_vector.end(),
[](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; });

rapidjson::Document doc;
doc.SetObject();
rapidjson::Document::AllocatorType& allocator = doc.GetAllocator();

unsigned num_buckets = bucket_vector.empty() ? 0
: bucket_vector.rbegin()->first -
bucket_vector.begin()->first + 1;
doc.AddMember("num_buckets", num_buckets, allocator);

rapidjson::Value bucket_arr(rapidjson::kArrayType);
bucket_arr.Reserve(num_buckets, allocator);

if (num_buckets > 0) {
int32_t idx = bucket_vector.begin()->first;
double left = bucket_vector.begin()->first * interval + offset;
size_t count = 0;
size_t acc_count = 0;

for (const auto& [key, count_] : bucket_vector) {
for (; idx <= key; ++idx) {
rapidjson::Value bucket_json(rapidjson::kObjectType);
bucket_json.AddMember("lower", left, allocator);
left += interval;
bucket_json.AddMember("upper", left, allocator);
count = (idx == key) ? count_ : 0;
bucket_json.AddMember("count", count, allocator);
acc_count += count;
bucket_json.AddMember("acc_count", acc_count, allocator);

bucket_arr.PushBack(bucket_json, allocator);
}
}
}

doc.AddMember("buckets", bucket_arr, allocator);

rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);

auto& column = assert_cast<ColumnString&>(to);
column.insert_data(buffer.GetString(), buffer.GetSize());
}
};

class AggregateFunctionLinearHistogramConsts {
public:
const static std::string NAME;
};

template <typename T, typename Data, bool has_offset>
class AggregateFunctionLinearHistogram final
: public IAggregateFunctionDataHelper<
Data, AggregateFunctionLinearHistogram<T, Data, has_offset>> {
public:
using ColVecType = ColumnVectorOrDecimal<T>;

AggregateFunctionLinearHistogram(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data,
AggregateFunctionLinearHistogram<T, Data, has_offset>>(
argument_types_),
scale(get_decimal_scale(*argument_types_[0])) {}

std::string get_name() const override { return AggregateFunctionLinearHistogramConsts::NAME; }

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeString>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
double interval =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1])
.get_data()[row_num];
if (interval <= 0) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"Invalid interval {}, row_num {}, interval should be larger than 0", interval,
row_num);
}

double offset = 0;
if constexpr (has_offset) {
offset = assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2])
.get_data()[row_num];
if (offset < 0 || offset >= interval) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"Invalid offset {}, row_num {}, offset should be in [0, interval)", offset,
row_num);
}
}

this->data(place).set_parameters(interval, offset);

this->data(place).add(
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data()[row_num],
scale);
}

void reset(AggregateDataPtr place) const override { this->data(place).reset(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
this->data(place).merge(this->data(rhs));
}

void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
this->data(place).insert_result_into(to);
}

private:
UInt32 scale;
};

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& fa
void register_aggregate_function_sequence_match(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_linear_histogram(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory);
Expand Down Expand Up @@ -110,6 +111,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_sequence_match(instance);
register_aggregate_function_avg_weighted(instance);
register_aggregate_function_histogram(instance);
register_aggregate_function_linear_histogram(instance);
register_aggregate_function_map_agg(instance);
register_aggregate_function_bitmap_agg(instance);

Expand Down
Loading

0 comments on commit 94403f7

Please sign in to comment.