Skip to content

Commit

Permalink
set empty_result_for_aggregation_by_empty_set according to AggregateF…
Browse files Browse the repository at this point in the history
…uncMode (#3822) (#4013)

close pingcap/tidb#30923
  • Loading branch information
ti-chi-bot authored Jul 7, 2022
1 parent b595172 commit 26557c2
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 59 deletions.
113 changes: 60 additions & 53 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ std::unordered_map<String, tipb::ScalarFuncSig> func_name_to_sig({

});

std::unordered_map<String, tipb::ExprType> agg_func_name_to_sig({
{"min", tipb::ExprType::Min},
{"max", tipb::ExprType::Max},
{"count", tipb::ExprType::Count},
{"sum", tipb::ExprType::Sum},
{"first_row", tipb::ExprType::First},
{"uniqRawRes", tipb::ExprType::ApproxCountDistinct},
{"group_concat", tipb::ExprType::GroupConcat},
});

std::pair<String, String> splitQualifiedName(String s)
{
std::pair<String, String> ret;
Expand Down Expand Up @@ -332,12 +342,12 @@ BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DA
throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg());
}
tipb::ExchangeReceiver tipb_exchange_receiver;
for (size_t i = 0; i < root_task_ids.size(); i++)
for (const auto root_task_id : root_task_ids)
{
mpp::TaskMeta tm;
tm.set_start_ts(properties.start_ts);
tm.set_address(LOCAL_HOST);
tm.set_task_id(root_task_ids[i]);
tm.set_task_id(root_task_id);
tm.set_partition_id(-1);
auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta();
tm.AppendToString(tm_string);
Expand Down Expand Up @@ -842,12 +852,9 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
struct MPPCtx
{
Timestamp start_ts;
Int64 partition_num;
Int64 next_task_id;
std::vector<Int64> sender_target_task_ids;
std::vector<Int64> current_task_ids;
std::vector<Int64> partition_keys;
MPPCtx(Timestamp start_ts_, size_t partition_num_) : start_ts(start_ts_), partition_num(partition_num_), next_task_id(1) {}
explicit MPPCtx(Timestamp start_ts_) : start_ts(start_ts_), next_task_id(1) {}
};

using MPPCtxPtr = std::shared_ptr<MPPCtx>;
Expand Down Expand Up @@ -1118,6 +1125,7 @@ struct Aggregation : public Executor
std::vector<ASTPtr> agg_exprs;
std::vector<ASTPtr> gby_exprs;
bool is_final_mode;
DAGSchema output_schema_for_partial_agg;
Aggregation(size_t & index_, const DAGSchema & output_schema_, bool has_uniq_raw_res_, bool need_append_project_,
std::vector<ASTPtr> && agg_exprs_, std::vector<ASTPtr> && gby_exprs_, bool is_final_mode_)
: Executor(index_, "aggregation_" + std::to_string(index_), output_schema_),
Expand Down Expand Up @@ -1146,51 +1154,43 @@ struct Aggregation : public Executor
tipb::Expr * arg_expr = agg_func->add_children();
astToPB(input_schema, arg, arg_expr, collator_id, context);
}
auto agg_sig_it = agg_func_name_to_sig.find(func->name);
if (agg_sig_it == agg_func_name_to_sig.end())
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
auto agg_sig = agg_sig_it->second;
agg_func->set_tp(agg_sig);

