Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support first_value window function #7427

Merged
merged 10 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions dbms/src/DataStreams/WindowBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,14 @@ Block WindowBlockInputStream::readImpl()
}

// Judge whether current_partition_row is end row of partition in current block
// How to judge?
// Compare data in previous partition with the new scanned data.
bool WindowTransformAction::isDifferentFromPrevPartition(UInt64 current_partition_row)
{
// prev_frame_start refers to the data in previous partition
const Columns & reference_columns = inputAt(prev_frame_start);

// partition_end refers to the new scanned data
const Columns & compared_columns = inputAt(partition_end);

for (size_t i = 0; i < partition_column_indices.size(); ++i)
Expand Down Expand Up @@ -299,9 +304,9 @@ void WindowTransformAction::advanceFrameStart()
}
}

bool WindowTransformAction::arePeers(const RowNumber & x, const RowNumber & y) const
bool WindowTransformAction::arePeers(const RowNumber & peer_group_last_row, const RowNumber & current_row) const
{
if (x == y)
if (peer_group_last_row == current_row)
{
// For convenience, a row is always its own peer.
return true;
Expand All @@ -324,18 +329,18 @@ bool WindowTransformAction::arePeers(const RowNumber & x, const RowNumber & y) c

for (size_t i = 0; i < n; ++i)
{
const auto * column_x = inputAt(x)[order_column_indices[i]].get();
const auto * column_y = inputAt(y)[order_column_indices[i]].get();
const auto * column_peer_last = inputAt(peer_group_last_row)[order_column_indices[i]].get();
const auto * column_current = inputAt(current_row)[order_column_indices[i]].get();
if (window_description.order_by[i].collator)
{
if (column_x->compareAt(x.row, y.row, *column_y, 1 /* nan_direction_hint */, *window_description.order_by[i].collator) != 0)
if (column_peer_last->compareAt(peer_group_last_row.row, current_row.row, *column_current, 1 /* nan_direction_hint */, *window_description.order_by[i].collator) != 0)
{
return false;
}
}
else
{
if (column_x->compareAt(x.row, y.row, *column_y, 1 /* nan_direction_hint */) != 0)
if (column_peer_last->compareAt(peer_group_last_row.row, current_row.row, *column_current, 1 /* nan_direction_hint */) != 0)
{
return false;
}
Expand Down Expand Up @@ -607,8 +612,8 @@ void WindowTransformAction::tryCalculate()
partition_start = partition_end;
advanceRowNumber(partition_end);
partition_ended = false;
// We have to reset the frame and other pointers when the new partition
// starts.

// We have to reset the frame and other pointers when the new partition starts.
frame_start = partition_start;
frame_end = partition_start;
prev_frame_start = partition_start;
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/DataStreams/WindowBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct WindowTransformAction
void advancePartitionEnd();
bool isDifferentFromPrevPartition(UInt64 current_partition_row);

bool arePeers(const RowNumber & x, const RowNumber & y) const;
bool arePeers(const RowNumber & peer_group_last_row, const RowNumber & current_row) const;

void advanceFrameStart();
void advanceFrameEndCurrentRow();
Expand Down Expand Up @@ -202,6 +202,7 @@ struct WindowTransformAction

// The row for which we are now computing the window functions.
RowNumber current_row;

// The start of current peer group, needed for CURRENT ROW frame start.
// For ROWS frame, always equal to the current row, and for RANGE and GROUP
// frames may be earlier.
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Debug/MockExecutor/FuncSigMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,6 @@ std::unordered_map<String, tipb::ExprType> window_func_name_to_sig({
{"DenseRank", tipb::ExprType::DenseRank},
{"Lead", tipb::ExprType::Lead},
{"Lag", tipb::ExprType::Lag},
{"FirstValue", tipb::ExprType::FirstValue},
});
} // namespace DB::tests
18 changes: 18 additions & 0 deletions dbms/src/Debug/MockExecutor/WindowBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <Debug/MockExecutor/FuncSigMap.h>
#include <Debug/MockExecutor/WindowBinder.h>
#include <Parsers/ASTFunction.h>
#include <tipb/expression.pb.h>


namespace DB::mock
{
Expand Down Expand Up @@ -73,6 +75,17 @@ bool WindowBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collat
ft->set_decimal(first_arg_type.decimal());
break;
}
case tipb::ExprType::FirstValue:
{
assert(window_expr->children_size() == 1);
const auto arg_type = window_expr->children(0).field_type();
ft->set_tp(arg_type.tp());
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
ft->set_flag(arg_type.flag());
ft->set_collate(arg_type.collate());
ft->set_flen(arg_type.flen());
ft->set_decimal(arg_type.decimal());
break;
}
default:
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagBinary);
Expand Down Expand Up @@ -202,6 +215,11 @@ ExecutorBinderPtr compileWindow(ExecutorBinderPtr input, size_t & executor_index
}
break;
}
case tipb::ExprType::FirstValue:
{
ci = children_ci[0];
break;
}
default:
throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR);
}
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const std::unordered_map<tipb::ExprType, String> window_func_map({
{tipb::ExprType::RowNumber, "row_number"},
{tipb::ExprType::Lead, "lead"},
{tipb::ExprType::Lag, "lag"},
{tipb::ExprType::FirstValue, "first_value"},
});

