Skip to content

Commit

Permalink
set empty_result_for_aggregation_by_empty_set according to AggregateF…
Browse files Browse the repository at this point in the history
…uncMode (#3822) (#4012)

close pingcap/tidb#30923
  • Loading branch information
ti-chi-bot authored Jul 14, 2022
1 parent 82c1eae commit d5e5a83
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 89 deletions.
149 changes: 69 additions & 80 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ extern const int LOGICAL_ERROR;
extern const int NO_SUCH_COLUMN_IN_TABLE;
} // namespace ErrorCodes

using TiDB::DatumFlat;
using TiDB::TableInfo;

using DAGColumnInfo = std::pair<String, ColumnInfo>;
using DAGSchema = std::vector<DAGColumnInfo>;
static const String ENCODE_TYPE_NAME = "encode_type";
Expand Down Expand Up @@ -129,6 +126,16 @@ std::unordered_map<String, tipb::ScalarFuncSig> func_name_to_sig({

});

std::unordered_map<String, tipb::ExprType> agg_func_name_to_sig({
{"min", tipb::ExprType::Min},
{"max", tipb::ExprType::Max},
{"count", tipb::ExprType::Count},
{"sum", tipb::ExprType::Sum},
{"first_row", tipb::ExprType::First},
{"uniqRawRes", tipb::ExprType::ApproxCountDistinct},
{"group_concat", tipb::ExprType::GroupConcat},
});