if (func->name == "count")
{
agg_func->set_tp(tipb::Count);
auto ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull);
}
else if (func->name == "sum")
if (agg_sig == tipb::ExprType::Count || agg_sig == tipb::ExprType::Sum)
{
agg_func->set_tp(tipb::Sum);
auto ft = agg_func->mutable_field_type();
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull);
}
else if (func->name == "max")
{
agg_func->set_tp(tipb::Max);
if (agg_func->children_size() != 1)
throw Exception("udaf max only accept 1 argument");
auto ft = agg_func->mutable_field_type();
ft->set_tp(agg_func->children(0).field_type().tp());
ft->set_collate(collator_id);
}
else if (func->name == "min")
else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max || agg_sig == tipb::ExprType::First)
{
agg_func->set_tp(tipb::Min);
if (agg_func->children_size() != 1)
throw Exception("udaf min only accept 1 argument");
auto ft = agg_func->mutable_field_type();
throw Exception("udaf " + func->name + " only accept 1 argument");
auto * ft = agg_func->mutable_field_type();
ft->set_tp(agg_func->children(0).field_type().tp());
ft->set_decimal(agg_func->children(0).field_type().decimal());
ft->set_flag(agg_func->children(0).field_type().flag() & (~TiDB::ColumnFlagNotNull));
ft->set_collate(collator_id);
}
else if (func->name == UniqRawResName)
else if (agg_sig == tipb::ExprType::ApproxCountDistinct)
{
agg_func->set_tp(tipb::ApproxCountDistinct);
auto ft = agg_func->mutable_field_type();
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
ft->set_flag(1);
}
// TODO: Other agg func.
else
else if (agg_sig == tipb::ExprType::GroupConcat)
{
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
}
if (is_final_mode)
agg_func->set_aggfuncmode(tipb::AggFunctionMode::FinalMode);
else
agg_func->set_aggfuncmode(tipb::AggFunctionMode::Partial1Mode);
}

for (const auto & child : gby_exprs)
Expand All @@ -1204,6 +1204,8 @@ struct Aggregation : public Executor
}
void columnPrune(std::unordered_set<String> & used_columns) override
{
/// output schema for partial agg is the original agg's output schema
output_schema_for_partial_agg = output_schema;
output_schema.erase(std::remove_if(output_schema.begin(), output_schema.end(),
[&](const auto & field) { return used_columns.count(field.first) == 0; }),
output_schema.end());
Expand Down Expand Up @@ -1239,22 +1241,21 @@ struct Aggregation : public Executor
// todo support avg
if (has_uniq_raw_res)
throw Exception("uniq raw res not supported in mpp query");
if (gby_exprs.size() == 0)
throw Exception("agg without group by columns not supported in mpp query");
std::shared_ptr<Aggregation> partial_agg = std::make_shared<Aggregation>(
executor_index, output_schema, has_uniq_raw_res, false, std::move(agg_exprs), std::move(gby_exprs), false);
executor_index, output_schema_for_partial_agg, has_uniq_raw_res, false, std::move(agg_exprs), std::move(gby_exprs), false);
partial_agg->children.push_back(children[0]);
std::vector<size_t> partition_keys;
size_t agg_func_num = partial_agg->agg_exprs.size();
for (size_t i = 0; i < partial_agg->gby_exprs.size(); i++)
{
partition_keys.push_back(i + agg_func_num);
}
std::shared_ptr<ExchangeSender> exchange_sender
= std::make_shared<ExchangeSender>(executor_index, output_schema, tipb::Hash, partition_keys);
std::shared_ptr<ExchangeSender> exchange_sender = std::make_shared<ExchangeSender>(
executor_index, output_schema_for_partial_agg, partition_keys.empty() ? tipb::PassThrough : tipb::Hash, partition_keys);
exchange_sender->children.push_back(partial_agg);

