diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArray.h b/dbms/src/AggregateFunctions/AggregateFunctionArray.h index 12d71a390fd..a4457c226cb 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArray.h @@ -92,7 +92,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper(*columns[i]).getData(); - const ColumnArray & first_array_column = static_cast(*columns[0]); + const auto & first_array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = first_array_column.getOffsets(); size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -101,7 +101,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper(*columns[i]); + const auto & ith_column = static_cast(*columns[i]); const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) @@ -132,6 +132,11 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperinsertResultInto(place, to, arena); } + void insertMergeResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + nested_func->insertMergeResultInto(place, to, arena); + } + bool allocatesMemoryInArena() const override { return nested_func->allocatesMemoryInArena(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h index 48681d802e3..0c158572403 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h @@ -156,7 +156,7 @@ class AggregateFunctionForEach final : public IAggregateFunctionDataHelper(*columns[i]).getData(); - const ColumnArray & first_array_column = static_cast(*columns[0]); + const auto & first_array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = first_array_column.getOffsets(); size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -165,7 +165,7 @@ class AggregateFunctionForEach final : public IAggregateFunctionDataHelper(*columns[i]); + const auto & ith_column = static_cast(*columns[i]); const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) @@ -229,11 +229,12 @@ class AggregateFunctionForEach final : public IAggregateFunctionDataHelper + void insertResultIntoImpl(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const { const AggregateFunctionForEachData & state = data(place); - ColumnArray & arr_to = static_cast(to); + auto & arr_to = static_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & elems_to = arr_to.getData(); @@ -247,6 +248,16 @@ class AggregateFunctionForEach final : public IAggregateFunctionDataHelper(place, to, arena); + } + + void insertMergeResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + insertResultIntoImpl(place, to, arena); + } + bool allocatesMemoryInArena() const override { return true; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIf.h b/dbms/src/AggregateFunctions/AggregateFunctionIf.h index b76466b8c8a..3dc8a3cc97f 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionIf.h @@ -155,6 +155,11 @@ class AggregateFunctionIf final : public IAggregateFunctionHelperinsertResultInto(place, to, arena); } + void insertMergeResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + nested_func->insertMergeResultInto(place, to, arena); + } + bool allocatesMemoryInArena() const override { return nested_func->allocatesMemoryInArena(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index 977af8117ab..7686e1a952e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -176,14 +176,18 @@ class AggregateFunctionNullBase : public IAggregateFunctionHelper } } - void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + template + void insertResultIntoImpl(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const { if constexpr (result_is_nullable) { auto & to_concrete = static_cast(to); if (getFlag(place)) { - nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena); + if constexpr (merge) + nested_function->insertMergeResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena); + else + nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena); to_concrete.getNullMapData().push_back(0); } else @@ -193,10 +197,23 @@ class AggregateFunctionNullBase : public IAggregateFunctionHelper } else { - nested_function->insertResultInto(nestedPlace(place), to, arena); + if constexpr (merge) + nested_function->insertMergeResultInto(nestedPlace(place), to, arena); + else + nested_function->insertResultInto(nestedPlace(place), to, arena); } } + void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + insertResultIntoImpl(place, to, arena); + } + + void insertMergeResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + insertResultIntoImpl(place, to, arena); + } + bool allocatesMemoryInArena() const override { return nested_function->allocatesMemoryInArena(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionState.h b/dbms/src/AggregateFunctions/AggregateFunctionState.h index da42da9fe65..a5a363dc3dd 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionState.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionState.h @@ -98,6 +98,11 @@ class AggregateFunctionState final : public IAggregateFunctionHelper(to).getData().push_back(const_cast(place)); } + void insertMergeResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override + { + assert_cast(to).insertFrom(place); + } + /// Aggregate function or aggregate function state. bool isState() const override { return true; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index 3bb2cf2747d..edbbe654988 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -88,7 +88,7 @@ class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper(*columns[0]); + const auto & array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = array_column.getOffsets(); const auto & keys_vec = static_cast &>(array_column.getData()); const size_t keys_vec_offset = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -99,7 +99,7 @@ class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper(*columns[col + 1]); + const auto & array_column = static_cast(*columns[col + 1]); const IColumn::Offsets & offsets = array_column.getOffsets(); const size_t values_vec_offset = row_num == 0 ? 0 : offsets[row_num - 1]; const size_t values_vec_size = (offsets[row_num] - values_vec_offset); diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index 4bf308dc21f..ac65df7b6ba 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -109,6 +109,18 @@ class IAggregateFunction /// Inserts results into a column. virtual void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const = 0; + // Special method for aggregate functions with -State combinator, it behaves the same way as insertResultInto, + // but if we need to insert AggregateData into ColumnAggregateFunction we use special method + // insertInto that inserts default value and then performs merge with provided AggregateData + // instead of just copying pointer to this AggregateData. Used in WindowTransform. + virtual void insertMergeResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const + { + if (isState()) + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function is marked as State but method insertMergeResultInto is not implemented"); + + insertResultInto(place, to, arena); + } + /** Returns true for aggregate functions of type -State. * They are executed as other aggregate functions, but not finalized (return an aggregation state that can be combined with another). */ diff --git a/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h b/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h index 73996be5a24..ec61de36d43 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h +++ b/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h @@ -62,7 +62,7 @@ class IAggregateFunctionCombinator const DataTypes & arguments, const Array & params) const = 0; - virtual ~IAggregateFunctionCombinator() {} + virtual ~IAggregateFunctionCombinator() = default; }; using AggregateFunctionCombinatorPtr = std::shared_ptr; diff --git a/dbms/src/Common/AlignedBuffer.cpp b/dbms/src/Common/AlignedBuffer.cpp new file mode 100644 index 00000000000..db4f4625f6e --- /dev/null +++ b/dbms/src/Common/AlignedBuffer.cpp @@ -0,0 +1,62 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int CANNOT_ALLOCATE_MEMORY; +} + + +void AlignedBuffer::alloc(size_t size, size_t alignment) +{ + void * new_buf; + int res = ::posix_memalign(&new_buf, std::max(alignment, sizeof(void *)), size); + if (0 != res) + throwFromErrno(fmt::format("Cannot allocate memory (posix_memalign), size: {}, alignment: {}.", + size, + alignment), + ErrorCodes::CANNOT_ALLOCATE_MEMORY, + res); + buf = new_buf; +} + +void AlignedBuffer::dealloc() +{ + if (buf) + ::free(buf); +} + +void AlignedBuffer::reset(size_t size, size_t alignment) +{ + dealloc(); + alloc(size, alignment); +} + +AlignedBuffer::AlignedBuffer(size_t size, size_t alignment) +{ + alloc(size, alignment); +} + +AlignedBuffer::~AlignedBuffer() +{ + dealloc(); +} + +} // namespace DB diff --git a/dbms/src/Common/AlignedBuffer.h b/dbms/src/Common/AlignedBuffer.h new file mode 100644 index 00000000000..d244760d773 --- /dev/null +++ b/dbms/src/Common/AlignedBuffer.h @@ -0,0 +1,48 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB +{ + +/** Aligned piece of memory. + * It can only be allocated and destroyed. + * MemoryTracker is not used. AlignedBuffer is intended for small pieces of memory. + */ +class AlignedBuffer : private boost::noncopyable +{ +private: + void * buf = nullptr; + + void alloc(size_t size, size_t alignment); + void dealloc(); + +public: + AlignedBuffer() = default; + AlignedBuffer(size_t size, size_t alignment); + AlignedBuffer(AlignedBuffer && old) noexcept { std::swap(buf, old.buf); } + ~AlignedBuffer(); + + void reset(size_t size, size_t alignment); + + char * data() { return static_cast(buf); } + const char * data() const { return static_cast(buf); } +}; + +} // namespace DB diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index a3d773d3cc4..a8ffb471d93 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 PingCAP, Ltd. +// Copyright 2023 PingCAP, Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,6 @@ WindowTransformAction::WindowTransformAction(const Block & input_header, const W } initialWorkspaces(); - initialPartitionAndOrderColumnIndices(); } @@ -72,6 +71,23 @@ void WindowTransformAction::initialPartitionAndOrderColumnIndices() } } +void WindowTransformAction::initialAggregateFunction(WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description) +{ + if (window_function_description.aggregate_function == nullptr) + return; + + // Some initialization for aggregate function + workspace.aggregate_function = window_function_description.aggregate_function; + const auto & aggregate_function = workspace.aggregate_function; + if (!arena && aggregate_function->allocatesMemoryInArena()) + arena = std::make_unique(); + + workspace.aggregate_function_state.reset( + aggregate_function->sizeOfData(), + aggregate_function->alignOfData()); + aggregate_function->create(workspace.aggregate_function_state.data()); +} + void WindowTransformAction::initialWorkspaces() { // Initialize window function workspaces. @@ -81,7 +97,10 @@ void WindowTransformAction::initialWorkspaces() { WindowFunctionWorkspace workspace; workspace.window_function = window_function_description.window_function; - workspace.arguments = window_function_description.arguments; + workspace.argument_column_indices = window_function_description.arguments; + workspace.argument_columns.assign(workspace.argument_column_indices.size(), nullptr); + + initialAggregateFunction(workspace, window_function_description); workspaces.push_back(std::move(workspace)); } only_have_row_number = onlyHaveRowNumber(); @@ -533,7 +552,21 @@ void WindowTransformAction::writeOutCurrentRow() for (size_t wi = 0; wi < workspaces.size(); ++wi) { auto & ws = workspaces[wi]; - ws.window_function->windowInsertResultInto(*this, wi, ws.arguments); + if (ws.window_function) + ws.window_function->windowInsertResultInto(*this, wi, ws.argument_column_indices); + else + { + const auto & block = blockAt(current_row); + IColumn * result_column = block.output_columns[wi].get(); + const auto * agg_func = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + // FIXME does it also allocate the result on the arena? + // We'll have to pass it out with blocks then... + + // We should use insertMergeResultInto to insert result into ColumnAggregateFunction + // correctly if result contains AggregateFunction's states + agg_func->insertMergeResultInto(buf, *result_column, arena.get()); + } } } @@ -568,7 +601,7 @@ bool WindowTransformAction::onlyHaveRowNumber() { for (const auto & workspace : workspaces) { - if (workspace.window_function->getName() != "row_number") + if (workspace.window_function != nullptr && workspace.window_function->getName() != "row_number") return false; } return true; @@ -619,7 +652,12 @@ void WindowTransformAction::appendBlock(Block & current_block) // Initialize output columns and add new columns to output block. for (auto & ws : workspaces) { - MutableColumnPtr res = ws.window_function->getReturnType()->createColumn(); + MutableColumnPtr res; + + if (ws.window_function != nullptr) + res = ws.window_function->getReturnType()->createColumn(); + else + res = ws.aggregate_function->getReturnType()->createColumn(); res->reserve(window_block.rows); window_block.output_columns.push_back(std::move(res)); } @@ -627,6 +665,102 @@ void WindowTransformAction::appendBlock(Block & current_block) window_block.input_columns = current_block.getColumns(); } +// Update the aggregation states after the frame has changed. +void WindowTransformAction::updateAggregationState() +{ + // Assert that the frame boundaries are known, have proper order wrt each + // other, and have not gone back wrt the previous frame. + assert(frame_started); + assert(frame_ended); + assert(frame_start <= frame_end); + assert(prev_frame_start <= prev_frame_end); + assert(prev_frame_start <= frame_start); + assert(prev_frame_end <= frame_end); + assert(partition_start <= frame_start); + assert(frame_end <= partition_end); + + // We might have to reset aggregation state and/or add some rows to it. + // Figure out what to do. + bool reset_aggregation = false; + RowNumber rows_to_add_start; + RowNumber rows_to_add_end; + if (frame_start == prev_frame_start) + { + // The frame start didn't change, add the tail rows. + reset_aggregation = false; + rows_to_add_start = prev_frame_end; + rows_to_add_end = frame_end; + } + else + { + // The frame start changed, reset the state and aggregate over the + // entire frame. This can be made per-function after we learn to + // subtract rows from some types of aggregation states, but for now we + // always have to reset when the frame start changes. + reset_aggregation = true; + rows_to_add_start = frame_start; + rows_to_add_end = frame_end; + } + + for (auto & ws : workspaces) + { + if (ws.window_function) + continue; // No need to do anything for true window functions. + + const auto * agg_func = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + + if (reset_aggregation) + { + agg_func->destroy(buf); + agg_func->create(buf); + } + + // To achieve better performance, we will have to loop over blocks and + // rows manually, instead of using advanceRowNumber(). + // For this purpose, the past-the-end block can be different than the + // block of the past-the-end row (it's usually the next block). + const auto past_the_end_block = rows_to_add_end.row == 0 + ? rows_to_add_end.block + : rows_to_add_end.block + 1; + + for (auto block_number = rows_to_add_start.block; + block_number < past_the_end_block; + ++block_number) + { + auto & block = blockAt(block_number); + + if (ws.cached_block_number != block_number) + { + for (size_t i = 0; i < ws.argument_column_indices.size(); ++i) + { + ws.argument_columns[i] = block.input_columns[ws.argument_column_indices[i]].get(); + } + ws.cached_block_number = block_number; + } + + // First and last blocks may be processed partially, and other blocks + // are processed in full. + const auto first_row = block_number == rows_to_add_start.block + ? rows_to_add_start.row + : 0; + const auto past_the_end_row = block_number == rows_to_add_end.block + ? rows_to_add_end.row + : block.rows; + + // We should add an addBatch analog that can accept a starting offset. + // For now, add the values one by one. + auto * columns = ws.argument_columns.data(); + // Removing arena.get() from the loop makes it faster somehow... + auto * arena_ptr = arena.get(); + for (auto row = first_row; row < past_the_end_row; ++row) + { + agg_func->add(buf, columns, row, arena_ptr); + } + } + } +} + void WindowTransformAction::tryCalculate() { // Start the calculations. First, advance the partition end. @@ -686,11 +820,19 @@ void WindowTransformAction::tryCalculate() assert(frame_ended); assert(frame_start <= frame_end); + // Now that we know the new frame boundaries, update the aggregation + // states. Theoretically we could do this simultaneously with moving + // the frame boundaries, but it would require some care not to + // perform unnecessary work while we are still looking for the frame + // start, so do it the simple way for now. + updateAggregationState(); + // Write out the results. // TODO execute the window function by block instead of row. writeOutCurrentRow(); prev_frame_start = frame_start; + prev_frame_end = frame_end; // Move to the next row. The frame will have to be recalculated. // The peer group start is updated at the beginning of the loop, @@ -733,6 +875,43 @@ void WindowTransformAction::tryCalculate() peer_group_last = partition_start; peer_group_start_row_number = 1; peer_group_number = 1; + + reinitializeAggFuncBeforeNextPartition(); + } +} + +void WindowTransformAction::reinitializeAggFuncBeforeNextPartition() +{ + // Reinitialize the aggregate function states because the new partition + // has started. + for (auto & ws : workspaces) + { + if (ws.window_function) + continue; + + const auto * a = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + + a->destroy(buf); + } + + // Release the arena we use for aggregate function states, so that it + // doesn't grow without limit. Not sure if it's actually correct, maybe + // it allocates the return values in the Arena as well... + if (arena) + { + arena = std::make_unique(); + } + + for (auto & ws : workspaces) + { + if (ws.window_function) + continue; + + const auto * a = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + + a->create(buf); } } @@ -743,7 +922,10 @@ void WindowTransformAction::appendInfo(FmtBuffer & buffer) const window_description.window_functions_descriptions.begin(), window_description.window_functions_descriptions.end(), [&](const auto & func, FmtBuffer & b) { - b.append(func.window_function->getName()); + if (func.window_function != nullptr) + b.append(func.window_function->getName()); + else + b.append(func.aggregate_function->getName()); }, ", "); buffer.fmtAppend( diff --git a/dbms/src/DataStreams/WindowBlockInputStream.h b/dbms/src/DataStreams/WindowBlockInputStream.h index fca4fa7ea0e..394a0e0436e 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.h +++ b/dbms/src/DataStreams/WindowBlockInputStream.h @@ -1,4 +1,4 @@ -// Copyright 2022 PingCAP, Ltd. +// Copyright 2023 PingCAP, Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ #pragma once +#include +#include #include #include #include @@ -27,10 +29,19 @@ namespace DB // Runtime data for computing one window function. struct WindowFunctionWorkspace { - // TODO add aggregation function WindowFunctionPtr window_function = nullptr; - ColumnNumbers arguments; + AggregateFunctionPtr aggregate_function; + + // Will not be initialized for a pure window function. + mutable AlignedBuffer aggregate_function_state; + + ColumnNumbers argument_column_indices; + + // Argument columns. Be careful, this is a per-block cache. + std::vector argument_columns; + + UInt64 cached_block_number = std::numeric_limits::max(); }; struct WindowBlock @@ -80,10 +91,9 @@ struct WindowTransformAction // distance is left - right. UInt64 distance(RowNumber left, RowNumber right); -public: - WindowTransformAction(const Block & input_header, const WindowDescription & window_description_, const String & req_id); - - void cleanUp(); + void initialWorkspaces(); + void initialAggregateFunction(WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description); + void initialPartitionAndOrderColumnIndices(); void advancePartitionEnd(); bool isDifferentFromPrevPartition(UInt64 current_partition_row); @@ -96,11 +106,17 @@ struct WindowTransformAction void writeOutCurrentRow(); - Block tryGetOutputBlock(); void releaseAlreadyOutputWindowBlock(); - void initialWorkspaces(); - void initialPartitionAndOrderColumnIndices(); + void updateAggregationState(); + + void reinitializeAggFuncBeforeNextPartition(); + +public: + WindowTransformAction(const Block & input_header, const WindowDescription & window_description_, const String & req_id); + + void cleanUp(); + Block tryGetOutputBlock(); Columns & inputAt(const RowNumber & x) { @@ -241,6 +257,9 @@ struct WindowTransformAction // aggregate function. We use them to determine how to update the aggregation // state after we find the new frame. RowNumber prev_frame_start; + RowNumber prev_frame_end; + + std::unique_ptr arena; //TODO: used as template parameters bool only_have_row_number = false; diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.cpp b/dbms/src/Debug/MockExecutor/WindowBinder.cpp index 56f808f0950..85585bfd233 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.cpp +++ b/dbms/src/Debug/MockExecutor/WindowBinder.cpp @@ -19,11 +19,105 @@ #include #include - namespace DB::mock { using ASTPartitionByElement = ASTOrderByElement; +void setFieldTypeForAggFunc(const DB::ASTFunction * func, tipb::Expr * expr, const tipb::ExprType agg_sig, int32_t collator_id) +{ + expr->set_tp(agg_sig); + if (agg_sig == tipb::ExprType::Count || agg_sig == tipb::ExprType::Sum) + { + auto * ft = expr->mutable_field_type(); + ft->set_tp(TiDB::TypeLongLong); + ft->set_flag(TiDB::ColumnFlagNotNull); + } + else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max) + { + if (expr->children_size() != 1) + throw Exception(fmt::format("Agg function({}) only accept 1 argument", func->name)); + + auto * ft = expr->mutable_field_type(); + ft->set_tp(expr->children(0).field_type().tp()); + ft->set_decimal(expr->children(0).field_type().decimal()); + ft->set_flag(expr->children(0).field_type().flag() & (~TiDB::ColumnFlagNotNull)); + ft->set_collate(collator_id); + } + else + { + throw Exception("Window does not support this agg function"); + } + + expr->set_aggfuncmode(tipb::AggFunctionMode::FinalMode); +} + +void setFieldTypeForWindowFunc(tipb::Expr * expr, const tipb::ExprType window_sig, int32_t collator_id) +{ + expr->set_tp(window_sig); + auto * ft = expr->mutable_field_type(); + switch (window_sig) + { + case tipb::ExprType::Lead: + case tipb::ExprType::Lag: + { + // TODO handling complex situations + // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) + assert(expr->children_size() >= 1 && expr->children_size() <= 3); + const auto first_arg_type = expr->children(0).field_type(); + ft->set_tp(first_arg_type.tp()); + if (expr->children_size() < 3) + { + auto field_type = TiDB::fieldTypeToColumnInfo(first_arg_type); + field_type.clearNotNullFlag(); + ft->set_flag(field_type.flag); + } + else + { + const auto third_arg_type = expr->children(2).field_type(); + assert(first_arg_type.tp() == third_arg_type.tp()); + ft->set_flag(TiDB::fieldTypeToColumnInfo(first_arg_type).hasNotNullFlag() + ? third_arg_type.flag() + : first_arg_type.flag()); + } + ft->set_collate(first_arg_type.collate()); + ft->set_flen(first_arg_type.flen()); + ft->set_decimal(first_arg_type.decimal()); + break; + } + case tipb::ExprType::FirstValue: + case tipb::ExprType::LastValue: + { + assert(expr->children_size() == 1); + const auto arg_type = expr->children(0).field_type(); + (*ft) = arg_type; + break; + } + default: + ft->set_tp(TiDB::TypeLongLong); + ft->set_flag(TiDB::ColumnFlagBinary); + ft->set_collate(collator_id); + ft->set_flen(21); + ft->set_decimal(-1); + } +} + +void setFieldType(const DB::ASTFunction * func, tipb::Expr * expr, int32_t collator_id) +{ + auto window_sig_it = tests::window_func_name_to_sig.find(func->name); + if (window_sig_it != tests::window_func_name_to_sig.end()) + { + setFieldTypeForWindowFunc(expr, window_sig_it->second, collator_id); + return; + } + + auto agg_sig_it = tests::agg_func_name_to_sig.find(func->name); + if (agg_sig_it == tests::agg_func_name_to_sig.end()) + throw Exception("Unsupported agg function: " + func->name, ErrorCodes::LOGICAL_ERROR); + + auto agg_sig = agg_sig_it->second; + setFieldTypeForAggFunc(func, expr, agg_sig, collator_id); +} + bool WindowBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context & context) { tipb_executor->set_tp(tipb::ExecType::TypeWindow); @@ -40,56 +134,8 @@ bool WindowBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collat tipb::Expr * func = window_expr->add_children(); astToPB(input_schema, arg, func, collator_id, context); } - auto window_sig_it = tests::window_func_name_to_sig.find(window_func->name); - if (window_sig_it == tests::window_func_name_to_sig.end()) - throw Exception(fmt::format("Unsupported window function {}", window_func->name), ErrorCodes::LOGICAL_ERROR); - auto window_sig = window_sig_it->second; - window_expr->set_tp(window_sig); - auto * ft = window_expr->mutable_field_type(); - switch (window_sig) - { - case tipb::ExprType::Lead: - case tipb::ExprType::Lag: - { - // TODO handling complex situations - // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) - assert(window_expr->children_size() >= 1 && window_expr->children_size() <= 3); - const auto first_arg_type = window_expr->children(0).field_type(); - ft->set_tp(first_arg_type.tp()); - if (window_expr->children_size() < 3) - { - auto field_type = TiDB::fieldTypeToColumnInfo(first_arg_type); - field_type.clearNotNullFlag(); - ft->set_flag(field_type.flag); - } - else - { - const auto third_arg_type = window_expr->children(2).field_type(); - assert(first_arg_type.tp() == third_arg_type.tp()); - ft->set_flag(TiDB::fieldTypeToColumnInfo(first_arg_type).hasNotNullFlag() - ? third_arg_type.flag() - : first_arg_type.flag()); - } - ft->set_collate(first_arg_type.collate()); - ft->set_flen(first_arg_type.flen()); - ft->set_decimal(first_arg_type.decimal()); - break; - } - case tipb::ExprType::FirstValue: - case tipb::ExprType::LastValue: - { - assert(window_expr->children_size() == 1); - const auto arg_type = window_expr->children(0).field_type(); - (*ft) = arg_type; - break; - } - default: - ft->set_tp(TiDB::TypeLongLong); - ft->set_flag(TiDB::ColumnFlagBinary); - ft->set_collate(collator_id); - ft->set_flen(21); - ft->set_decimal(-1); - } + + setFieldType(window_func, window_expr, collator_id); } for (const auto & child : order_by_exprs) @@ -149,6 +195,81 @@ bool WindowBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collat return children[0]->toTiPBExecutor(children_executor, collator_id, mpp_info, context); } +void setColumnInfoForAgg(TiDB::ColumnInfo & ci, const DB::ASTFunction * func, const std::vector & children_ci) +{ + // TODO: Other agg func. + if (func->name == "count") + { + ci.tp = TiDB::TypeLongLong; + ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull; + } + else if (func->name == "max" || func->name == "min" || func->name == "sum") + { + ci = children_ci[0]; + ci.flag &= ~TiDB::ColumnFlagNotNull; + } + else + { + throw Exception("Unsupported agg function: " + func->name, ErrorCodes::LOGICAL_ERROR); + } +} + +void setColumnInfoForWindowFunc(TiDB::ColumnInfo & ci, const DB::ASTFunction * func, const std::vector & children_ci, tipb::ExprType expr_type) +{ + // TODO: add more window functions + switch (expr_type) + { + case tipb::ExprType::RowNumber: + case tipb::ExprType::Rank: + case tipb::ExprType::DenseRank: + { + ci.tp = TiDB::TypeLongLong; + ci.flag = TiDB::ColumnFlagBinary; + break; + } + case tipb::ExprType::Lead: + case tipb::ExprType::Lag: + { + // TODO handling complex situations + // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) + assert(!children_ci.empty() && children_ci.size() <= 3); + if (children_ci.size() < 3) + { + ci = children_ci[0]; + ci.clearNotNullFlag(); + } + else + { + assert(children_ci[0].tp == children_ci[2].tp); + ci = children_ci[0].hasNotNullFlag() ? children_ci[2] : children_ci[0]; + } + break; + } + case tipb::ExprType::FirstValue: + case tipb::ExprType::LastValue: + { + ci = children_ci[0]; + break; + } + default: + throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR); + } +} + +TiDB::ColumnInfo createColumnInfo(const DB::ASTFunction * func, const std::vector & children_ci) +{ + TiDB::ColumnInfo ci; + auto iter = tests::window_func_name_to_sig.find(func->name); + if (iter != tests::window_func_name_to_sig.end()) + { + setColumnInfoForWindowFunc(ci, func, children_ci, iter->second); + return ci; + } + + setColumnInfoForAgg(ci, func, children_ci); + return ci; +} + ExecutorBinderPtr compileWindow(ExecutorBinderPtr input, size_t & executor_index, ASTPtr func_desc_list, ASTPtr partition_by_expr_list, ASTPtr order_by_expr_list, mock::MockWindowFrame frame, uint64_t fine_grained_shuffle_stream_count) { std::vector partition_columns; @@ -200,45 +321,8 @@ ExecutorBinderPtr compileWindow(ExecutorBinderPtr input, size_t & executor_index { children_ci.push_back(compileExpr(input->output_schema, arg)); } - // TODO: add more window functions - TiDB::ColumnInfo ci; - switch (tests::window_func_name_to_sig[func->name]) - { - case tipb::ExprType::RowNumber: - case tipb::ExprType::Rank: - case tipb::ExprType::DenseRank: - { - ci.tp = TiDB::TypeLongLong; - ci.flag = TiDB::ColumnFlagBinary; - break; - } - case tipb::ExprType::Lead: - case tipb::ExprType::Lag: - { - // TODO handling complex situations - // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) - assert(!children_ci.empty() && children_ci.size() <= 3); - if (children_ci.size() < 3) - { - ci = children_ci[0]; - ci.clearNotNullFlag(); - } - else - { - assert(children_ci[0].tp == children_ci[2].tp); - ci = children_ci[0].hasNotNullFlag() ? children_ci[2] : children_ci[0]; - } - break; - } - case tipb::ExprType::FirstValue: - case tipb::ExprType::LastValue: - { - ci = children_ci[0]; - break; - } - default: - throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR); - } + + TiDB::ColumnInfo ci = createColumnInfo(func, children_ci); output_schema.emplace_back(std::make_pair(func->getColumnName(), ci)); } } diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.h b/dbms/src/Debug/MockExecutor/WindowBinder.h index c96775dc706..0d71e729379 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.h +++ b/dbms/src/Debug/MockExecutor/WindowBinder.h @@ -157,6 +157,10 @@ class WindowBinder : public ExecutorBinder bool toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context & context) override; +private: + void buildWindowFunc(); + void buildAggFunc(); + private: std::vector func_descs; std::vector partition_by_exprs; diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index ebc0180cbaa..64a85f99d4d 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -39,6 +39,7 @@ #include #include #include +#include namespace DB @@ -168,14 +169,15 @@ void appendWindowDescription( const Names & arg_names, const DataTypes & arg_types, TiDB::TiDBCollators & arg_collators, - const String & window_func_name, + const String & func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns) + NamesAndTypes & window_columns, + bool is_agg) { assert(arg_names.size() == arg_collators.size() && arg_names.size() == arg_types.size()); - String func_string = genFuncString(window_func_name, arg_names, arg_collators); + String func_string = genFuncString(func_name, arg_names, arg_collators); if (auto duplicated_return_type = findDuplicateAggWindowFunc(func_string, window_description.window_functions_descriptions)) { // window function duplicate, don't need to build again. @@ -186,12 +188,42 @@ void appendWindowDescription( WindowFunctionDescription window_function_description; window_function_description.argument_names = arg_names; window_function_description.column_name = func_string; - window_function_description.window_function = WindowFunctionFactory::instance().get(window_func_name, arg_types); - DataTypePtr result_type = window_function_description.window_function->getReturnType(); + + DataTypePtr result_type; + if (is_agg) + { + window_function_description.aggregate_function = AggregateFunctionFactory::instance().get(func_name, arg_types, {}, 0, true); + result_type = window_function_description.aggregate_function->getReturnType(); + } + else + { + window_function_description.window_function = WindowFunctionFactory::instance().get(func_name, arg_types); + result_type = window_function_description.window_function->getReturnType(); + } + window_description.window_functions_descriptions.emplace_back(std::move(window_function_description)); window_columns.emplace_back(func_string, result_type); source_columns.emplace_back(func_string, result_type); } + +bool isWindowFunction(const tipb::ExprType expr_type) +{ + switch (expr_type) + { + case tipb::ExprType::FirstValue: + case tipb::ExprType::LastValue: + case tipb::ExprType::RowNumber: + case tipb::ExprType::Rank: + case tipb::ExprType::DenseRank: + case tipb::ExprType::CumeDist: + case tipb::ExprType::PercentRank: + case tipb::ExprType::Ntile: + case tipb::ExprType::NthValue: + return true; + default: + return false; + } +} } // namespace ExpressionActionsChain::Step & DAGExpressionAnalyzer::initAndGetLastStep(ExpressionActionsChain & chain) const @@ -572,16 +604,18 @@ void DAGExpressionAnalyzer::buildLeadLag( window_func_name, window_description, source_columns, - window_columns); + window_columns, + false); } -void DAGExpressionAnalyzer::buildCommonWindowFunc( +void DAGExpressionAnalyzer::buildWindowOrAggFuncImpl( const tipb::Expr & expr, const ExpressionActionsPtr & actions, const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns) + NamesAndTypes & window_columns, + bool is_agg) { auto child_size = expr.children_size(); Names arg_names; @@ -599,7 +633,8 @@ void DAGExpressionAnalyzer::buildCommonWindowFunc( window_func_name, window_description, source_columns, - window_columns); + window_columns, + is_agg); } // This function will add new window function culumns to source_column @@ -611,14 +646,17 @@ void DAGExpressionAnalyzer::appendWindowColumns(WindowDescription & window_descr NamesAndTypes window_columns; for (const tipb::Expr & expr : window.func_desc()) { - RUNTIME_CHECK_MSG(isWindowFunctionExpr(expr), "Now Window Operator only support window function."); if (expr.tp() == tipb::ExprType::Lead || expr.tp() == tipb::ExprType::Lag) { buildLeadLag(expr, actions, getWindowFunctionName(expr), window_description, source_columns, window_columns); } + else if (isWindowFunction(expr.tp())) + { + buildWindowOrAggFuncImpl(expr, actions, getWindowFunctionName(expr), window_description, source_columns, window_columns, false); + } else { - buildCommonWindowFunc(expr, actions, getWindowFunctionName(expr), window_description, source_columns, window_columns); + buildWindowOrAggFuncImpl(expr, actions, getAggFunctionName(expr), window_description, source_columns, window_columns, true); } } window_description.add_columns = window_columns; diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h index 40ddd14195d..3d71eeba5f9 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h @@ -245,13 +245,14 @@ class DAGExpressionAnalyzer : private boost::noncopyable NamesAndTypes & source_columns, NamesAndTypes & window_columns); - void buildCommonWindowFunc( + void buildWindowOrAggFuncImpl( const tipb::Expr & expr, const ExpressionActionsPtr & actions, const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns); + NamesAndTypes & window_columns, + bool is_agg); void fillArgumentDetail( const ExpressionActionsPtr & actions, diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index e4fb9a654b3..77b712c7975 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -51,7 +51,7 @@ const std::unordered_map agg_func_map({ {tipb::ExprType::First, "first_row"}, {tipb::ExprType::ApproxCountDistinct, uniq_raw_res_name}, {tipb::ExprType::GroupConcat, "groupArray"}, - //{tipb::ExprType::Avg, ""}, + {tipb::ExprType::Avg, "avg"}, //{tipb::ExprType::Agg_BitAnd, ""}, //{tipb::ExprType::Agg_BitOr, ""}, //{tipb::ExprType::Agg_BitXor, ""}, diff --git a/dbms/src/Interpreters/WindowDescription.h b/dbms/src/Interpreters/WindowDescription.h index e986f920f65..aa7aec9781c 100644 --- a/dbms/src/Interpreters/WindowDescription.h +++ b/dbms/src/Interpreters/WindowDescription.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -30,10 +31,11 @@ namespace DB struct WindowFunctionDescription { WindowFunctionPtr window_function; + AggregateFunctionPtr aggregate_function; Array parameters; ColumnNumbers arguments; Names argument_names; - std::string column_name; + String column_name; }; using WindowFunctionDescriptions = std::vector; diff --git a/dbms/src/TestUtils/WindowTestUtils.h b/dbms/src/TestUtils/WindowTestUtils.h index 5e41852c51e..74a2f2598ee 100644 --- a/dbms/src/TestUtils/WindowTestUtils.h +++ b/dbms/src/TestUtils/WindowTestUtils.h @@ -67,7 +67,7 @@ class WindowTest : public ExecutorTest static const size_t MAX_CONCURRENCY_LEVEL = 10; static constexpr auto PARTITION_COL_NAME = "partition"; static constexpr auto ORDER_COL_NAME = "order"; - static constexpr auto VALUE_COL_NAME = "first_value"; + static constexpr auto VALUE_COL_NAME = "value"; template mock::MockWindowFrameBound buildRangeFrameBound( diff --git a/dbms/src/WindowFunctions/tests/gtest_agg_func.cpp b/dbms/src/WindowFunctions/tests/gtest_agg_func.cpp new file mode 100644 index 00000000000..60576aade6d --- /dev/null +++ b/dbms/src/WindowFunctions/tests/gtest_agg_func.cpp @@ -0,0 +1,130 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + + +namespace DB::tests +{ +class WindowAggFuncTest : public DB::tests::WindowTest +{ +public: + const ASTPtr value_col = col(VALUE_COL_NAME); + + void initializeContext() override + { + ExecutorTest::initializeContext(); + } +}; + +TEST_F(WindowAggFuncTest, windowAggSumTests) +try +{ + { + // rows frame + MockWindowFrame frame; + frame.type = tipb::WindowFrameType::Rows; + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, 0); + frame.end = mock::MockWindowFrameBound(tipb::WindowBoundType::Following, false, 3); + std::vector frame_start_offset{0, 1, 3, 10}; + + std::vector> res{ + {0, 15, 14, 12, 8, 26, 41, 38, 28, 15, 18, 32, 49, 75, 66, 51, 31}, + {0, 15, 15, 14, 12, 26, 41, 41, 38, 28, 18, 33, 52, 80, 75, 66, 51}, + {0, 15, 15, 15, 15, 26, 41, 41, 41, 41, 18, 33, 53, 84, 83, 80, 75}, + {0, 15, 15, 15, 15, 26, 41, 41, 41, 41, 18, 33, 53, 84, 84, 84, 84}}; + + for (size_t i = 0; i < frame_start_offset.size(); ++i) + { + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, frame_start_offset[i]); + + executeFunctionAndAssert( + toVec(res[i]), + Sum(value_col), + {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), + toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), + toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, + frame); + } + } + + // TODO uncomment these test after range frame is merged + // { + // // range frame + // MockWindowFrame frame; + // frame.type = tipb::WindowFrameType::Rows; + // frame.start = buildRangeFrameBound(tipb::WindowBoundType::Preceding, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, false, 0); + // frame.end = buildRangeFrameBound(tipb::WindowBoundType::Following, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, true, 3); + // std::vector frame_start_offset{0, 1, 3, 10}; + + // std::vector> res_not_null{ + // {0, 7, 6, 4, 8, 3, 3, 23, 28, 15, 4, 8, 5, 9, 15, 20, 31}, + // {0, 7, 7, 4, 8, 3, 3, 23, 28, 15, 4, 8, 5, 9, 15, 20, 31}, + // {0, 7, 7, 7, 8, 3, 3, 23, 38, 28, 4, 9, 8, 9, 15, 20, 31}, + // {0, 7, 7, 7, 15, 3, 3, 26, 41, 38, 4, 9, 9, 18, 29, 35, 31}}; + + // for (size_t i = 0; i < frame_start_offset.size(); ++i) + // { + // frame.start = buildRangeFrameBound(tipb::WindowBoundType::Preceding, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, false, 0); + + // executeFunctionAndAssert( + // toVec(res_not_null[i]), + // Sum(value_col), + // {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), + // toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), + // toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, + // frame); + // } + // } +} +CATCH + +TEST_F(WindowAggFuncTest, windowAggCountTests) +try +{ + { + // rows frame + MockWindowFrame frame; + frame.type = tipb::WindowFrameType::Rows; + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, 0); + frame.end = mock::MockWindowFrameBound(tipb::WindowBoundType::Following, false, 3); + std::vector frame_start_offset{0, 1, 3, 10}; + + std::vector> res{ + {1, 4, 3, 2, 1, 4, 4, 3, 2, 1, 4, 4, 4, 4, 3, 2, 1}, + {1, 4, 4, 3, 2, 4, 5, 4, 3, 2, 4, 5, 5, 5, 4, 3, 2}, + {1, 4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 5, 6, 7, 6, 5, 4}, + {1, 4, 4, 4, 4, 4, 5, 5, 5, 5, 4, 5, 6, 7, 7, 7, 7}}; + + for (size_t i = 0; i < frame_start_offset.size(); ++i) + { + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, frame_start_offset[i]); + + executeFunctionAndAssert( + toVec(res[i]), + Count(value_col), + {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), + toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), + toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, + frame); + } + } + // TODO add range frame tests after that is merged +} +CATCH +} // namespace DB::tests diff --git a/tests/fullstack-test/mpp/window_agg.test b/tests/fullstack-test/mpp/window_agg.test new file mode 100644 index 00000000000..570b21ad524 --- /dev/null +++ b/tests/fullstack-test/mpp/window_agg.test @@ -0,0 +1,139 @@ +# Copyright 2023 PingCAP, Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mysql> drop table if exists test.agg; +mysql> create table test.agg(p int not null, o int not null, v int not null); +mysql> insert into test.agg (p, o, v) values (0, 0, 0), (1, 1, 1), (1, 2, 2), (1, 4, 4), (1, 8, 8), (2, 0, 0), (2, 3, 3), (2, 10, 10), (2, 13, 13), (2, 15, 15), (3, 1, 1), (3, 3, 3), (3, 5, 5), (3, 9, 9), (3, 15, 15), (3, 20, 20), (3, 31, 31); +mysql> alter table agg set tiflash replica 1; + +func> wait_table test test.agg + +mysql> use test; set tidb_enforce_mpp=1; + +ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncAvg, ast.AggFuncMax, ast.AggFuncMin + +mysql> use test; set tidb_enforce_mpp=1; select *, sum(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+------+ +| p | o | v | a | ++---+----+----+------+ +| 0 | 0 | 0 | 0 | +| 1 | 1 | 1 | 3 | +| 1 | 2 | 2 | 7 | +| 1 | 4 | 4 | 14 | +| 1 | 8 | 8 | 12 | +| 2 | 0 | 0 | 3 | +| 2 | 3 | 3 | 13 | +| 2 | 10 | 10 | 26 | +| 2 | 13 | 13 | 38 | +| 2 | 15 | 15 | 28 | +| 3 | 1 | 1 | 4 | +| 3 | 3 | 3 | 9 | +| 3 | 5 | 5 | 17 | +| 3 | 9 | 9 | 29 | +| 3 | 15 | 15 | 44 | +| 3 | 20 | 20 | 66 | +| 3 | 31 | 31 | 51 | ++---+----+----+------+ + +mysql> use test; set tidb_enforce_mpp=1; select *, count(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+---+ +| p | o | v | a | ++---+----+----+---+ +| 0 | 0 | 0 | 1 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 2 | 3 | +| 1 | 4 | 4 | 3 | +| 1 | 8 | 8 | 2 | +| 2 | 0 | 0 | 2 | +| 2 | 3 | 3 | 3 | +| 2 | 10 | 10 | 3 | +| 2 | 13 | 13 | 3 | +| 2 | 15 | 15 | 2 | +| 3 | 1 | 1 | 2 | +| 3 | 3 | 3 | 3 | +| 3 | 5 | 5 | 3 | +| 3 | 9 | 9 | 3 | +| 3 | 15 | 15 | 3 | +| 3 | 20 | 20 | 3 | +| 3 | 31 | 31 | 2 | ++---+----+----+---+ + +mysql> use test; set tidb_enforce_mpp=1; select *, avg(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+---------+ +| p | o | v | a | ++---+----+----+---------+ +| 0 | 0 | 0 | 0.0000 | +| 1 | 1 | 1 | 1.5000 | +| 1 | 2 | 2 | 2.3333 | +| 1 | 4 | 4 | 4.6666 | +| 1 | 8 | 8 | 6.0000 | +| 2 | 0 | 0 | 1.5000 | +| 2 | 3 | 3 | 4.3333 | +| 2 | 10 | 10 | 8.6666 | +| 2 | 13 | 13 | 12.6666 | +| 2 | 15 | 15 | 14.0000 | +| 3 | 1 | 1 | 2.0000 | +| 3 | 3 | 3 | 3.0000 | +| 3 | 5 | 5 | 5.6666 | +| 3 | 9 | 9 | 9.6666 | +| 3 | 15 | 15 | 14.6666 | +| 3 | 20 | 20 | 22.0000 | +| 3 | 31 | 31 | 25.5000 | ++---+----+----+---------+ + +mysql> use test; set tidb_enforce_mpp=1; select *, min(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+------+ +| p | o | v | a | ++---+----+----+------+ +| 0 | 0 | 0 | 0 | +| 1 | 1 | 1 | 1 | +| 1 | 2 | 2 | 1 | +| 1 | 4 | 4 | 2 | +| 1 | 8 | 8 | 4 | +| 2 | 0 | 0 | 0 | +| 2 | 3 | 3 | 0 | +| 2 | 10 | 10 | 3 | +| 2 | 13 | 13 | 10 | +| 2 | 15 | 15 | 13 | +| 3 | 1 | 1 | 1 | +| 3 | 3 | 3 | 1 | +| 3 | 5 | 5 | 3 | +| 3 | 9 | 9 | 5 | +| 3 | 15 | 15 | 9 | +| 3 | 20 | 20 | 15 | +| 3 | 31 | 31 | 20 | ++---+----+----+------+ + +mysql> use test; set tidb_enforce_mpp=1; select *, max(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+------+ +| p | o | v | a | ++---+----+----+------+ +| 0 | 0 | 0 | 0 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 2 | 4 | +| 1 | 4 | 4 | 8 | +| 1 | 8 | 8 | 8 | +| 2 | 0 | 0 | 3 | +| 2 | 3 | 3 | 10 | +| 2 | 10 | 10 | 13 | +| 2 | 13 | 13 | 15 | +| 2 | 15 | 15 | 15 | +| 3 | 1 | 1 | 3 | +| 3 | 3 | 3 | 5 | +| 3 | 5 | 5 | 9 | +| 3 | 9 | 9 | 15 | +| 3 | 15 | 15 | 20 | +| 3 | 20 | 20 | 31 | +| 3 | 31 | 31 | 31 | ++---+----+----+------+