std::pair<String, String> splitQualifiedName(String s)
{
std::pair<String, String> ret;
Expand Down Expand Up @@ -277,11 +284,11 @@ BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DA
{
/// contains a table scan
auto regions = context.getTMTContext().getRegionTable().getRegionsByTable(table_id);
if (regions.size() < (size_t)properties.mpp_partition_num)
if (regions.size() < static_cast<size_t>(properties.mpp_partition_num))
throw Exception("Not supported: table region num less than mpp partition num");
for (size_t i = 0; i < regions.size(); i++)
{
if (i % properties.mpp_partition_num != (size_t)task.partition_id)
if (i % properties.mpp_partition_num != static_cast<size_t>(task.partition_id))
continue;
auto * region = req->add_regions();
region->set_region_id(regions[i].first);
Expand All @@ -300,12 +307,12 @@ BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DA
throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg());
}
tipb::ExchangeReceiver tipb_exchange_receiver;
for (size_t i = 0; i < root_task_ids.size(); i++)
for (const auto root_task_id : root_task_ids)
{
mpp::TaskMeta tm;
tm.set_start_ts(properties.start_ts);
tm.set_address(LOCAL_HOST);
tm.set_task_id(root_task_ids[i]);
tm.set_task_id(root_task_id);
tm.set_partition_id(-1);
auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta();
tm.AppendToString(tm_string);
Expand Down Expand Up @@ -373,15 +380,15 @@ BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DA

BlockInputStreamPtr dbgFuncTiDBQuery(Context & context, const ASTs & args)
{
if (args.size() < 1 || args.size() > 3)
if (args.empty() || args.size() > 3)
throw Exception("Args not matched, should be: query[, region-id, dag_prop_string]", ErrorCodes::BAD_ARGUMENTS);

String query = safeGet<String>(typeid_cast<const ASTLiteral &>(*args[0]).value);
RegionID region_id = InvalidRegionID;
if (args.size() >= 2)
region_id = safeGet<RegionID>(typeid_cast<const ASTLiteral &>(*args[1]).value);

String prop_string = "";
String prop_string;
if (args.size() == 3)
prop_string = safeGet<String>(typeid_cast<const ASTLiteral &>(*args[2]).value);
DAGProperties properties = getDAGProperties(prop_string);
Expand Down Expand Up @@ -417,7 +424,7 @@ BlockInputStreamPtr dbgFuncMockTiDBQuery(Context & context, const ASTs & args)
if (start_ts == 0)
start_ts = context.getTMTContext().getPDClient()->getTS();

String prop_string = "";
String prop_string;
if (args.size() == 4)
prop_string = safeGet<String>(typeid_cast<const ASTLiteral &>(*args[3]).value);
DAGProperties properties = getDAGProperties(prop_string);
Expand Down Expand Up @@ -525,7 +532,7 @@ void foldConstant(tipb::Expr * expr, uint32_t collator_id, const Context & conte
if (expr->tp() == tipb::ScalarFunc)
{
bool all_const = true;
for (auto c : expr->children())
for (const auto & c : expr->children())
{
if (!isLiteralExpr(c))
{
Expand All @@ -537,7 +544,7 @@ void foldConstant(tipb::Expr * expr, uint32_t collator_id, const Context & conte
return;
DataTypes arguments_types;
ColumnsWithTypeAndName argument_columns;
for (auto & c : expr->children())
for (const auto & c : expr->children())
{
Field value = decodeLiteral(c);
DataTypePtr flash_type = applyVisitor(FieldToDataType(), value);
Expand Down Expand Up @@ -821,7 +828,7 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
else
{
bool found = false;
for (auto & field : input)
for (const auto & field : input)
{
auto field_name = splitQualifiedName(field.first);
if (field_name.second == column_name.second)
Expand Down Expand Up @@ -867,14 +874,10 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
struct MPPCtx
{
Timestamp start_ts;
Int64 partition_num;
Int64 next_task_id;
std::vector<Int64> sender_target_task_ids;
std::vector<Int64> current_task_ids;
std::vector<Int64> partition_keys;
MPPCtx(Timestamp start_ts_, size_t partition_num_)
explicit MPPCtx(Timestamp start_ts_)
: start_ts(start_ts_)
, partition_num(partition_num_)
, next_task_id(1)
{}
};
Expand Down Expand Up @@ -930,7 +933,7 @@ struct Executor
{
children[0]->toMPPSubPlan(executor_index, properties, exchange_map);
}
virtual ~Executor() {}
virtual ~Executor() = default;
};

struct ExchangeSender : Executor
Expand Down Expand Up @@ -1047,7 +1050,7 @@ struct TableScan : public Executor
ci->set_decimal(info.second.decimal);
if (!info.second.elems.empty())
{
for (auto & pair : info.second.elems)
for (const auto & pair : info.second.elems)
{
ci->add_elems(pair.first);
}
Expand Down Expand Up @@ -1186,61 +1189,43 @@ struct Aggregation : public Executor
tipb::Expr * arg_expr = agg_func->add_children();
astToPB(input_schema, arg, arg_expr, collator_id, context);
}
auto agg_sig_it = agg_func_name_to_sig.find(func->name);
if (agg_sig_it == agg_func_name_to_sig.end())
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
auto agg_sig = agg_sig_it->second;
agg_func->set_tp(agg_sig);

if (func->name == "count")
{
agg_func->set_tp(tipb::Count);
auto ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull);
}
else if (func->name == "sum")
if (agg_sig == tipb::ExprType::Count || agg_sig == tipb::ExprType::Sum)
{
agg_func->set_tp(tipb::Sum);
auto ft = agg_func->mutable_field_type();
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull);
}
else if (func->name == "max")
{
agg_func->set_tp(tipb::Max);
if (agg_func->children_size() != 1)
throw Exception("udaf max only accept 1 argument");
auto ft = agg_func->mutable_field_type();
ft->set_tp(agg_func->children(0).field_type().tp());
ft->set_decimal(agg_func->children(0).field_type().decimal());
ft->set_flag(agg_func->children(0).field_type().flag());
ft->set_collate(collator_id);
}
else if (func->name == "min")
else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max || agg_sig == tipb::ExprType::First)
{
agg_func->set_tp(tipb::Min);
if (agg_func->children_size() != 1)
throw Exception("udaf min only accept 1 argument");
auto ft = agg_func->mutable_field_type();
throw Exception("udaf " + func->name + " only accept 1 argument");
auto * ft = agg_func->mutable_field_type();
ft->set_tp(agg_func->children(0).field_type().tp());
ft->set_decimal(agg_func->children(0).field_type().decimal());
ft->set_flag(agg_func->children(0).field_type().flag());
ft->set_flag(agg_func->children(0).field_type().flag() & (~TiDB::ColumnFlagNotNull));
ft->set_collate(collator_id);
}
else if (func->name == uniq_raw_res_name)
else if (agg_sig == tipb::ExprType::ApproxCountDistinct)
{
agg_func->set_tp(tipb::ApproxCountDistinct);
auto ft = agg_func->mutable_field_type();
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
ft->set_flag(1);
}
else if (func->name == "group_concat")
else if (agg_sig == tipb::ExprType::GroupConcat)
{
agg_func->set_tp(tipb::GroupConcat);
auto ft = agg_func->mutable_field_type();
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
}
// TODO: Other agg func.
if (is_final_mode)
agg_func->set_aggfuncmode(tipb::AggFunctionMode::FinalMode);
else
{
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
}
agg_func->set_aggfuncmode(tipb::AggFunctionMode::Partial1Mode);
}

for (const auto & child : gby_exprs)
Expand Down Expand Up @@ -1289,8 +1274,6 @@ struct Aggregation : public Executor
// todo support avg
if (has_uniq_raw_res)
throw Exception("uniq raw res not supported in mpp query");
if (gby_exprs.size() == 0)
throw Exception("agg without group by columns not supported in mpp query");
std::shared_ptr<Aggregation> partial_agg = std::make_shared<Aggregation>(
executor_index,
output_schema_for_partial_agg,
Expand All @@ -1307,7 +1290,7 @@ struct Aggregation : public Executor
partition_keys.push_back(i + agg_func_num);
}
std::shared_ptr<ExchangeSender> exchange_sender
= std::make_shared<ExchangeSender>(executor_index, output_schema_for_partial_agg, tipb::Hash, partition_keys);
= std::make_shared<ExchangeSender>(executor_index, output_schema_for_partial_agg, partition_keys.empty() ? tipb::PassThrough : tipb::Hash, partition_keys);
exchange_sender->children.push_back(partial_agg);

std::shared_ptr<ExchangeReceiver> exchange_receiver
Expand Down Expand Up @@ -1421,14 +1404,14 @@ struct Join : Executor

std::unordered_set<String> left_used_columns;
std::unordered_set<String> right_used_columns;
for (auto & s : used_columns)
for (const auto & s : used_columns)
{
if (left_columns.find(s) != left_columns.end())
left_used_columns.emplace(s);
else
right_used_columns.emplace(s);
}
for (auto child : join_params.using_expression_list->children)
for (const auto & child : join_params.using_expression_list->children)
{
if (auto * identifier = typeid_cast<ASTIdentifier *>(child.get()))
{
Expand Down Expand Up @@ -1475,7 +1458,7 @@ struct Join : Executor
}
}

void fillJoinKeyAndFieldType(
static void fillJoinKeyAndFieldType(
ASTPtr key,
const DAGSchema & schema,
tipb::Expr * tipb_key,
Expand All @@ -1485,7 +1468,7 @@ struct Join : Executor
auto * identifier = typeid_cast<ASTIdentifier *>(key.get());
for (size_t index = 0; index < schema.size(); index++)
{
auto & field = schema[index];
const auto & field = schema[index];
if (splitQualifiedName(field.first).second == identifier->getColumnName())
{
auto tipb_type = TiDB::columnInfoToFieldType(field.second);
Expand Down Expand Up @@ -1834,9 +1817,9 @@ ExecutorPtr compileTopN(ExecutorPtr input, size_t & executor_index, ASTPtr order
compileExpr(input->output_schema, elem->children[0]);
}
auto limit = safeGet<UInt64>(typeid_cast<ASTLiteral &>(*limit_expr).value);
auto topN = std::make_shared<mock::TopN>(executor_index, input->output_schema, std::move(order_columns), limit);
topN->children.push_back(input);
return topN;
auto top_n = std::make_shared<mock::TopN>(executor_index, input->output_schema, std::move(order_columns), limit);
top_n->children.push_back(input);
return top_n;
}

ExecutorPtr compileLimit(ExecutorPtr input, size_t & executor_index, ASTPtr limit_expr)
Expand Down Expand Up @@ -1879,9 +1862,10 @@ ExecutorPtr compileAggregation(ExecutorPtr input, size_t & executor_index, ASTPt
ci.tp = TiDB::TypeLongLong;
ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull;
}
else if (func->name == "max" || func->name == "min")
else if (func->name == "max" || func->name == "min" || func->name == "first_row")
{
ci = children_ci[0];
ci.flag &= ~TiDB::ColumnFlagNotNull;
}
else if (func->name == uniq_raw_res_name)
{
Expand Down Expand Up @@ -1971,7 +1955,7 @@ ExecutorPtr compileProject(ExecutorPtr input, size_t & executor_index, ASTPtr se
ExecutorPtr compileJoin(size_t & executor_index, ExecutorPtr left, ExecutorPtr right, ASTPtr params)
{
DAGSchema output_schema;
auto & join_params = (static_cast<const ASTTableJoin &>(*params));
const auto & join_params = (static_cast<const ASTTableJoin &>(*params));
for (auto & field : left->output_schema)
{
if (join_params.kind == ASTTableJoin::Kind::Right && field.second.hasNotNullFlag())
Expand Down Expand Up @@ -2035,7 +2019,7 @@ struct QueryFragment
dag_request.add_output_offsets(i);
auto * root_tipb_executor = dag_request.mutable_root_executor();
root_executor->toTiPBExecutor(root_tipb_executor, properties.collator, mpp_info, context);
return QueryTask(dag_request_ptr, table_id, root_executor->output_schema, mpp_info.sender_target_task_ids.size() == 0 ? DAG : MPP_DISPATCH, mpp_info.task_id, mpp_info.partition_id, is_top_fragment);
return QueryTask(dag_request_ptr, table_id, root_executor->output_schema, mpp_info.sender_target_task_ids.empty() ? DAG : MPP_DISPATCH, mpp_info.task_id, mpp_info.partition_id, is_top_fragment);
}

QueryTasks toQueryTasks(const DAGProperties & properties, const Context & context)
Expand Down Expand Up @@ -2071,7 +2055,7 @@ TableID findTableIdForQueryFragment(ExecutorPtr root_executor, bool must_have_ta
while (!current_executor->children.empty())
{
ExecutorPtr non_exchange_child;
for (auto c : current_executor->children)
for (const auto & c : current_executor->children)
{
if (dynamic_cast<mock::ExchangeReceiver *>(c.get()))
continue;
Expand Down Expand Up @@ -2109,17 +2093,24 @@ QueryFragments mppQueryToQueryFragments(
root_executor->toMPPSubPlan(executor_index, properties, exchange_map);
TableID table_id = findTableIdForQueryFragment(root_executor, exchange_map.empty());
std::vector<Int64> sender_target_task_ids = mpp_ctx->sender_target_task_ids;
std::vector<Int64> current_task_ids = mpp_ctx->current_task_ids;
std::unordered_map<String, std::vector<Int64>> receiver_source_task_ids_map;
size_t current_task_num = properties.mpp_partition_num;
for (auto & exchange : exchange_map)
{
if (exchange.second.second->type == tipb::ExchangeType::PassThrough)
{
current_task_num = 1;
break;
}
}
std::vector<Int64> current_task_ids;
for (size_t i = 0; i < current_task_num; i++)
current_task_ids.push_back(mpp_ctx->next_task_id++);
for (auto & exchange : exchange_map)
{
std::vector<Int64> task_ids;
for (size_t i = 0; i < (size_t)mpp_ctx->partition_num; i++)
task_ids.push_back(mpp_ctx->next_task_id++);
mpp_ctx->sender_target_task_ids = current_task_ids;
mpp_ctx->current_task_ids = task_ids;
receiver_source_task_ids_map[exchange.first] = task_ids;
auto sub_fragments = mppQueryToQueryFragments(exchange.second.second, executor_index, properties, false, mpp_ctx);
receiver_source_task_ids_map[exchange.first] = sub_fragments.cbegin()->task_ids;
fragments.insert(fragments.end(), sub_fragments.begin(), sub_fragments.end());
}
fragments.emplace_back(root_executor, table_id, for_root_fragment, std::move(sender_target_task_ids), std::move(receiver_source_task_ids_map), std::move(current_task_ids));
Expand All @@ -2134,10 +2125,8 @@ QueryFragments queryPlanToQueryFragments(const DAGProperties & properties, Execu
= std::make_shared<mock::ExchangeSender>(executor_index, root_executor->output_schema, tipb::PassThrough);
root_exchange_sender->children.push_back(root_executor);
root_executor = root_exchange_sender;
MPPCtxPtr mpp_ctx = std::make_shared<MPPCtx>(properties.start_ts, properties.mpp_partition_num);
MPPCtxPtr mpp_ctx = std::make_shared<MPPCtx>(properties.start_ts);
mpp_ctx->sender_target_task_ids.emplace_back(-1);
for (size_t i = 0; i < (size_t)properties.mpp_partition_num; i++)
mpp_ctx->current_task_ids.push_back(mpp_ctx->next_task_id++);
return mppQueryToQueryFragments(root_executor, executor_index, properties, true, mpp_ctx);
}
else
Expand Down Expand Up @@ -2195,7 +2184,7 @@ std::pair<ExecutorPtr, bool> compileQueryBlock(
const DAGProperties & properties,
ASTSelectQuery & ast_query)
{
auto joined_table = getJoin(ast_query);
const auto * joined_table = getJoin(ast_query);
/// uniq_raw is used to test `ApproxCountDistinct`, when testing `ApproxCountDistinct` in mock coprocessor
/// the return value of `ApproxCountDistinct` is just the raw result, we need to convert it to a readable
/// value when decoding the result(using `UniqRawResReformatBlockOutputStream`)
Expand Down
Loading

0 comments on commit d5e5a83

Please sign in to comment.