std::shared_ptr<ExchangeReceiver> exchange_receiver = std::make_shared<ExchangeReceiver>(executor_index, output_schema);
std::shared_ptr<ExchangeReceiver> exchange_receiver
= std::make_shared<ExchangeReceiver>(executor_index, output_schema_for_partial_agg);
exchange_map[exchange_receiver->name] = std::make_pair(exchange_receiver, exchange_sender);
/// re-construct agg_exprs and gby_exprs in final_agg
for (size_t i = 0; i < partial_agg->agg_exprs.size(); i++)
Expand All @@ -1265,12 +1266,12 @@ struct Aggregation : public Executor
if (agg_func->name == "count")
update_agg_func->name = "sum";
update_agg_func->arguments->children.clear();
update_agg_func->arguments->children.push_back(std::make_shared<ASTIdentifier>(output_schema[i].first));
update_agg_func->arguments->children.push_back(std::make_shared<ASTIdentifier>(output_schema_for_partial_agg[i].first));
agg_exprs.push_back(update_agg_expr);
}
for (size_t i = 0; i < partial_agg->gby_exprs.size(); i++)
{
gby_exprs.push_back(std::make_shared<ASTIdentifier>(output_schema[agg_func_num + i].first));
gby_exprs.push_back(std::make_shared<ASTIdentifier>(output_schema_for_partial_agg[agg_func_num + i].first));
}
children[0] = exchange_receiver;
}
Expand Down Expand Up @@ -1800,9 +1801,10 @@ ExecutorPtr compileAggregation(ExecutorPtr input, size_t & executor_index, ASTPt
ci.tp = TiDB::TypeLongLong;
ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull;
}
else if (func->name == "max" || func->name == "min")
else if (func->name == "max" || func->name == "min" || func->name == "first_row")
{
ci = children_ci[0];
ci.flag &= ~TiDB::ColumnFlagNotNull;
}
else if (func->name == UniqRawResName)
{
Expand Down Expand Up @@ -2018,17 +2020,24 @@ QueryFragments mppQueryToQueryFragments(
root_executor->toMPPSubPlan(executor_index, properties, exchange_map);
TableID table_id = findTableIdForQueryFragment(root_executor, exchange_map.empty());
std::vector<Int64> sender_target_task_ids = mpp_ctx->sender_target_task_ids;
std::vector<Int64> current_task_ids = mpp_ctx->current_task_ids;
std::unordered_map<String, std::vector<Int64>> receiver_source_task_ids_map;
size_t current_task_num = properties.mpp_partition_num;
for (auto & exchange : exchange_map)
{
if (exchange.second.second->type == tipb::ExchangeType::PassThrough)
{
current_task_num = 1;
break;
}
}
std::vector<Int64> current_task_ids;
for (size_t i = 0; i < current_task_num; i++)
current_task_ids.push_back(mpp_ctx->next_task_id++);
for (auto & exchange : exchange_map)
{
std::vector<Int64> task_ids;
for (size_t i = 0; i < (size_t)mpp_ctx->partition_num; i++)
task_ids.push_back(mpp_ctx->next_task_id++);
mpp_ctx->sender_target_task_ids = current_task_ids;
mpp_ctx->current_task_ids = task_ids;
receiver_source_task_ids_map[exchange.first] = task_ids;
auto sub_fragments = mppQueryToQueryFragments(exchange.second.second, executor_index, properties, false, mpp_ctx);
receiver_source_task_ids_map[exchange.first] = sub_fragments.cbegin()->task_ids;
fragments.insert(fragments.end(), sub_fragments.begin(), sub_fragments.end());
}
fragments.emplace_back(root_executor, table_id, for_root_fragment, std::move(sender_target_task_ids),
Expand All @@ -2044,10 +2053,8 @@ QueryFragments queryPlanToQueryFragments(const DAGProperties & properties, Execu
= std::make_shared<mock::ExchangeSender>(executor_index, root_executor->output_schema, tipb::PassThrough);
root_exchange_sender->children.push_back(root_executor);
root_executor = root_exchange_sender;
MPPCtxPtr mpp_ctx = std::make_shared<MPPCtx>(properties.start_ts, properties.mpp_partition_num);
MPPCtxPtr mpp_ctx = std::make_shared<MPPCtx>(properties.start_ts);
mpp_ctx->sender_target_task_ids.emplace_back(-1);
for (size_t i = 0; i < (size_t)properties.mpp_partition_num; i++)
mpp_ctx->current_task_ids.push_back(mpp_ctx->next_task_id++);
return mppQueryToQueryFragments(root_executor, executor_index, properties, true, mpp_ctx);
}
else
Expand Down
33 changes: 29 additions & 4 deletions dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ DAGQueryBlockInterpreter::DAGQueryBlockInterpreter(Context & context_, const std
}
}

bool isFinalAgg(const tipb::Expr & expr)
{
if (!expr.has_aggfuncmode())
/// set default value to true to make it compatible with old version of TiDB since before this
/// change, all the aggregation in TiFlash is treated as final aggregation
return true;
return expr.aggfuncmode() == tipb::AggFunctionMode::FinalMode || expr.aggfuncmode() == tipb::AggFunctionMode::CompleteMode;
}

static std::tuple<std::optional<::tipb::DAGRequest>, std::optional<DAGSchema>> //
buildRemoteTS(const RegionRetryList & region_retry, const DAGQueryBlock & query_block, const tipb::TableScan & ts,
const String & handle_column_name, const TableStructureLockHolder &, const ManageableStoragePtr & storage, Context & context,
Expand Down Expand Up @@ -799,6 +808,17 @@ AnalysisResult DAGQueryBlockInterpreter::analyzeExpressions()
// There will be either Agg...
if (query_block.aggregation)
{
/// set default value to true to make it compatible with old version of TiDB since before this
/// change, all the aggregation in TiFlash is treated as final aggregation
res.is_final_agg = true;
const auto & aggregation = query_block.aggregation->aggregation();
if (aggregation.agg_func_size() > 0 && !isFinalAgg(aggregation.agg_func(0)))
res.is_final_agg = false;
for (int i = 1; i < aggregation.agg_func_size(); i++)
{
if (res.is_final_agg != isFinalAgg(aggregation.agg_func(i)))
throw TiFlashException("Different aggregation mode detected", Errors::Coprocessor::BadRequest);
}
/// collation sensitive group by is slower then normal group by, use normal group by by default
// todo better to let TiDB decide whether group by is collation sensitive or not
analyzer->appendAggregation(chain, query_block.aggregation->aggregation(), res.aggregation_keys, res.aggregation_collators,
Expand Down Expand Up @@ -847,8 +867,12 @@ void DAGQueryBlockInterpreter::executeWhere(DAGPipeline & pipeline, const Expres
pipeline.transform([&](auto & stream) { stream = std::make_shared<FilterBlockInputStream>(stream, expr, filter_column); });
}

void DAGQueryBlockInterpreter::executeAggregation(DAGPipeline & pipeline, const ExpressionActionsPtr & expr, Names & key_names,
TiDB::TiDBCollators & collators, AggregateDescriptions & aggregates)
void DAGQueryBlockInterpreter::executeAggregation(DAGPipeline & pipeline,
const ExpressionActionsPtr & expr,
Names & key_names,
TiDB::TiDBCollators & collators,
AggregateDescriptions & aggregates,
bool is_final_agg)
{
pipeline.transform([&](auto & stream) { stream = std::make_shared<ExpressionBlockInputStream>(stream, expr); });

Expand Down Expand Up @@ -890,7 +914,7 @@ void DAGQueryBlockInterpreter::executeAggregation(DAGPipeline & pipeline, const
settings.compile && !has_collator ? &context.getCompiler() : nullptr, settings.min_count_to_compile,
allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold : SettingUInt64(0),
allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold_bytes : SettingUInt64(0),
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set, context.getTemporaryPath(),
settings.max_bytes_before_external_group_by, !is_final_agg, context.getTemporaryPath(),
has_collator ? collators : TiDB::dummy_collators);

/// If there are several sources, then we perform parallel aggregation
Expand Down Expand Up @@ -1385,7 +1409,8 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline)
if (res.need_aggregate)
{
// execute aggregation
executeAggregation(pipeline, res.before_aggregation, res.aggregation_keys, res.aggregation_collators, res.aggregate_descriptions);
executeAggregation(pipeline, res.before_aggregation, res.aggregation_keys, res.aggregation_collators, res.aggregate_descriptions,
res.is_final_agg);
recordProfileStreams(pipeline, query_block.aggregation_name);
}
if (res.has_having)
Expand Down
9 changes: 7 additions & 2 deletions dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct AnalysisResult
bool need_aggregate = false;
bool has_having = false;
bool has_order_by = false;
bool is_final_agg = true;

ExpressionActionsPtr timezone_cast;
ExpressionActionsPtr before_where;
Expand Down Expand Up @@ -117,8 +118,12 @@ class DAGQueryBlockInterpreter
void executeExpression(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr);
void executeOrder(DAGPipeline & pipeline, std::vector<NameAndTypePair> & order_columns);
void executeLimit(DAGPipeline & pipeline);
void executeAggregation(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr, Names & aggregation_keys,
TiDB::TiDBCollators & collators, AggregateDescriptions & aggregate_descriptions);
void executeAggregation(DAGPipeline & pipeline,
const ExpressionActionsPtr & expression_actions_ptr,
Names & key_names,
TiDB::TiDBCollators & collators,
AggregateDescriptions & aggregate_descriptions,
bool is_final_agg);
void executeProject(DAGPipeline & pipeline, NamesWithAliases & project_cols);

void readFromLocalStorage( //
Expand Down
43 changes: 43 additions & 0 deletions tests/delta-merge-test/query/mpp/aggregation_empty_input.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Preparation.
=> DBGInvoke __enable_schema_sync_service('true')

=> DBGInvoke __drop_tidb_table(default, test)
=> drop table if exists default.test

=> DBGInvoke __set_flush_threshold(1000000, 1000000)

# Data.
=> DBGInvoke __mock_tidb_table(default, test, 'col_1 String, col_2 Int64')
=> DBGInvoke __refresh_schemas()
=> DBGInvoke __put_region(4, 0, 100, default, test)
=> DBGInvoke __put_region(5, 100, 200, default, test)
=> DBGInvoke __put_region(6, 200, 300, default, test)

# shuffle agg with empty table
=> DBGInvoke tidb_query('select count(col_1) from default.test', 4,'mpp_query:true,mpp_partition_num:3')
┌─exchange_receiver_0─┐
│ 0 │
└─────────────────────┘

=> DBGInvoke __raft_insert_row(default, test, 4, 50, 'test1', 666)
=> DBGInvoke __raft_insert_row(default, test, 4, 51, 'test2', 666)
=> DBGInvoke __raft_insert_row(default, test, 4, 52, 'test3', 777)
=> DBGInvoke __raft_insert_row(default, test, 4, 53, 'test4', 888)
=> DBGInvoke __raft_insert_row(default, test, 5, 150, 'test1', 666)
=> DBGInvoke __raft_insert_row(default, test, 5, 151, 'test2', 666)
=> DBGInvoke __raft_insert_row(default, test, 5, 152, 'test3', 777)
=> DBGInvoke __raft_insert_row(default, test, 5, 153, 'test4', 888)
=> DBGInvoke __raft_insert_row(default, test, 6, 250, 'test1', 666)
=> DBGInvoke __raft_insert_row(default, test, 6, 251, 'test2', 666)
=> DBGInvoke __raft_insert_row(default, test, 6, 252, 'test3', 777)
=> DBGInvoke __raft_insert_row(default, test, 6, 253, 'test4', 999)

# shuffle agg
=> DBGInvoke tidb_query('select count(col_1), first_row(col_2) from default.test where col_2 = 999', 4,'mpp_query:true,mpp_partition_num:3')
┌─exchange_receiver_0─┬─exchange_receiver_1─┐
│ 1 │ 999 │
└─────────────────────┴─────────────────────┘

# Clean up.
=> DBGInvoke __drop_tidb_table(default, test)
=> drop table if exists default.test
Loading

0 comments on commit 26557c2

Please sign in to comment.