const std::unordered_map<tipb::ExprType, String> agg_func_map({
Expand Down Expand Up @@ -1030,10 +1031,10 @@ bool isWindowFunctionExpr(const tipb::Expr & expr)
case tipb::ExprType::DenseRank:
case tipb::ExprType::Lead:
case tipb::ExprType::Lag:
case tipb::ExprType::FirstValue:
// case tipb::ExprType::CumeDist:
// case tipb::ExprType::PercentRank:
// case tipb::ExprType::Ntile:
// case tipb::ExprType::FirstValue:
// case tipb::ExprType::LastValue:
// case tipb::ExprType::NthValue:
return true;
Expand Down
1 change: 1 addition & 0 deletions dbms/src/TestUtils/mockExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,6 @@ MockWindowFrame buildDefaultRowsFrame();
#define Lag1(expr) makeASTFunction("Lag", (expr))
#define Lag2(expr1, expr2) makeASTFunction("Lag", (expr1), (expr2))
#define Lag3(expr1, expr2, expr3) makeASTFunction("Lag", (expr1), (expr2), (expr3))
#define FirstValue(expr) makeASTFunction("FirstValue", (expr))
} // namespace tests
} // namespace DB
39 changes: 39 additions & 0 deletions dbms/src/WindowFunctions/IWindowFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ struct WindowFunctionRowNumber final : public IWindowFunction
}
};

struct WindowFunctionFirstValue final : public IWindowFunction
{
public:
static constexpr auto name = "first_value";

explicit WindowFunctionFirstValue(const DataTypes & argument_types_)
: IWindowFunction(argument_types_)
{
assert(argument_types_.size() == 1);
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
return_type = argument_types_[0];
}

String getName() const override
{
return name;
}

DataTypePtr getReturnType() const override
{
return return_type;
}

void windowInsertResultInto(
WindowTransformAction & action,
size_t function_index,
const ColumnNumbers & arguments) override
{
assert(action.frame_started);
IColumn & to = *action.blockAt(action.current_row).output_columns[function_index];
const auto & value_column = *action.inputAt(action.frame_start)[arguments[0]];
const auto & value_field = value_column[action.frame_start.row];
to.insert(value_field);
}

private:
DataTypePtr return_type;
};

/**
LEAD/LAG(<expression>[,offset[, default_value]]) OVER (
PARTITION BY (expr)
Expand Down Expand Up @@ -319,5 +357,6 @@ void registerWindowFunctions(WindowFunctionFactory & factory)
factory.registerFunction<WindowFunctionRowNumber>();
factory.registerFunction<WindowFunctionLeadLagBase<LeadImpl>>();
factory.registerFunction<WindowFunctionLeadLagBase<LagImpl>>();
factory.registerFunction<WindowFunctionFirstValue>();
}
} // namespace DB
141 changes: 141 additions & 0 deletions dbms/src/WindowFunctions/tests/gtest_first_value.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2023 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Interpreters/Context.h>
#include <TestUtils/ExecutorTestUtils.h>

namespace DB::tests
{
// TODO Tests with frame should be added
class FirstValue : public DB::tests::ExecutorTest
SeaRise marked this conversation as resolved.
Show resolved Hide resolved
{
static const size_t max_concurrency_level = 10;

public:
static constexpr auto value_col_name = "first_value";
const ASTPtr value_col = col(value_col_name);

void initializeContext() override
{
ExecutorTest::initializeContext();
}

void executeWithConcurrencyAndBlockSize(const std::shared_ptr<tipb::DAGRequest> & request, const ColumnsWithTypeAndName & expect_columns)
{
std::vector<size_t> block_sizes{1, 2, 3, 4, DEFAULT_BLOCK_SIZE};
for (auto block_size : block_sizes)
{
context.context->setSetting("max_block_size", Field(static_cast<UInt64>(block_size)));
ASSERT_COLUMNS_EQ_R(expect_columns, executeStreams(request));
ASSERT_COLUMNS_EQ_UR(expect_columns, executeStreams(request, 2));
ASSERT_COLUMNS_EQ_UR(expect_columns, executeStreams(request, max_concurrency_level));
}
}

void executeFunctionAndAssert(
const ColumnWithTypeAndName & result,
const ASTPtr & function,
const ColumnsWithTypeAndName & input)
{
ColumnsWithTypeAndName actual_input = input;
assert(actual_input.size() == 3);
TiDB::TP value_tp = dataTypeToTP(actual_input[2].type);

actual_input[0].name = "partition";
actual_input[1].name = "order";
actual_input[2].name = value_col_name;
context.addMockTable(
{"test_db", "test_table_for_lead_lag"},
{{"partition", TiDB::TP::TypeLongLong, actual_input[0].type->isNullable()},
{"order", TiDB::TP::TypeLongLong, actual_input[1].type->isNullable()},
{value_col_name, value_tp, actual_input[2].type->isNullable()}},
actual_input);

auto request = context
.scan("test_db", "test_table_for_lead_lag")
.sort({{"partition", false}, {"order", false}}, true)
.window(function, {"order", false}, {"partition", false}, MockWindowFrame{})
.build(context);

ColumnsWithTypeAndName expect = input;
expect.push_back(result);
executeWithConcurrencyAndBlockSize(request, expect);
}

template <typename IntType>
void testInt()
{
executeFunctionAndAssert(
toVec<IntType>({1, 2, 2, 2, 2, 6, 6, 6, 6, 6, 11, 11, 11}),
FirstValue(value_col),
{toVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toVec<IntType>(/*value*/ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})});

executeFunctionAndAssert(
toNullableVec<IntType>({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}}),
FirstValue(value_col),
{toNullableVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toNullableVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toNullableVec<IntType>(/*value*/ {{}, {}, 3, 4, 5, {}, 7, 8, 9, 10, {}, 12, 13})});
}

template <typename FloatType>
void testFloat()
{
executeFunctionAndAssert(
toVec<FloatType>({1, 2, 2, 2, 2, 6, 6, 6, 6, 6, 11, 11, 11}),
FirstValue(value_col),
{toVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toVec<FloatType>(/*value*/ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})});

executeFunctionAndAssert(
toNullableVec<FloatType>({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}}),
FirstValue(value_col),
{toNullableVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toNullableVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toNullableVec<FloatType>(/*value*/ {{}, {}, 3, 4, 5, {}, 7, 8, 9, 10, {}, 12, 13})});
}
};

TEST_F(FirstValue, firstValue)
try
{
executeFunctionAndAssert(
toVec<String>({"1", "2", "2", "2", "2", "6", "6", "6", "6", "6", "11", "11", "11"}),
FirstValue(value_col),
{toVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toVec<String>(/*value*/ {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13"})});

executeFunctionAndAssert(
toNullableVec<String>({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}}),
FirstValue(value_col),
{toNullableVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toNullableVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toNullableVec<String>(/*value*/ {{}, {}, "3", "4", "5", {}, "7", "8", "9", "10", {}, "12", "13"})});

// TODO support unsigned int.
testInt<Int8>();
testInt<Int16>();
testInt<Int32>();
testInt<Int64>();

testFloat<Float32>();
testFloat<Float64>();
}
CATCH

} // namespace DB::tests
7 changes: 4 additions & 3 deletions dbms/src/WindowFunctions/tests/gtest_lead_lag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ template <typename T>
using Limits = std::numeric_limits<T>;

// TODO Support more convenient testing framework for Window Function.
// TODO Tests with frame should be added
class LeadLag : public DB::tests::ExecutorTest
{
static const size_t max_concurrency_level = 10;
Expand Down Expand Up @@ -60,9 +61,9 @@ class LeadLag : public DB::tests::ExecutorTest
actual_input[2].name = value_col_name;
context.addMockTable(
{"test_db", "test_table_for_lead_lag"},
{{"partition", TiDB::TP::TypeLongLong},
{"order", TiDB::TP::TypeLongLong},
{value_col_name, value_tp}},
{{"partition", TiDB::TP::TypeLongLong, actual_input[0].type->isNullable()},
{"order", TiDB::TP::TypeLongLong, actual_input[1].type->isNullable()},
{value_col_name, value_tp, actual_input[2].type->isNullable()}},
actual_input);

auto request = context
Expand Down