diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 8d91b8b23e9..4d8faffde6c 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -33,9 +33,7 @@ #include #include #include -#include #include -#include #include #include #include @@ -43,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -50,10 +49,6 @@ #include #include #include -#include -#include -#include - namespace DB { @@ -91,7 +86,7 @@ struct AnalysisResult Names aggregation_keys; TiDB::TiDBCollators aggregation_collators; AggregateDescriptions aggregate_descriptions; - bool is_final_agg; + bool is_final_agg = false; }; AnalysisResult analyzeExpressions( @@ -184,322 +179,117 @@ void DAGQueryBlockInterpreter::handleTableScan(const TiDBTableScan & table_scan, analyzer = std::move(storage_interpreter.analyzer); } -void DAGQueryBlockInterpreter::prepareJoin( - const google::protobuf::RepeatedPtrField & keys, - const DataTypes & key_types, - DAGPipeline & pipeline, - Names & key_names, - bool left, - bool is_right_out_join, - const google::protobuf::RepeatedPtrField & filters, - String & filter_column_name) -{ - NamesAndTypes source_columns; - for (auto const & p : pipeline.firstStream()->getHeader().getNamesAndTypesList()) - source_columns.emplace_back(p.name, p.type); - DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); - ExpressionActionsChain chain; - if (dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, key_types, key_names, left, is_right_out_join, filters, filter_column_name)) - { - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), log->identifier()); }); - } -} - -ExpressionActionsPtr DAGQueryBlockInterpreter::genJoinOtherConditionAction( - const tipb::Join & join, - NamesAndTypes & source_columns, - String & filter_column_for_other_condition, - String & filter_column_for_other_eq_condition) -{ - if (join.other_conditions_size() == 0 && join.other_eq_conditions_from_in_size() == 0) - return nullptr; - DAGExpressionAnalyzer dag_analyzer(source_columns, context); - ExpressionActionsChain chain; - std::vector condition_vector; - if (join.other_conditions_size() > 0) - { - for (const auto & c : join.other_conditions()) - { - condition_vector.push_back(&c); - } - filter_column_for_other_condition = dag_analyzer.appendWhere(chain, condition_vector); - } - if (join.other_eq_conditions_from_in_size() > 0) - { - condition_vector.clear(); - for (const auto & c : join.other_eq_conditions_from_in()) - { - condition_vector.push_back(&c); - } - filter_column_for_other_eq_condition = dag_analyzer.appendWhere(chain, condition_vector); - } - return chain.getLastActions(); -} - -/// ClickHouse require join key to be exactly the same type -/// TiDB only require the join key to be the same category -/// for example decimal(10,2) join decimal(20,0) is allowed in -/// TiDB and will throw exception in ClickHouse -void getJoinKeyTypes(const tipb::Join & join, DataTypes & key_types) -{ - for (int i = 0; i < join.left_join_keys().size(); i++) - { - if (!exprHasValidFieldType(join.left_join_keys(i)) || !exprHasValidFieldType(join.right_join_keys(i))) - throw TiFlashException("Join key without field type", Errors::Coprocessor::BadRequest); - DataTypes types; - types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.left_join_keys(i).field_type())); - types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.right_join_keys(i).field_type())); - DataTypePtr common_type = getLeastSupertype(types); - key_types.emplace_back(common_type); - } -} - void DAGQueryBlockInterpreter::handleJoin(const tipb::Join & join, DAGPipeline & pipeline, SubqueryForSet & right_query) { - // build - static const std::unordered_map equal_join_type_map{ - {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Inner}, - {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Left}, - {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Right}, - {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Inner}, - {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Anti}, - {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::LeftSemi}, - {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::LeftAnti}}; - static const std::unordered_map cartesian_join_type_map{ - {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Cross}, - {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Cross_Left}, - {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Cross_Right}, - {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Cross}, - {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Cross_Anti}, - {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftSemi}, - {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftAnti}}; - - if (input_streams_vec.size() != 2) + if (unlikely(input_streams_vec.size() != 2)) { throw TiFlashException("Join query block must have 2 input streams", Errors::BroadcastJoin::Internal); } - const auto & join_type_map = join.left_join_keys_size() == 0 ? cartesian_join_type_map : equal_join_type_map; - auto join_type_it = join_type_map.find(join.join_type()); - if (join_type_it == join_type_map.end()) - throw TiFlashException("Unknown join type in dag request", Errors::Coprocessor::BadRequest); - - /// (cartesian) (anti) left semi join. - const bool is_left_semi_family = join.join_type() == tipb::JoinType::TypeLeftOuterSemiJoin || join.join_type() == tipb::JoinType::TypeAntiLeftOuterSemiJoin; - - ASTTableJoin::Kind kind = join_type_it->second; - const bool is_semi_join = join.join_type() == tipb::JoinType::TypeSemiJoin || join.join_type() == tipb::JoinType::TypeAntiSemiJoin || is_left_semi_family; - ASTTableJoin::Strictness strictness = ASTTableJoin::Strictness::All; - if (is_semi_join) - strictness = ASTTableJoin::Strictness::Any; - - /// in DAG request, inner part is the build side, however for TiFlash implementation, - /// the build side must be the right side, so need to swap the join side if needed - /// 1. for (cross) inner join, there is no problem in this swap. - /// 2. for (cross) semi/anti-semi join, the build side is always right, needn't swap. - /// 3. for non-cross left/right join, there is no problem in this swap. - /// 4. for cross left join, the build side is always right, needn't and can't swap. - /// 5. for cross right join, the build side is always left, so it will always swap and change to cross left join. - /// note that whatever the build side is, we can't support cross-right join now. - - bool swap_join_side; - if (kind == ASTTableJoin::Kind::Cross_Right) - swap_join_side = true; - else if (kind == ASTTableJoin::Kind::Cross_Left) - swap_join_side = false; - else - swap_join_side = join.inner_idx() == 0; + JoinInterpreterHelper::TiFlashJoin tiflash_join{join}; - DAGPipeline left_pipeline; - DAGPipeline right_pipeline; + DAGPipeline probe_pipeline; + DAGPipeline build_pipeline; + probe_pipeline.streams = input_streams_vec[1 - tiflash_join.build_side_index]; + build_pipeline.streams = input_streams_vec[tiflash_join.build_side_index]; - if (swap_join_side) - { - if (kind == ASTTableJoin::Kind::Left) - kind = ASTTableJoin::Kind::Right; - else if (kind == ASTTableJoin::Kind::Right) - kind = ASTTableJoin::Kind::Left; - else if (kind == ASTTableJoin::Kind::Cross_Right) - kind = ASTTableJoin::Kind::Cross_Left; - left_pipeline.streams = input_streams_vec[1]; - right_pipeline.streams = input_streams_vec[0]; - } - else - { - left_pipeline.streams = input_streams_vec[0]; - right_pipeline.streams = input_streams_vec[1]; - } - - NamesAndTypes join_output_columns; - /// columns_for_other_join_filter is a vector of columns used - /// as the input columns when compiling other join filter. - /// Note the order in the column vector is very important: - /// first the columns in input_streams_vec[0], then followed - /// by the columns in input_streams_vec[1], if there are other - /// columns generated before compile other join filter, then - /// append the extra columns afterwards. In order to figure out - /// whether a given column is already in the column vector or - /// not quickly, we use another set to store the column names - NamesAndTypes columns_for_other_join_filter; - std::unordered_set column_set_for_other_join_filter; - bool make_nullable = join.join_type() == tipb::JoinType::TypeRightOuterJoin; - for (auto const & p : input_streams_vec[0][0]->getHeader().getNamesAndTypesList()) - { - join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - column_set_for_other_join_filter.emplace(p.name); - } - make_nullable = join.join_type() == tipb::JoinType::TypeLeftOuterJoin; - for (auto const & p : input_streams_vec[1][0]->getHeader().getNamesAndTypesList()) - { - if (!is_semi_join) - /// for semi join, the columns from right table will be ignored - join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - /// however, when compiling join's other condition, we still need the columns from right table - columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - column_set_for_other_join_filter.emplace(p.name); - } + RUNTIME_ASSERT(!input_streams_vec[0].empty(), log, "left input streams cannot be empty"); + const Block & left_input_header = input_streams_vec[0].back()->getHeader(); - bool is_tiflash_left_join = kind == ASTTableJoin::Kind::Left || kind == ASTTableJoin::Kind::Cross_Left; - /// Cross_Right join will be converted to Cross_Left join, so no need to check Cross_Right - bool is_tiflash_right_join = kind == ASTTableJoin::Kind::Right; - /// all the columns from right table should be added after join, even for the join key - NamesAndTypesList columns_added_by_join; - make_nullable = is_tiflash_left_join; - for (auto const & p : right_pipeline.streams[0]->getHeader().getNamesAndTypesList()) - { - columns_added_by_join.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - } - - String match_helper_name; - if (is_left_semi_family) - { - const auto & left_block = input_streams_vec[0][0]->getHeader(); - const auto & right_block = input_streams_vec[1][0]->getHeader(); + RUNTIME_ASSERT(!input_streams_vec[1].empty(), log, "right input streams cannot be empty"); + const Block & right_input_header = input_streams_vec[1].back()->getHeader(); - match_helper_name = Join::match_helper_prefix; - for (int i = 1; left_block.has(match_helper_name) || right_block.has(match_helper_name); ++i) - { - match_helper_name = Join::match_helper_prefix + std::to_string(i); - } - - columns_added_by_join.emplace_back(match_helper_name, Join::match_helper_type); - join_output_columns.emplace_back(match_helper_name, Join::match_helper_type); - } - - DataTypes join_key_types; - getJoinKeyTypes(join, join_key_types); - TiDB::TiDBCollators collators; - size_t join_key_size = join_key_types.size(); - if (join.probe_types_size() == static_cast(join_key_size) && join.build_types_size() == join.probe_types_size()) - for (size_t i = 0; i < join_key_size; i++) - { - if (removeNullable(join_key_types[i])->isString()) - { - if (join.probe_types(i).collate() != join.build_types(i).collate()) - throw TiFlashException("Join with different collators on the join key", Errors::Coprocessor::BadRequest); - collators.push_back(getCollatorFromFieldType(join.probe_types(i))); - } - else - collators.push_back(nullptr); - } - - Names left_key_names, right_key_names; - String left_filter_column_name, right_filter_column_name; + String match_helper_name = tiflash_join.genMatchHelperName(left_input_header, right_input_header); + NamesAndTypesList columns_added_by_join = tiflash_join.genColumnsAddedByJoin(build_pipeline.firstStream()->getHeader(), match_helper_name); + NamesAndTypes join_output_columns = tiflash_join.genJoinOutputColumns(left_input_header, right_input_header, match_helper_name); /// add necessary transformation if the join key is an expression - prepareJoin( - swap_join_side ? join.right_join_keys() : join.left_join_keys(), - join_key_types, - left_pipeline, - left_key_names, + bool is_tiflash_right_join = tiflash_join.isTiFlashRightJoin(); + + // prepare probe side + auto [probe_side_prepare_actions, probe_key_names, probe_filter_column_name] = JoinInterpreterHelper::prepareJoin( + context, + probe_pipeline.firstStream()->getHeader(), + tiflash_join.getProbeJoinKeys(), + tiflash_join.join_key_types, true, is_tiflash_right_join, - swap_join_side ? join.right_conditions() : join.left_conditions(), - left_filter_column_name); - - prepareJoin( - swap_join_side ? join.left_join_keys() : join.right_join_keys(), - join_key_types, - right_pipeline, - right_key_names, + tiflash_join.getProbeConditions()); + RUNTIME_ASSERT(probe_side_prepare_actions, log, "probe_side_prepare_actions cannot be nullptr"); + + // prepare build side + auto [build_side_prepare_actions, build_key_names, build_filter_column_name] = JoinInterpreterHelper::prepareJoin( + context, + build_pipeline.firstStream()->getHeader(), + tiflash_join.getBuildJoinKeys(), + tiflash_join.join_key_types, false, is_tiflash_right_join, - swap_join_side ? join.left_conditions() : join.right_conditions(), - right_filter_column_name); + tiflash_join.getBuildConditions()); + RUNTIME_ASSERT(build_side_prepare_actions, log, "build_side_prepare_actions cannot be nullptr"); - String other_filter_column_name, other_eq_filter_from_in_column_name; - for (auto const & p : left_pipeline.streams[0]->getHeader().getNamesAndTypesList()) - { - if (column_set_for_other_join_filter.find(p.name) == column_set_for_other_join_filter.end()) - columns_for_other_join_filter.emplace_back(p.name, p.type); - } - for (auto const & p : right_pipeline.streams[0]->getHeader().getNamesAndTypesList()) - { - if (column_set_for_other_join_filter.find(p.name) == column_set_for_other_join_filter.end()) - columns_for_other_join_filter.emplace_back(p.name, p.type); - } - - ExpressionActionsPtr other_condition_expr - = genJoinOtherConditionAction(join, columns_for_other_join_filter, other_filter_column_name, other_eq_filter_from_in_column_name); + auto [other_condition_expr, other_filter_column_name, other_eq_filter_from_in_column_name] + = tiflash_join.genJoinOtherConditionAction(context, left_input_header, right_input_header, probe_side_prepare_actions); const Settings & settings = context.getSettingsRef(); - size_t join_build_concurrency = settings.join_concurrent_build ? std::min(max_streams, right_pipeline.streams.size()) : 1; size_t max_block_size_for_cross_join = settings.max_block_size; fiu_do_on(FailPoints::minimum_block_size_for_cross_join, { max_block_size_for_cross_join = 1; }); JoinPtr join_ptr = std::make_shared( - left_key_names, - right_key_names, + probe_key_names, + build_key_names, true, SizeLimits(settings.max_rows_in_join, settings.max_bytes_in_join, settings.join_overflow_mode), - kind, - strictness, + tiflash_join.kind, + tiflash_join.strictness, log->identifier(), - join_build_concurrency, - collators, - left_filter_column_name, - right_filter_column_name, + tiflash_join.join_key_collators, + probe_filter_column_name, + build_filter_column_name, other_filter_column_name, other_eq_filter_from_in_column_name, other_condition_expr, max_block_size_for_cross_join, match_helper_name); - recordJoinExecuteInfo(swap_join_side ? 0 : 1, join_ptr); + recordJoinExecuteInfo(tiflash_join.build_side_index, join_ptr); + + size_t join_build_concurrency = settings.join_concurrent_build ? std::min(max_streams, build_pipeline.streams.size()) : 1; + /// build side streams + executeExpression(build_pipeline, build_side_prepare_actions); // add a HashJoinBuildBlockInputStream to build a shared hash table - size_t concurrency_build_index = 0; - auto get_concurrency_build_index = [&concurrency_build_index, &join_build_concurrency]() { - return (concurrency_build_index++) % join_build_concurrency; - }; - right_pipeline.transform([&](auto & stream) { + auto get_concurrency_build_index = JoinInterpreterHelper::concurrencyBuildIndexGenerator(join_build_concurrency); + build_pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, join_ptr, get_concurrency_build_index(), log->identifier()); }); - executeUnion(right_pipeline, max_streams, log, /*ignore_block=*/true); + executeUnion(build_pipeline, max_streams, log, /*ignore_block=*/true); - right_query.source = right_pipeline.firstStream(); + right_query.source = build_pipeline.firstStream(); right_query.join = join_ptr; - right_query.join->setSampleBlock(right_query.source->getHeader()); + join_ptr->init(right_query.source->getHeader(), join_build_concurrency); + /// probe side streams + executeExpression(probe_pipeline, probe_side_prepare_actions); NamesAndTypes source_columns; - for (const auto & p : left_pipeline.streams[0]->getHeader().getNamesAndTypesList()) + for (const auto & p : probe_pipeline.firstStream()->getHeader()) source_columns.emplace_back(p.name, p.type); DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); ExpressionActionsChain chain; dag_analyzer.appendJoin(chain, right_query, columns_added_by_join); - pipeline.streams = left_pipeline.streams; + pipeline.streams = probe_pipeline.streams; /// add join input stream if (is_tiflash_right_join) { auto & join_execute_info = dagContext().getJoinExecuteInfoMap()[query_block.source_name]; - for (size_t i = 0; i < join_build_concurrency; i++) + size_t not_joined_concurrency = join_ptr->getNotJoinedStreamConcurrency(); + for (size_t i = 0; i < not_joined_concurrency; ++i) { - auto non_joined_stream = chain.getLastActions()->createStreamWithNonJoinedDataIfFullOrRightJoin( + auto non_joined_stream = join_ptr->createStreamWithNonJoinedRows( pipeline.firstStream()->getHeader(), i, - join_build_concurrency, + not_joined_concurrency, settings.max_block_size); pipeline.streams_with_non_joined_data.push_back(non_joined_stream); join_execute_info.non_joined_streams.push_back(non_joined_stream); diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h index 69bac9c3ba9..9b95a5c3e93 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h @@ -61,25 +61,11 @@ class DAGQueryBlockInterpreter void handleMockTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline); void handleTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline); void handleJoin(const tipb::Join & join, DAGPipeline & pipeline, SubqueryForSet & right_query); - void prepareJoin( - const google::protobuf::RepeatedPtrField & keys, - const DataTypes & key_types, - DAGPipeline & pipeline, - Names & key_names, - bool left, - bool is_right_out_join, - const google::protobuf::RepeatedPtrField & filters, - String & filter_column_name); void handleExchangeReceiver(DAGPipeline & pipeline); void handleMockExchangeReceiver(DAGPipeline & pipeline); void handleProjection(DAGPipeline & pipeline, const tipb::Projection & projection); void handleWindow(DAGPipeline & pipeline, const tipb::Window & window); void handleWindowOrder(DAGPipeline & pipeline, const tipb::Sort & window_sort); - ExpressionActionsPtr genJoinOtherConditionAction( - const tipb::Join & join, - NamesAndTypes & source_columns, - String & filter_column_for_other_condition, - String & filter_column_for_other_eq_condition); void executeWhere(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr, String & filter_column); void executeExpression(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr); void executeWindowOrder(DAGPipeline & pipeline, SortDescription sort_desc); diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp new file mode 100644 index 00000000000..2582a84ac46 --- /dev/null +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp @@ -0,0 +1,356 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB::JoinInterpreterHelper +{ +namespace +{ +std::pair getJoinKindAndBuildSideIndex(const tipb::Join & join) +{ + static const std::unordered_map equal_join_type_map{ + {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Inner}, + {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Left}, + {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Right}, + {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Inner}, + {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Anti}, + {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::LeftSemi}, + {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::LeftAnti}}; + static const std::unordered_map cartesian_join_type_map{ + {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Cross}, + {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Cross_Left}, + {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Cross_Right}, + {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Cross}, + {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Cross_Anti}, + {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftSemi}, + {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftAnti}}; + + const auto & join_type_map = join.left_join_keys_size() == 0 ? cartesian_join_type_map : equal_join_type_map; + auto join_type_it = join_type_map.find(join.join_type()); + if (unlikely(join_type_it == join_type_map.end())) + throw TiFlashException("Unknown join type in dag request", Errors::Coprocessor::BadRequest); + + ASTTableJoin::Kind kind = join_type_it->second; + + /// in DAG request, inner part is the build side, however for TiFlash implementation, + /// the build side must be the right side, so need to swap the join side if needed + /// 1. for (cross) inner join, there is no problem in this swap. + /// 2. for (cross) semi/anti-semi join, the build side is always right, needn't swap. + /// 3. for non-cross left/right join, there is no problem in this swap. + /// 4. for cross left join, the build side is always right, needn't and can't swap. + /// 5. for cross right join, the build side is always left, so it will always swap and change to cross left join. + /// note that whatever the build side is, we can't support cross-right join now. + + size_t build_side_index = 0; + switch (kind) + { + case ASTTableJoin::Kind::Cross_Right: + build_side_index = 0; + break; + case ASTTableJoin::Kind::Cross_Left: + build_side_index = 1; + break; + default: + build_side_index = join.inner_idx(); + } + assert(build_side_index == 0 || build_side_index == 1); + + // should swap join side. + if (build_side_index != 1) + { + switch (kind) + { + case ASTTableJoin::Kind::Left: + kind = ASTTableJoin::Kind::Right; + break; + case ASTTableJoin::Kind::Right: + kind = ASTTableJoin::Kind::Left; + break; + case ASTTableJoin::Kind::Cross_Right: + kind = ASTTableJoin::Kind::Cross_Left; + default:; // just `default`, for other kinds, don't need to change kind. + } + } + + return {kind, build_side_index}; +} + +DataTypes getJoinKeyTypes(const tipb::Join & join) +{ + if (unlikely(join.left_join_keys_size() != join.right_join_keys_size())) + throw TiFlashException("size of join.left_join_keys != size of join.right_join_keys", Errors::Coprocessor::BadRequest); + DataTypes key_types; + for (int i = 0; i < join.left_join_keys_size(); ++i) + { + if (unlikely(!exprHasValidFieldType(join.left_join_keys(i)) || !exprHasValidFieldType(join.right_join_keys(i)))) + throw TiFlashException("Join key without field type", Errors::Coprocessor::BadRequest); + DataTypes types; + types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.left_join_keys(i).field_type())); + types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.right_join_keys(i).field_type())); + DataTypePtr common_type = getLeastSupertype(types); + key_types.emplace_back(common_type); + } + return key_types; +} + +TiDB::TiDBCollators getJoinKeyCollators(const tipb::Join & join, const DataTypes & join_key_types) +{ + TiDB::TiDBCollators collators; + size_t join_key_size = join_key_types.size(); + if (join.probe_types_size() == static_cast(join_key_size) && join.build_types_size() == join.probe_types_size()) + for (size_t i = 0; i < join_key_size; ++i) + { + if (removeNullable(join_key_types[i])->isString()) + { + if (unlikely(join.probe_types(i).collate() != join.build_types(i).collate())) + throw TiFlashException("Join with different collators on the join key", Errors::Coprocessor::BadRequest); + collators.push_back(getCollatorFromFieldType(join.probe_types(i))); + } + else + collators.push_back(nullptr); + } + return collators; +} + +std::tuple doGenJoinOtherConditionAction( + const Context & context, + const tipb::Join & join, + const NamesAndTypes & source_columns) +{ + if (join.other_conditions_size() == 0 && join.other_eq_conditions_from_in_size() == 0) + return {nullptr, "", ""}; + + DAGExpressionAnalyzer dag_analyzer(source_columns, context); + ExpressionActionsChain chain; + + String filter_column_for_other_condition; + if (join.other_conditions_size() > 0) + { + std::vector condition_vector; + for (const auto & c : join.other_conditions()) + { + condition_vector.push_back(&c); + } + filter_column_for_other_condition = dag_analyzer.appendWhere(chain, condition_vector); + } + + String filter_column_for_other_eq_condition; + if (join.other_eq_conditions_from_in_size() > 0) + { + std::vector condition_vector; + for (const auto & c : join.other_eq_conditions_from_in()) + { + condition_vector.push_back(&c); + } + filter_column_for_other_eq_condition = dag_analyzer.appendWhere(chain, condition_vector); + } + + return {chain.getLastActions(), std::move(filter_column_for_other_condition), std::move(filter_column_for_other_eq_condition)}; +} +} // namespace + +TiFlashJoin::TiFlashJoin(const tipb::Join & join_) // NOLINT(cppcoreguidelines-pro-type-member-init) + : join(join_) + , join_key_types(getJoinKeyTypes(join_)) + , join_key_collators(getJoinKeyCollators(join_, join_key_types)) +{ + std::tie(kind, build_side_index) = getJoinKindAndBuildSideIndex(join); + strictness = isSemiJoin() ? ASTTableJoin::Strictness::Any : ASTTableJoin::Strictness::All; +} + +String TiFlashJoin::genMatchHelperName(const Block & header1, const Block & header2) const +{ + if (!isLeftSemiFamily()) + { + return ""; + } + + size_t i = 0; + String match_helper_name = fmt::format("{}{}", Join::match_helper_prefix, i); + while (header1.has(match_helper_name) || header2.has(match_helper_name)) + { + match_helper_name = fmt::format("{}{}", Join::match_helper_prefix, ++i); + } + return match_helper_name; +} + +NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_prepare_join_actions) const +{ +#ifndef NDEBUG + auto is_prepare_actions_valid = [](const Block & origin_block, const ExpressionActionsPtr & prepare_actions) { + const Block & prepare_sample_block = prepare_actions->getSampleBlock(); + for (const auto & p : origin_block) + { + if (!prepare_sample_block.has(p.name)) + return false; + } + return true; + }; + if (unlikely(!is_prepare_actions_valid(build_side_index == 1 ? left_input_header : right_input_header, probe_prepare_join_actions))) + { + throw TiFlashException("probe_prepare_join_actions isn't valid", Errors::Coprocessor::Internal); + } +#endif + + /// columns_for_other_join_filter is a vector of columns used + /// as the input columns when compiling other join filter. + /// Note the order in the column vector is very important: + /// first the columns in left_input_header, then followed + /// by the columns in right_input_header, if there are other + /// columns generated before compile other join filter, then + /// append the extra columns afterwards. In order to figure out + /// whether a given column is already in the column vector or + /// not quickly, we use another set to store the column names. + + /// The order of columns must be {left_input, right_input, extra columns}, + /// because tidb requires the input schema of join to be {left_input, right_input}. + /// Extra columns are appended to prevent extra columns from being repeatedly generated. + + NamesAndTypes columns_for_other_join_filter; + std::unordered_set column_set_for_origin_columns; + + auto append_origin_columns = [&columns_for_other_join_filter, &column_set_for_origin_columns](const Block & header, bool make_nullable) { + for (const auto & p : header) + { + columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + column_set_for_origin_columns.emplace(p.name); + } + }; + append_origin_columns(left_input_header, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + append_origin_columns(right_input_header, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + + /// append the columns generated by probe side prepare join actions. + /// the new columns are + /// - filter_column and related temporary columns + /// - join keys and related temporary columns + auto append_new_columns = [&columns_for_other_join_filter, &column_set_for_origin_columns](const Block & header, bool make_nullable) { + for (const auto & p : header) + { + if (column_set_for_origin_columns.find(p.name) == column_set_for_origin_columns.end()) + columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + } + }; + bool make_nullable = build_side_index == 1 + ? join.join_type() == tipb::JoinType::TypeRightOuterJoin + : join.join_type() == tipb::JoinType::TypeLeftOuterJoin; + append_new_columns(probe_prepare_join_actions->getSampleBlock(), make_nullable); + + return columns_for_other_join_filter; +} + +/// all the columns from build side streams should be added after join, even for the join key. +NamesAndTypesList TiFlashJoin::genColumnsAddedByJoin( + const Block & build_side_header, + const String & match_helper_name) const +{ + NamesAndTypesList columns_added_by_join; + bool make_nullable = isTiFlashLeftJoin(); + for (auto const & p : build_side_header) + { + columns_added_by_join.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + } + if (!match_helper_name.empty()) + { + columns_added_by_join.emplace_back(match_helper_name, Join::match_helper_type); + } + return columns_added_by_join; +} + +NamesAndTypes TiFlashJoin::genJoinOutputColumns( + const Block & left_input_header, + const Block & right_input_header, + const String & match_helper_name) const +{ + NamesAndTypes join_output_columns; + auto append_output_columns = [&join_output_columns](const Block & header, bool make_nullable) { + for (auto const & p : header) + { + join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + } + }; + + append_output_columns(left_input_header, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + if (!isSemiJoin()) + { + /// for semi join, the columns from right table will be ignored + append_output_columns(right_input_header, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + } + + if (!match_helper_name.empty()) + { + join_output_columns.emplace_back(match_helper_name, Join::match_helper_type); + } + + return join_output_columns; +} + +std::tuple TiFlashJoin::genJoinOtherConditionAction( + const Context & context, + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_side_prepare_join) const +{ + auto columns_for_other_join_filter + = genColumnsForOtherJoinFilter( + left_input_header, + right_input_header, + probe_side_prepare_join); + + return doGenJoinOtherConditionAction(context, join, columns_for_other_join_filter); +} + +std::tuple prepareJoin( + const Context & context, + const Block & input_header, + const google::protobuf::RepeatedPtrField & keys, + const DataTypes & key_types, + bool left, + bool is_right_out_join, + const google::protobuf::RepeatedPtrField & filters) +{ + NamesAndTypes source_columns; + for (auto const & p : input_header) + source_columns.emplace_back(p.name, p.type); + DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); + ExpressionActionsChain chain; + Names key_names; + String filter_column_name; + dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, key_types, key_names, left, is_right_out_join, filters, filter_column_name); + return {chain.getLastActions(), std::move(key_names), std::move(filter_column_name)}; +} + +std::function concurrencyBuildIndexGenerator(size_t join_build_concurrency) +{ + size_t init_value = 0; + return [init_value, join_build_concurrency]() mutable { + return (init_value++) % join_build_concurrency; + }; +} +} // namespace DB::JoinInterpreterHelper \ No newline at end of file diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h new file mode 100644 index 00000000000..d84c03d572d --- /dev/null +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h @@ -0,0 +1,133 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ +class Context; + +namespace JoinInterpreterHelper +{ +struct TiFlashJoin +{ + explicit TiFlashJoin(const tipb::Join & join_); + + const tipb::Join & join; + + ASTTableJoin::Kind kind; + size_t build_side_index = 0; + + DataTypes join_key_types; + TiDB::TiDBCollators join_key_collators; + + ASTTableJoin::Strictness strictness; + + /// (cartesian) (anti) left semi join. + bool isLeftSemiFamily() const { return join.join_type() == tipb::JoinType::TypeLeftOuterSemiJoin || join.join_type() == tipb::JoinType::TypeAntiLeftOuterSemiJoin; } + + bool isSemiJoin() const { return join.join_type() == tipb::JoinType::TypeSemiJoin || join.join_type() == tipb::JoinType::TypeAntiSemiJoin || isLeftSemiFamily(); } + + const google::protobuf::RepeatedPtrField & getBuildJoinKeys() const + { + return build_side_index == 1 ? join.right_join_keys() : join.left_join_keys(); + } + + const google::protobuf::RepeatedPtrField & getProbeJoinKeys() const + { + return build_side_index == 0 ? join.right_join_keys() : join.left_join_keys(); + } + + const google::protobuf::RepeatedPtrField & getBuildConditions() const + { + return build_side_index == 1 ? join.right_conditions() : join.left_conditions(); + } + + const google::protobuf::RepeatedPtrField & getProbeConditions() const + { + return build_side_index == 0 ? join.right_conditions() : join.left_conditions(); + } + + bool isTiFlashLeftJoin() const { return kind == ASTTableJoin::Kind::Left || kind == ASTTableJoin::Kind::Cross_Left; } + + /// Cross_Right join will be converted to Cross_Left join, so no need to check Cross_Right + bool isTiFlashRightJoin() const { return kind == ASTTableJoin::Kind::Right; } + + /// return a name that is unique in header1 and header2 for left semi family join, + /// return "" for everything else. + String genMatchHelperName(const Block & header1, const Block & header2) const; + + /// columns_added_by_join + /// = join_output_columns - probe_side_columns + /// = build_side_columns + match_helper_name + NamesAndTypesList genColumnsAddedByJoin( + const Block & build_side_header, + const String & match_helper_name) const; + + /// The columns output by join will be: + /// {columns of left_input, columns of right_input, match_helper_name} + NamesAndTypes genJoinOutputColumns( + const Block & left_input_header, + const Block & right_input_header, + const String & match_helper_name) const; + + /// @other_condition_expr: generates other_filter_column and other_eq_filter_from_in_column + /// @other_filter_column_name: column name of `and(other_cond1, other_cond2, ...)` + /// @other_eq_filter_from_in_column_name: column name of `and(other_eq_cond1_from_in, other_eq_cond2_from_in, ...)` + /// such as + /// `select * from t where col1 in (select col2 from t2 where t1.col2 = t2.col3)` + /// - other_filter is `t1.col2 = t2.col3` + /// - other_eq_filter_from_in_column is `t1.col1 = t2.col2` + /// + /// new columns from build side prepare join actions cannot be appended. + /// because the input that other filter accepts is + /// {left_input_columns, right_input_columns, new_columns_from_probe_side_prepare, match_helper_name}. + std::tuple genJoinOtherConditionAction( + const Context & context, + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_side_prepare_join) const; + + NamesAndTypes genColumnsForOtherJoinFilter( + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_prepare_join_actions) const; +}; + +/// @join_prepare_expr_actions: generates join key columns and join filter column +/// @key_names: column names of keys. +/// @filter_column_name: column name of `and(filters)` +std::tuple prepareJoin( + const Context & context, + const Block & input_header, + const google::protobuf::RepeatedPtrField & keys, + const DataTypes & key_types, + bool left, + bool is_right_out_join, + const google::protobuf::RepeatedPtrField & filters); + +std::function concurrencyBuildIndexGenerator(size_t join_build_concurrency); +} // namespace JoinInterpreterHelper +} // namespace DB diff --git a/dbms/src/Flash/tests/exchange_perftest.cpp b/dbms/src/Flash/tests/exchange_perftest.cpp index 45dbac4a7f6..c2e047bec62 100644 --- a/dbms/src/Flash/tests/exchange_perftest.cpp +++ b/dbms/src/Flash/tests/exchange_perftest.cpp @@ -462,7 +462,7 @@ struct ReceiverHelper SizeLimits(0, 0, OverflowMode::THROW), ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All, - concurrency, + /*req_id=*/"", TiDB::TiDBCollators{nullptr}, "", "", @@ -471,7 +471,7 @@ struct ReceiverHelper nullptr, 65536); - join_ptr->setSampleBlock(receiver_header); + join_ptr->init(receiver_header, concurrency); for (int i = 0; i < concurrency; ++i) streams[i] = std::make_shared(streams[i], join_ptr, i, /*req_id=*/""); diff --git a/dbms/src/Functions/FunctionsTiDBConversion.cpp b/dbms/src/Functions/FunctionsTiDBConversion.cpp index 75c015c4bad..74daca2b7fe 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.cpp +++ b/dbms/src/Functions/FunctionsTiDBConversion.cpp @@ -46,4 +46,19 @@ void registerFunctionsTiDBConversion(FunctionFactory & factory) factory.registerFunction(); } +FunctionBasePtr FunctionBuilderTiDBCast::buildImpl( + const ColumnsWithTypeAndName & arguments, + const DataTypePtr & return_type, + const TiDB::TiDBCollatorPtr &) const +{ + DataTypes data_types(arguments.size()); + + for (size_t i = 0; i < arguments.size(); ++i) + data_types[i] = arguments[i].type; + + auto monotonicity = getMonotonicityInformation(arguments.front().type, return_type.get()); + return std::make_shared>(context, name, std::move(monotonicity), data_types, return_type, in_union, tidb_tp); +} + + } // namespace DB diff --git a/dbms/src/Functions/FunctionsTiDBConversion.h b/dbms/src/Functions/FunctionsTiDBConversion.h index 30251aac36d..bcd7856ee71 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.h +++ b/dbms/src/Functions/FunctionsTiDBConversion.h @@ -1743,6 +1743,7 @@ inline bool numberToDateTime(Int64 number, MyDateTime & result, DAGContext * ctx return getDatetime(number, result, ctx); } +template class ExecutableFunctionTiDBCast : public IExecutableFunction { public: @@ -1782,13 +1783,15 @@ class ExecutableFunctionTiDBCast : public IExecutableFunction const Context & context; }; +using MonotonicityForRange = std::function; + /// FunctionTiDBCast implements SQL cast function in TiDB /// The basic idea is to dispatch according to combinations of parameter types +template class FunctionTiDBCast final : public IFunctionBase { public: using WrapperType = std::function; - using MonotonicityForRange = std::function; FunctionTiDBCast(const Context & context, const char * name, MonotonicityForRange && monotonicity_for_range, const DataTypes & argument_types, const DataTypePtr & return_type, bool in_union_, const tipb::FieldType & tidb_tp_) : context(context) @@ -1805,7 +1808,7 @@ class FunctionTiDBCast final : public IFunctionBase ExecutableFunctionPtr prepare(const Block & /*sample_block*/) const override { - return std::make_shared( + return std::make_shared>( prepare(getArgumentTypes()[0], getReturnType()), name, in_union, @@ -2341,8 +2344,6 @@ class FunctionTiDBCast final : public IFunctionBase class FunctionBuilderTiDBCast : public IFunctionBuilder { public: - using MonotonicityForRange = FunctionTiDBCast::MonotonicityForRange; - static constexpr auto name = "tidb_cast"; static FunctionBuilderPtr create(const Context & context) { @@ -2369,16 +2370,7 @@ class FunctionBuilderTiDBCast : public IFunctionBuilder FunctionBasePtr buildImpl( const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, - const TiDB::TiDBCollatorPtr &) const override - { - DataTypes data_types(arguments.size()); - - for (size_t i = 0; i < arguments.size(); ++i) - data_types[i] = arguments[i].type; - - auto monotonicity = getMonotonicityInformation(arguments.front().type, return_type.get()); - return std::make_shared(context, name, std::move(monotonicity), data_types, return_type, in_union, tidb_tp); - } + const TiDB::TiDBCollatorPtr &) const override; // use the last const string column's value as the return type name, in string representation like "Float64" DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override diff --git a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp index d67ef49e108..5f885c2716f 100644 --- a/dbms/src/Functions/tests/gtest_tidb_conversion.cpp +++ b/dbms/src/Functions/tests/gtest_tidb_conversion.cpp @@ -1474,76 +1474,76 @@ TEST_F(TestTidbConversion, skipCheckOverflowIntToDeciaml) const ScaleType scale = 0; // int8(max_prec: 3) -> decimal32(max_prec: 9) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal32, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal32, scale)); // int16(max_prec: 5) -> decimal32(max_prec: 9) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal32, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal32, scale)); // int32(max_prec: 10) -> decimal32(max_prec: 9) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal32, scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal32, scale)); // int64(max_prec: 20) -> decimal32(max_prec: 9) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal32, scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal32, scale)); // uint8(max_prec: 3) -> decimal32(max_prec: 9) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal32, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal32, scale)); // uint16(max_prec: 5) -> decimal32(max_prec: 9) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal32, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal32, scale)); // uint32(max_prec: 10) -> decimal32(max_prec: 9) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal32, scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal32, scale)); // uint64(max_prec: 20) -> decimal32(max_prec: 9) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal32, scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal32, scale)); // int8(max_prec: 3) -> decimal64(max_prec: 18) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal64, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal64, scale)); // int16(max_prec: 5) -> decimal64(max_prec: 18) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal64, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal64, scale)); // int32(max_prec: 10) -> decimal64(max_prec: 18) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal64, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal64, scale)); // int64(max_prec: 20) -> decimal64(max_prec: 18) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal64, scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal64, scale)); // uint8(max_prec: 3) -> decimal64(max_prec: 18) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal64, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal64, scale)); // uint16(max_prec: 5) -> decimal64(max_prec: 18) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal64, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal64, scale)); // uint32(max_prec: 10) -> decimal64(max_prec: 18) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal64, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal64, scale)); // uint64(max_prec: 20) -> decimal64(max_prec: 18) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal64, scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal64, scale)); // int8(max_prec: 3) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal128, scale)); // int16(max_prec: 5) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal128, scale)); // int32(max_prec: 10) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal128, scale)); // int64(max_prec: 20) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal128, scale)); // uint8(max_prec: 3) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal128, scale)); // uint16(max_prec: 5) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal128, scale)); // uint32(max_prec: 10) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal128, scale)); // uint64(max_prec: 20) -> decimal128(max_prec: 38) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal128, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal128, scale)); // int8(max_prec: 3) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int8_ptr, prec_decimal256, scale)); // int16(max_prec: 5) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int16_ptr, prec_decimal256, scale)); // int32(max_prec: 10) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int32_ptr, prec_decimal256, scale)); // int64(max_prec: 20) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int64_ptr, prec_decimal256, scale)); // uint8(max_prec: 3) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint8_ptr, prec_decimal256, scale)); // uint16(max_prec: 5) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint16_ptr, prec_decimal256, scale)); // uint32(max_prec: 10) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint32_ptr, prec_decimal256, scale)); // uint64(max_prec: 20) -> decimal256(max_prec: 65) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal256, scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(uint64_ptr, prec_decimal256, scale)); } TEST_F(TestTidbConversion, skipCheckOverflowDecimalToDeciaml) @@ -1551,24 +1551,24 @@ TEST_F(TestTidbConversion, skipCheckOverflowDecimalToDeciaml) DataTypePtr decimal32_ptr_8_3 = createDecimal(8, 3); DataTypePtr decimal32_ptr_8_2 = createDecimal(8, 2); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 8, 3)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_3, 8, 2)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 7, 5)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 8, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_8_3, 8, 2)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 7, 5)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 9, 3)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 9, 1)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 9, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_8_2, 9, 1)); DataTypePtr decimal32_ptr_6_4 = createDecimal(6, 4); // decimal(6, 4) -> decimal(5, 3) // because select cast(99.9999 as decimal(5, 3)); -> 100.000 is greater than 99.999. - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 5, 3)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 5, 3)); // decimal(6, 4) -> decimal(7, 5) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 7, 5)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 7, 5)); // decimal(6, 4) -> decimal(6, 5) - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 6, 5)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 6, 5)); // decimal(6, 4) -> decimal(8, 5) - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 8, 5)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(decimal32_ptr_6_4, 8, 5)); } TEST_F(TestTidbConversion, skipCheckOverflowEnumToDecimal) @@ -1583,15 +1583,15 @@ TEST_F(TestTidbConversion, skipCheckOverflowEnumToDecimal) enum16_values.push_back({"b1", 2000}); DataTypePtr enum16_ptr = std::make_shared(enum16_values); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 3, 0)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 4, 1)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 2, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum8_ptr, 4, 2)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum8_ptr, 3, 0)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum8_ptr, 4, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum8_ptr, 2, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum8_ptr, 4, 2)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 5, 0)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 6, 1)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 4, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(enum16_ptr, 6, 2)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum16_ptr, 5, 0)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum16_ptr, 6, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum16_ptr, 4, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(enum16_ptr, 6, 2)); } TEST_F(TestTidbConversion, skipCheckOverflowMyDateTimeToDeciaml) @@ -1600,18 +1600,18 @@ TEST_F(TestTidbConversion, skipCheckOverflowMyDateTimeToDeciaml) DataTypePtr datetime_ptr_fsp_5 = std::make_shared(5); // rule for no fsp: 14 + to_scale <= to_prec. - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 5, 3)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 18, 3)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 17, 3)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 18, 4)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 14, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 14, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 5, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 18, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 17, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 18, 4)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 14, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_no_fsp, 14, 1)); // rule for fsp: 20 + scale_diff <= to_prec. // 20 + (3 - 6 + 1) = 18 - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 19, 3)); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 18, 3)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 17, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 19, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 18, 3)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(datetime_ptr_fsp_5, 17, 3)); } TEST_F(TestTidbConversion, skipCheckOverflowMyDateToDeciaml) @@ -1619,30 +1619,30 @@ TEST_F(TestTidbConversion, skipCheckOverflowMyDateToDeciaml) DataTypePtr date_ptr = std::make_shared(); // rule: 8 + to_scale <= to_prec. - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(date_ptr, 11, 3)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(date_ptr, 11, 4)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(date_ptr, 10, 3)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(date_ptr, 11, 3)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(date_ptr, 11, 4)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(date_ptr, 10, 3)); } TEST_F(TestTidbConversion, skipCheckOverflowOtherToDecimal) { // float and string not support skip overflow check. DataTypePtr string_ptr = std::make_shared(); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(string_ptr, 1, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(string_ptr, 60, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(string_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(string_ptr, 60, 1)); DataTypePtr float32_ptr = std::make_shared(); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float32_ptr, 1, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float32_ptr, 60, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(float32_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(float32_ptr, 60, 1)); DataTypePtr float64_ptr = std::make_shared(); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float64_ptr, 1, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(float64_ptr, 60, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(float64_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(float64_ptr, 60, 1)); // cast duration to decimal is not supported to push down to tiflash for now. DataTypePtr duration_ptr = std::make_shared(); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(duration_ptr, 1, 0)); - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(duration_ptr, 60, 1)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(duration_ptr, 1, 0)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(duration_ptr, 60, 1)); } // check if template argument of CastInternalType is correct or not. @@ -1654,7 +1654,7 @@ try ScaleType to_scale = 3; DataTypePtr int8_ptr = std::make_shared(); // from_prec(3) + to_scale(3) <= Decimal32::prec(9), so we **CAN** skip check overflow. - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, to_prec, to_scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int8_ptr, to_prec, to_scale)); // from_prec(3) + to_scale(3) <= Int32::real_prec(10) - 1, so CastInternalType should be **Int32**. ASSERT_COLUMN_EQ( @@ -1669,7 +1669,7 @@ try to_prec = 9; to_scale = 7; // from_prec(3) + to_scale(7) > Decimal32::prec(9), so we **CANNOT** skip check overflow. - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int8_ptr, to_prec, to_scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int8_ptr, to_prec, to_scale)); // from_prec(3) + to_scale(7) > Int32::real_prec(10) - 1, so CastInternalType should be **Int64**. DAGContext * dag_context = context.getDAGContext(); @@ -1690,7 +1690,7 @@ try to_prec = 40; to_scale = 20; DataTypePtr int64_ptr = std::make_shared(); - ASSERT_TRUE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, to_prec, to_scale)); + ASSERT_TRUE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int64_ptr, to_prec, to_scale)); // from_prec(19) + to_scale(20) > Int128::real_prec(39) - 1, so CastInternalType should be **Int256**. ASSERT_COLUMN_EQ( @@ -1705,7 +1705,7 @@ try // from_prec(19) + to_scale(20) > Decimal256::prec(38), so we **CANNOT** skip check overflow. to_prec = 38; to_scale = 20; - ASSERT_FALSE(FunctionTiDBCast::canSkipCheckOverflowForDecimal(int64_ptr, to_prec, to_scale)); + ASSERT_FALSE(FunctionTiDBCast<>::canSkipCheckOverflowForDecimal(int64_ptr, to_prec, to_scale)); // from_prec(19) + to_scale(20) > Int128::real_prec(39) - 1, so CastInternalType should be **Int256**. ASSERT_COLUMN_EQ( diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index eb07e2d541e..a532ed8a8e0 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -247,16 +247,16 @@ void ExpressionAnalyzer::translateQualifiedNames() if (!select_query || !select_query->tables || select_query->tables->children.empty()) return; - ASTTablesInSelectQueryElement & element = static_cast(*select_query->tables->children[0]); + auto & element = static_cast(*select_query->tables->children[0]); if (!element.table_expression) /// This is ARRAY JOIN without a table at the left side. return; - ASTTableExpression & table_expression = static_cast(*element.table_expression); + auto & table_expression = static_cast(*element.table_expression); if (table_expression.database_and_table_name) { - const ASTIdentifier & identifier = static_cast(*table_expression.database_and_table_name); + const auto & identifier = static_cast(*table_expression.database_and_table_name); alias = identifier.tryGetAlias(); @@ -291,7 +291,7 @@ void ExpressionAnalyzer::translateQualifiedNames() void ExpressionAnalyzer::translateQualifiedNamesImpl(ASTPtr & ast, const String & database_name, const String & table_name, const String & alias) { - if (ASTIdentifier * ident = typeid_cast(ast.get())) + if (auto * ident = typeid_cast(ast.get())) { if (ident->kind == ASTIdentifier::Column) { @@ -352,7 +352,7 @@ void ExpressionAnalyzer::translateQualifiedNamesImpl(ASTPtr & ast, const String if (ast->children.size() != 1) throw Exception("Logical error: qualified asterisk must have exactly one child", ErrorCodes::LOGICAL_ERROR); - ASTIdentifier * ident = typeid_cast(ast->children[0].get()); + auto * ident = typeid_cast(ast->children[0].get()); if (!ident) throw Exception("Logical error: qualified asterisk must have identifier as its child", ErrorCodes::LOGICAL_ERROR); @@ -396,7 +396,7 @@ void ExpressionAnalyzer::optimizeIfWithConstantCondition() bool ExpressionAnalyzer::tryExtractConstValueFromCondition(const ASTPtr & condition, bool & value) const { /// numeric constant in condition - if (const ASTLiteral * literal = typeid_cast(condition.get())) + if (const auto * literal = typeid_cast(condition.get())) { if (literal->value.getType() == Field::Types::Int64 || literal->value.getType() == Field::Types::UInt64) { @@ -406,14 +406,14 @@ bool ExpressionAnalyzer::tryExtractConstValueFromCondition(const ASTPtr & condit } /// cast of numeric constant in condition to UInt8 - if (const ASTFunction * function = typeid_cast(condition.get())) + if (const auto * function = typeid_cast(condition.get())) { if (function->name == "CAST") { - if (ASTExpressionList * expr_list = typeid_cast(function->arguments.get())) + if (auto * expr_list = typeid_cast(function->arguments.get())) { const ASTPtr & type_ast = expr_list->children.at(1); - if (const ASTLiteral * type_literal = typeid_cast(type_ast.get())) + if (const auto * type_literal = typeid_cast(type_ast.get())) { if (type_literal->value.getType() == Field::Types::String && type_literal->value.get() == "UInt8") return tryExtractConstValueFromCondition(expr_list->children.at(0), value); @@ -432,7 +432,7 @@ void ExpressionAnalyzer::optimizeIfWithConstantConditionImpl(ASTPtr & current_as for (ASTPtr & child : current_ast->children) { - ASTFunction * function_node = typeid_cast(child.get()); + auto * function_node = typeid_cast(child.get()); if (!function_node || function_node->name != "if") { optimizeIfWithConstantConditionImpl(child, aliases); @@ -440,7 +440,7 @@ void ExpressionAnalyzer::optimizeIfWithConstantConditionImpl(ASTPtr & current_as } optimizeIfWithConstantConditionImpl(function_node->arguments, aliases); - ASTExpressionList * args = typeid_cast(function_node->arguments.get()); + auto * args = typeid_cast(function_node->arguments.get()); ASTPtr condition_expr = args->children.at(0); ASTPtr then_expr = args->children.at(1); @@ -603,13 +603,13 @@ void ExpressionAnalyzer::initGlobalSubqueries(ASTPtr & ast) /// Bottom-up actions. - if (ASTFunction * node = typeid_cast(ast.get())) + if (auto * node = typeid_cast(ast.get())) { /// For GLOBAL IN. if (do_global && (node->name == "globalIn" || node->name == "globalNotIn")) addExternalStorage(node->arguments->children.at(1)); } - else if (ASTTablesInSelectQueryElement * node = typeid_cast(ast.get())) + else if (auto * node = typeid_cast(ast.get())) { /// For GLOBAL JOIN. if (do_global && node->table_join @@ -628,7 +628,7 @@ void ExpressionAnalyzer::findExternalTables(ASTPtr & ast) /// If table type identifier StoragePtr external_storage; - if (ASTIdentifier * node = typeid_cast(ast.get())) + if (auto * node = typeid_cast(ast.get())) if (node->kind == ASTIdentifier::Table) if ((external_storage = context.tryGetExternalTable(node->name))) external_tables[node->name] = external_storage; @@ -658,8 +658,8 @@ static std::shared_ptr interpretSubquery( const Names & required_source_columns) { /// Subquery or table name. The name of the table is similar to the subquery `SELECT * FROM t`. - const ASTSubquery * subquery = typeid_cast(subquery_or_table_name.get()); - const ASTIdentifier * table = typeid_cast(subquery_or_table_name.get()); + const auto * subquery = typeid_cast(subquery_or_table_name.get()); + const auto * table = typeid_cast(subquery_or_table_name.get()); if (!subquery && !table) throw Exception("IN/JOIN supports only SELECT subqueries.", ErrorCodes::BAD_ARGUMENTS); @@ -721,9 +721,9 @@ static std::shared_ptr interpretSubquery( std::set all_column_names; std::set assigned_column_names; - if (ASTSelectWithUnionQuery * select_with_union = typeid_cast(query.get())) + if (auto * select_with_union = typeid_cast(query.get())) { - if (ASTSelectQuery * select = typeid_cast(select_with_union->list_of_selects->children.at(0).get())) + if (auto * select = typeid_cast(select_with_union->list_of_selects->children.at(0).get())) { for (auto & expr : select->select_expression_list->children) all_column_names.insert(expr->getAliasOrColumnName()); @@ -973,7 +973,7 @@ void ExpressionAnalyzer::normalizeTreeImpl( { /// `IN t` can be specified, where t is a table, which is equivalent to `IN (SELECT * FROM t)`. if (functionIsInOrGlobalInOperator(func_node->name)) - if (ASTIdentifier * right = typeid_cast(func_node->arguments->children.at(1).get())) + if (auto * right = typeid_cast(func_node->arguments->children.at(1).get())) if (!aliases.count(right->name)) right->kind = ASTIdentifier::Table; @@ -1030,7 +1030,7 @@ void ExpressionAnalyzer::normalizeTreeImpl( } } } - else if (ASTExpressionList * node = typeid_cast(ast.get())) + else if (auto * node = typeid_cast(ast.get())) { // Get hidden column names of mutable storage OrderedNameSet filtered_names; @@ -1068,14 +1068,14 @@ void ExpressionAnalyzer::normalizeTreeImpl( } } } - else if (ASTTablesInSelectQueryElement * node = typeid_cast(ast.get())) + else if (auto * node = typeid_cast(ast.get())) { if (node->table_expression) { auto & database_and_table_name = static_cast(*node->table_expression).database_and_table_name; if (database_and_table_name) { - if (ASTIdentifier * right = typeid_cast(database_and_table_name.get())) + if (auto * right = typeid_cast(database_and_table_name.get())) { right->kind = ASTIdentifier::Table; } @@ -1127,7 +1127,7 @@ void ExpressionAnalyzer::normalizeTreeImpl( } /// If the WHERE clause or HAVING consists of a single alias, the reference must be replaced not only in children, but also in where_expression and having_expression. - if (ASTSelectQuery * select = typeid_cast(ast.get())) + if (auto * select = typeid_cast(ast.get())) { if (select->prewhere_expression) normalizeTreeImpl(select->prewhere_expression, finished_asts, current_asts, current_alias, level + 1); @@ -1211,7 +1211,7 @@ void ExpressionAnalyzer::executeScalarSubqueriesImpl(ASTPtr & ast) * The request is sent to remote servers with already substituted constants. */ - if (ASTSubquery * subquery = typeid_cast(ast.get())) + if (auto * subquery = typeid_cast(ast.get())) { Context subquery_context = context; Settings subquery_settings = context.getSettings(); @@ -1283,7 +1283,7 @@ void ExpressionAnalyzer::executeScalarSubqueriesImpl(ASTPtr & ast) /** Don't descend into subqueries in arguments of IN operator. * But if an argument is not subquery, than deeper may be scalar subqueries and we need to descend in them. */ - ASTFunction * func = typeid_cast(ast.get()); + auto * func = typeid_cast(ast.get()); if (func && functionIsInOrGlobalInOperator(func->name)) { @@ -1424,7 +1424,7 @@ void ExpressionAnalyzer::optimizeOrderBy() for (const auto & elem : elems) { String name = elem->children.front()->getColumnName(); - const ASTOrderByElement & order_by_elem = typeid_cast(*elem); + const auto & order_by_elem = typeid_cast(*elem); if (elems_set.emplace(name, order_by_elem.collation ? order_by_elem.collation->getColumnName() : "").second) unique_elems.emplace_back(elem); @@ -1496,14 +1496,14 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node, const Block & continue; /// Don't dive into lambda functions - const ASTFunction * func = typeid_cast(child.get()); + const auto * func = typeid_cast(child.get()); if (func && func->name == "lambda") continue; makeSetsForIndexImpl(child, sample_block); } - const ASTFunction * func = typeid_cast(node.get()); + const auto * func = typeid_cast(node.get()); if (func && functionIsInOperator(func->name)) { const IAST & args = *func->arguments; @@ -1551,7 +1551,7 @@ void ExpressionAnalyzer::makeSet(const ASTFunction * node, const Block & sample_ return; /// If the subquery or table name for SELECT. - const ASTIdentifier * identifier = typeid_cast(arg.get()); + const auto * identifier = typeid_cast(arg.get()); if (typeid_cast(arg.get()) || identifier) { /// We get the stream of blocks for the subquery. Create Set and put it in place of the subquery. @@ -1566,7 +1566,7 @@ void ExpressionAnalyzer::makeSet(const ASTFunction * node, const Block & sample_ if (table) { - StorageSet * storage_set = dynamic_cast(table.get()); + auto * storage_set = dynamic_cast(table.get()); if (storage_set) { @@ -1650,7 +1650,7 @@ void ExpressionAnalyzer::makeExplicitSet(const ASTFunction * node, const Block & DataTypes set_element_types; const ASTPtr & left_arg = args.children.at(0); - const ASTFunction * left_arg_tuple = typeid_cast(left_arg.get()); + const auto * left_arg_tuple = typeid_cast(left_arg.get()); /** NOTE If tuple in left hand side specified non-explicitly * Example: identity((a, b)) IN ((1, 2), (3, 4)) @@ -1672,7 +1672,7 @@ void ExpressionAnalyzer::makeExplicitSet(const ASTFunction * node, const Block & bool single_value = false; ASTPtr elements_ast = arg; - if (ASTFunction * set_func = typeid_cast(arg.get())) + if (auto * set_func = typeid_cast(arg.get())) { if (set_func->name == "tuple") { @@ -1684,7 +1684,7 @@ void ExpressionAnalyzer::makeExplicitSet(const ASTFunction * node, const Block & else { /// Distinguish the case `(x, y) in ((1, 2), (3, 4))` from the case `(x, y) in (1, 2)`. - ASTFunction * any_element = typeid_cast(set_func->arguments->children.at(0).get()); + auto * any_element = typeid_cast(set_func->arguments->children.at(0).get()); if (set_element_types.size() >= 2 && (!any_element || any_element->name != "tuple")) single_value = true; else @@ -1902,7 +1902,7 @@ void ExpressionAnalyzer::getArrayJoinedColumnsImpl(const ASTPtr & ast) if (typeid_cast(ast.get())) return; - if (ASTIdentifier * node = typeid_cast(ast.get())) + if (auto * node = typeid_cast(ast.get())) { if (node->kind == ASTIdentifier::Column) { @@ -1955,7 +1955,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries, && actions_stack.getSampleBlock().has(ast->getColumnName())) return; - if (ASTIdentifier * node = typeid_cast(ast.get())) + if (auto * node = typeid_cast(ast.get())) { std::string name = node->getColumnName(); if (!only_consts && !actions_stack.getSampleBlock().has(name)) @@ -1973,7 +1973,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries, ErrorCodes::NOT_AN_AGGREGATE); } } - else if (ASTFunction * node = typeid_cast(ast.get())) + else if (auto * node = typeid_cast(ast.get())) { if (node->name == "lambda") throw Exception("Unexpected lambda expression", ErrorCodes::UNEXPECTED_EXPRESSION); @@ -2049,14 +2049,14 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries, for (auto & child : node->arguments->children) { - ASTFunction * lambda = typeid_cast(child.get()); + auto * lambda = typeid_cast(child.get()); if (lambda && lambda->name == "lambda") { /// If the argument is a lambda expression, just remember its approximate type. if (lambda->arguments->children.size() != 2) throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTFunction * lambda_args_tuple = typeid_cast(lambda->arguments->children.at(0).get()); + auto * lambda_args_tuple = typeid_cast(lambda->arguments->children.at(0).get()); if (!lambda_args_tuple || lambda_args_tuple->name != "tuple") throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); @@ -2126,17 +2126,17 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries, { ASTPtr child = node->arguments->children[i]; - ASTFunction * lambda = typeid_cast(child.get()); + auto * lambda = typeid_cast(child.get()); if (lambda && lambda->name == "lambda") { - const DataTypeFunction * lambda_type = typeid_cast(argument_types[i].get()); - ASTFunction * lambda_args_tuple = typeid_cast(lambda->arguments->children.at(0).get()); + const auto * lambda_type = typeid_cast(argument_types[i].get()); + auto * lambda_args_tuple = typeid_cast(lambda->arguments->children.at(0).get()); ASTs lambda_arg_asts = lambda_args_tuple->arguments->children; NamesAndTypesList lambda_arguments; for (size_t j = 0; j < lambda_arg_asts.size(); ++j) { - ASTIdentifier * identifier = typeid_cast(lambda_arg_asts[j].get()); + auto * identifier = typeid_cast(lambda_arg_asts[j].get()); if (!identifier) throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH); @@ -2192,7 +2192,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries, if (arguments_present) actions_stack.addAction(ExpressionAction::applyFunction(function_builder, argument_names, node->getColumnName())); } - else if (ASTLiteral * node = typeid_cast(ast.get())) + else if (auto * node = typeid_cast(ast.get())) { DataTypePtr type = applyVisitor(FieldToDataType(), node->value); @@ -2232,7 +2232,7 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr return; } - const ASTFunction * node = typeid_cast(ast.get()); + const auto * node = typeid_cast(ast.get()); if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) { has_aggregation = true; @@ -2276,7 +2276,7 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr void ExpressionAnalyzer::assertNoAggregates(const ASTPtr & ast, const char * description) { - const ASTFunction * node = typeid_cast(ast.get()); + const auto * node = typeid_cast(ast.get()); if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) throw Exception("Aggregate function " + node->getColumnName() @@ -2365,9 +2365,9 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty initChain(chain, source_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - const ASTTablesInSelectQueryElement & join_element = static_cast(*select_query->join()); - const ASTTableJoin & join_params = static_cast(*join_element.table_join); - const ASTTableExpression & table_to_join = static_cast(*join_element.table_expression); + const auto & join_element = static_cast(*select_query->join()); + const auto & join_params = static_cast(*join_element.table_join); + const auto & table_to_join = static_cast(*join_element.table_expression); if (join_params.using_expression_list) getRootActions(join_params.using_expression_list, only_types, false, step.actions); @@ -2386,7 +2386,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty if (table) { - StorageJoin * storage_join = dynamic_cast(table.get()); + auto * storage_join = dynamic_cast(table.get()); if (storage_join) { @@ -2435,7 +2435,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty /// TODO You do not need to set this up when JOIN is only needed on remote servers. subquery_for_set.join = join; - subquery_for_set.join->setSampleBlock(subquery_for_set.source->getHeader()); + subquery_for_set.join->init(subquery_for_set.source->getHeader()); } addJoinAction(step.actions, false); @@ -2544,7 +2544,7 @@ bool ExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only ASTs asts = select_query->order_expression_list->children; for (const auto & i : asts) { - ASTOrderByElement * ast = typeid_cast(i.get()); + auto * ast = typeid_cast(i.get()); if (!ast || ast->children.empty()) throw Exception("Bad order expression AST", ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE); ASTPtr order_expression = ast->children.at(0); @@ -2598,7 +2598,7 @@ void ExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & chain) con void ExpressionAnalyzer::getActionsBeforeAggregation(const ASTPtr & ast, ExpressionActionsPtr & actions, bool no_subqueries) { - ASTFunction * node = typeid_cast(ast.get()); + auto * node = typeid_cast(ast.get()); if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) for (auto & argument : node->arguments->children) @@ -2714,7 +2714,7 @@ void ExpressionAnalyzer::collectUsedColumns() NameSet required_joined_columns; getRequiredSourceColumnsImpl(ast, available_columns, required, ignored, available_joined_columns, required_joined_columns); - for (NamesAndTypesList::iterator it = columns_added_by_join.begin(); it != columns_added_by_join.end();) + for (auto it = columns_added_by_join.begin(); it != columns_added_by_join.end();) { if (required_joined_columns.count(it->name)) ++it; @@ -2737,7 +2737,7 @@ void ExpressionAnalyzer::collectUsedColumns() NameSet unknown_required_source_columns = required; - for (NamesAndTypesList::iterator it = source_columns.begin(); it != source_columns.end();) + for (auto it = source_columns.begin(); it != source_columns.end();) { unknown_required_source_columns.erase(it->name); @@ -2777,8 +2777,8 @@ void ExpressionAnalyzer::collectJoinedColumns(NameSet & joined_columns, NamesAnd if (!node) return; - const ASTTableJoin & table_join = static_cast(*node->table_join); - const ASTTableExpression & table_expression = static_cast(*node->table_expression); + const auto & table_join = static_cast(*node->table_join); + const auto & table_expression = static_cast(*node->table_expression); Block nested_result_sample; if (table_expression.database_and_table_name) @@ -2847,7 +2847,7 @@ void ExpressionAnalyzer::getRequiredSourceColumnsImpl(const ASTPtr & ast, * - we put identifiers available from JOIN in required_joined_columns. */ - if (ASTIdentifier * node = typeid_cast(ast.get())) + if (auto * node = typeid_cast(ast.get())) { if (node->kind == ASTIdentifier::Column && !ignored_names.count(node->name) @@ -2863,14 +2863,14 @@ void ExpressionAnalyzer::getRequiredSourceColumnsImpl(const ASTPtr & ast, return; } - if (ASTFunction * node = typeid_cast(ast.get())) + if (auto * node = typeid_cast(ast.get())) { if (node->name == "lambda") { if (node->arguments->children.size() != 2) throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTFunction * lambda_args_tuple = typeid_cast(node->arguments->children.at(0).get()); + auto * lambda_args_tuple = typeid_cast(node->arguments->children.at(0).get()); if (!lambda_args_tuple || lambda_args_tuple->name != "tuple") throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); @@ -2879,7 +2879,7 @@ void ExpressionAnalyzer::getRequiredSourceColumnsImpl(const ASTPtr & ast, Names added_ignored; for (auto & child : lambda_args_tuple->arguments->children) { - ASTIdentifier * identifier = typeid_cast(child.get()); + auto * identifier = typeid_cast(child.get()); if (!identifier) throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH); @@ -2926,7 +2926,7 @@ void ExpressionAnalyzer::getRequiredSourceColumnsImpl(const ASTPtr & ast, static bool hasArrayJoin(const ASTPtr & ast) { - if (const ASTFunction * function = typeid_cast(&*ast)) + if (const auto * function = typeid_cast(&*ast)) if (function->name == "arrayJoin") return true; diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index f1275d8e88e..820618a6e8b 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -26,10 +26,6 @@ #include #include #include -#include - -#include "executeQuery.h" - namespace DB { @@ -42,40 +38,67 @@ extern const int TYPE_MISMATCH; extern const int ILLEGAL_COLUMN; } // namespace ErrorCodes +namespace +{ /// Do I need to use the hash table maps_*_full, in which we remember whether the row was joined. -static bool getFullness(ASTTableJoin::Kind kind) +bool getFullness(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Right || kind == ASTTableJoin::Kind::Cross_Right || kind == ASTTableJoin::Kind::Full; } -static bool isLeftJoin(ASTTableJoin::Kind kind) +bool isLeftJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Left || kind == ASTTableJoin::Kind::Cross_Left; } -static bool isRightJoin(ASTTableJoin::Kind kind) +bool isRightJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Right || kind == ASTTableJoin::Kind::Cross_Right; } -static bool isInnerJoin(ASTTableJoin::Kind kind) +bool isInnerJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Inner || kind == ASTTableJoin::Kind::Cross; } -static bool isAntiJoin(ASTTableJoin::Kind kind) +bool isAntiJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Anti || kind == ASTTableJoin::Kind::Cross_Anti; } -static bool isCrossJoin(ASTTableJoin::Kind kind) +bool isCrossJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Cross || kind == ASTTableJoin::Kind::Cross_Left || kind == ASTTableJoin::Kind::Cross_Right || kind == ASTTableJoin::Kind::Cross_Anti || kind == ASTTableJoin::Kind::Cross_LeftSemi || kind == ASTTableJoin::Kind::Cross_LeftAnti; } /// (cartesian) (anti) left semi join. -static bool isLeftSemiFamily(ASTTableJoin::Kind kind) +bool isLeftSemiFamily(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::LeftSemi || kind == ASTTableJoin::Kind::LeftAnti || kind == ASTTableJoin::Kind::Cross_LeftSemi || kind == ASTTableJoin::Kind::Cross_LeftAnti; } +void convertColumnToNullable(ColumnWithTypeAndName & column) +{ + column.type = makeNullable(column.type); + if (column.column) + column.column = makeNullable(column.column); +} + +ColumnRawPtrs getKeyColumns(const Names & key_names, const Block & block) +{ + size_t keys_size = key_names.size(); + ColumnRawPtrs key_columns(keys_size); + + for (size_t i = 0; i < keys_size; ++i) + { + key_columns[i] = block.getByName(key_names[i]).column.get(); + + /// We will join only keys, where all components are not NULL. + if (key_columns[i]->isColumnNullable()) + key_columns[i] = &static_cast(*key_columns[i]).getNestedColumn(); + } + + return key_columns; +} +} // namespace + const std::string Join::match_helper_prefix = "__left-semi-join-match-helper"; const DataTypePtr Join::match_helper_type = makeNullable(std::make_shared()); @@ -88,7 +111,6 @@ Join::Join( ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_, const String & req_id, - size_t build_concurrency_, const TiDB::TiDBCollators & collators_, const String & left_filter_column_, const String & right_filter_column_, @@ -103,7 +125,8 @@ Join::Join( , key_names_left(key_names_left_) , key_names_right(key_names_right_) , use_nulls(use_nulls_) - , build_concurrency(std::max(1, build_concurrency_)) + , build_concurrency(0) + , build_set_exceeded(false) , collators(collators_) , left_filter_column(left_filter_column_) , right_filter_column(right_filter_column_) @@ -116,9 +139,6 @@ Join::Join( , log(Logger::get("Join", req_id)) , limits(limits) { - build_set_exceeded.store(false); - for (size_t i = 0; i < build_concurrency; i++) - pools.emplace_back(std::make_shared()); if (other_condition_ptr != nullptr) { /// if there is other_condition, then should keep all the valid rows during probe stage @@ -127,14 +147,9 @@ Join::Join( strictness = ASTTableJoin::Strictness::All; } } - if (getFullness(kind)) - { - for (size_t i = 0; i < build_concurrency; i++) - rows_not_inserted_to_map.push_back(std::make_unique()); - } - if (!left_filter_column.empty() && !isLeftJoin(kind)) + if (unlikely(!left_filter_column.empty() && !isLeftJoin(kind))) throw Exception("Not supported: non left join with left conditions"); - if (!right_filter_column.empty() && !isRightJoin(kind)) + if (unlikely(!right_filter_column.empty() && !isRightJoin(kind))) throw Exception("Not supported: non right join with right conditions"); } @@ -328,7 +343,7 @@ struct KeyGetterForType using Type = typename KeyGetterForTypeImpl::Type; }; -void Join::init(Type type_) +void Join::initMapImpl(Type type_) { type = type_; @@ -338,16 +353,16 @@ void Join::init(Type type_) if (!getFullness(kind)) { if (strictness == ASTTableJoin::Strictness::Any) - initImpl(maps_any, type, build_concurrency); + initImpl(maps_any, type, getBuildConcurrencyInternal()); else - initImpl(maps_all, type, build_concurrency); + initImpl(maps_all, type, getBuildConcurrencyInternal()); } else { if (strictness == ASTTableJoin::Strictness::Any) - initImpl(maps_any_full, type, build_concurrency); + initImpl(maps_any_full, type, getBuildConcurrencyInternal()); else - initImpl(maps_all_full, type, build_concurrency); + initImpl(maps_all_full, type, getBuildConcurrencyInternal()); } } @@ -396,37 +411,24 @@ size_t Join::getTotalByteCount() const return res; } - -static void convertColumnToNullable(ColumnWithTypeAndName & column) +void Join::setBuildConcurrencyAndInitPool(size_t build_concurrency_) { - column.type = makeNullable(column.type); - if (column.column) - column.column = makeNullable(column.column); -} + if (unlikely(build_concurrency > 0)) + throw Exception("Logical error: `setBuildConcurrencyAndInitPool` shouldn't be called more than once", ErrorCodes::LOGICAL_ERROR); + build_concurrency = std::max(1, build_concurrency_); - -void Join::setSampleBlock(const Block & block) -{ - std::unique_lock lock(rwlock); - - if (!empty()) - return; - - size_t keys_size = key_names_right.size(); - ColumnRawPtrs key_columns(keys_size); - - for (size_t i = 0; i < keys_size; ++i) + for (size_t i = 0; i < getBuildConcurrencyInternal(); ++i) + pools.emplace_back(std::make_shared()); + // init for non-joined-streams. + if (getFullness(kind)) { - key_columns[i] = block.getByName(key_names_right[i]).column.get(); - - /// We will join only keys, where all components are not NULL. - if (key_columns[i]->isColumnNullable()) - key_columns[i] = &static_cast(*key_columns[i]).getNestedColumn(); + for (size_t i = 0; i < getNotJoinedStreamConcurrencyInternal(); ++i) + rows_not_inserted_to_map.push_back(std::make_unique()); } +} - /// Choose data structure to use for JOIN. - init(chooseMethod(key_columns, key_sizes)); - +void Join::setSampleBlock(const Block & block) +{ sample_block_with_columns_to_add = materializeBlock(block); /// Move from `sample_block_with_columns_to_add` key columns to `sample_block_with_keys`, keeping the order. @@ -461,6 +463,18 @@ void Join::setSampleBlock(const Block & block) sample_block_with_columns_to_add.insert(ColumnWithTypeAndName(Join::match_helper_type, match_helper_name)); } +void Join::init(const Block & sample_block, size_t build_concurrency_) +{ + std::unique_lock lock(rwlock); + if (unlikely(initialized)) + throw Exception("Logical error: Join has been initialized", ErrorCodes::LOGICAL_ERROR); + initialized = true; + setBuildConcurrencyAndInitPool(build_concurrency_); + /// Choose data structure to use for JOIN. + initMapImpl(chooseMethod(getKeyColumns(key_names_right, sample_block), key_sizes)); + setSampleBlock(sample_block); +} + namespace { @@ -725,7 +739,7 @@ void recordFilteredRows(const Block & block, const String & filter_column, Colum column = column->convertToFullColumnIfConst(); if (column->isColumnNullable()) { - const ColumnNullable & column_nullable = static_cast(*column); + const auto & column_nullable = static_cast(*column); if (!null_map_holder) { null_map_holder = column_nullable.getNullMapColumnPtr(); @@ -761,9 +775,9 @@ void recordFilteredRows(const Block & block, const String & filter_column, Colum bool Join::insertFromBlock(const Block & block) { - if (empty()) - throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR); std::unique_lock lock(rwlock); + if (unlikely(!initialized)) + throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR); blocks.push_back(block); Block * stored_block = &blocks.back(); return insertFromBlockInternal(stored_block, 0); @@ -772,11 +786,12 @@ bool Join::insertFromBlock(const Block & block) /// the block should be valid. void Join::insertFromBlock(const Block & block, size_t stream_index) { - assert(stream_index < build_concurrency); + std::shared_lock lock(rwlock); + assert(stream_index < getBuildConcurrencyInternal()); + assert(stream_index < getNotJoinedStreamConcurrencyInternal()); - if (empty()) + if (unlikely(!initialized)) throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR); - std::shared_lock lock(rwlock); Block * stored_block = nullptr; { std::lock_guard lk(blocks_lock); @@ -872,16 +887,16 @@ bool Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) if (!getFullness(kind)) { if (strictness == ASTTableJoin::Strictness::Any) - insertFromBlockImpl(type, maps_any, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_any, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); else - insertFromBlockImpl(type, maps_all, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_all, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); } else { if (strictness == ASTTableJoin::Strictness::Any) - insertFromBlockImpl(type, maps_any_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_any_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); else - insertFromBlockImpl(type, maps_all_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_all_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); } } @@ -1958,7 +1973,8 @@ class NonJoinedBlockInputStream : public IProfilingBlockInputStream , max_block_size(max_block_size_) , add_not_mapped_rows(true) { - if (step > parent.build_concurrency || index >= parent.build_concurrency) + size_t build_concurrency = parent.getBuildConcurrency(); + if (unlikely(step > build_concurrency || index >= build_concurrency)) throw Exception("The concurrency of NonJoinedBlockInputStream should not be larger than join build concurrency"); /** left_sample_block contains keys and "left" columns. @@ -2048,7 +2064,7 @@ class NonJoinedBlockInputStream : public IProfilingBlockInputStream MutableColumns columns_right; std::unique_ptr> position; /// type erasure - size_t current_segment; + size_t current_segment = 0; Join::RowRefList * current_not_mapped_row = nullptr; void setNextCurrentNotMappedRow() diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 89dad0d1ca6..01916aa1dcc 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -99,7 +99,6 @@ class Join ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_, const String & req_id, - size_t build_concurrency = 1, const TiDB::TiDBCollators & collators_ = TiDB::dummy_collators, const String & left_filter_column = "", const String & right_filter_column = "", @@ -109,17 +108,10 @@ class Join size_t max_block_size = 0, const String & match_helper_name = ""); - bool empty() { return type == Type::EMPTY; } - - /** Set information about structure of right hand of JOIN (joined data). + /** Call `setBuildConcurrencyAndInitPool`, `initMapImpl` and `setSampleBlock`. * You must call this method before subsequent calls to insertFromBlock. */ - void setSampleBlock(const Block & block); - - /** Add block of data from right hand of JOIN to the map. - * Returns false, if some limit was exceeded and you should not insert more data. - */ - bool insertFromBlockInternal(Block * stored_block, size_t stream_index); + void init(const Block & sample_block, size_t build_concurrency_ = 1); bool insertFromBlock(const Block & block); @@ -153,9 +145,19 @@ class Join bool useNulls() const { return use_nulls; } const Names & getLeftJoinKeys() const { return key_names_left; } - size_t getBuildConcurrency() const { return build_concurrency; } + + size_t getBuildConcurrency() const + { + std::shared_lock lock(rwlock); + return getBuildConcurrencyInternal(); + } + size_t getNotJoinedStreamConcurrency() const + { + std::shared_lock lock(rwlock); + return getNotJoinedStreamConcurrencyInternal(); + } + bool isBuildSetExceeded() const { return build_set_exceeded.load(); } - size_t getNotJoinedStreamConcurrency() const { return build_concurrency; }; enum BuildTableState { @@ -171,7 +173,7 @@ class Join const Block * block; size_t row_num; - RowRef() {} + RowRef() = default; RowRef(const Block * block_, size_t row_num_) : block(block_) , row_num(row_num_) @@ -183,7 +185,7 @@ class Join { RowRefList * next = nullptr; - RowRefList() {} + RowRefList() = default; RowRefList(const Block * block_, size_t row_num_) : RowRef(block_, row_num_) {} @@ -342,11 +344,40 @@ class Join */ mutable std::shared_mutex rwlock; - void init(Type type_); + bool initialized = false; + + size_t getBuildConcurrencyInternal() const + { + if (unlikely(build_concurrency == 0)) + throw Exception("Logical error: `setBuildConcurrencyAndInitPool` has not been called", ErrorCodes::LOGICAL_ERROR); + return build_concurrency; + } + size_t getNotJoinedStreamConcurrencyInternal() const + { + return getBuildConcurrencyInternal(); + } + + /// Initialize map implementations for various join types. + void initMapImpl(Type type_); + + /** Set information about structure of right hand of JOIN (joined data). + * You must call this method before subsequent calls to insertFromBlock. + */ + void setSampleBlock(const Block & block); + + /** Set Join build concurrency and init hash map. + * You must call this method before subsequent calls to insertFromBlock. + */ + void setBuildConcurrencyAndInitPool(size_t build_concurrency_); /// Throw an exception if blocks have different types of key columns. void checkTypesOfKeys(const Block & block_left, const Block & block_right) const; + /** Add block of data from right hand of JOIN to the map. + * Returns false, if some limit was exceeded and you should not insert more data. + */ + bool insertFromBlockInternal(Block * stored_block, size_t stream_index); + template void joinBlockImpl(Block & block, const Maps & maps) const; diff --git a/dbms/src/Storages/Page/V3/PageDirectoryFactory.cpp b/dbms/src/Storages/Page/V3/PageDirectoryFactory.cpp index 0592d1ddaa8..9d20e0a64ab 100644 --- a/dbms/src/Storages/Page/V3/PageDirectoryFactory.cpp +++ b/dbms/src/Storages/Page/V3/PageDirectoryFactory.cpp @@ -113,21 +113,13 @@ void PageDirectoryFactory::loadEdit(const PageDirectoryPtr & dir, const PageEntr if (max_applied_ver < r.version) max_applied_ver = r.version; - // We can not avoid page id from being reused under some corner situation. Try to do gcInMemEntries - // and apply again to resolve the error. - if (bool ok = applyRecord(dir, r, /*throw_on_error*/ false); unlikely(!ok)) - { - dir->gcInMemEntries(); - applyRecord(dir, r, /*throw_on_error*/ true); - LOG_FMT_INFO(DB::Logger::get("PageDirectoryFactory"), "resolve from error status done, continue to restore"); - } + applyRecord(dir, r); } } -bool PageDirectoryFactory::applyRecord( +void PageDirectoryFactory::applyRecord( const PageDirectoryPtr & dir, - const PageEntriesEdit::EditRecord & r, - bool throw_on_error) + const PageEntriesEdit::EditRecord & r) { auto [iter, created] = dir->mvcc_table_directory.insert(std::make_pair(r.page_id, nullptr)); if (created) @@ -189,14 +181,8 @@ bool PageDirectoryFactory::applyRecord( catch (DB::Exception & e) { e.addMessage(fmt::format(" [type={}] [page_id={}] [ver={}]", r.type, r.page_id, restored_version)); - if (throw_on_error || e.code() != ErrorCodes::PS_DIR_APPLY_INVALID_STATUS) - { - throw e; - } - LOG_FMT_WARNING(DB::Logger::get("PageDirectoryFactory"), "try to resolve error during restore: {}", e.message()); - return false; + throw e; } - return true; } void PageDirectoryFactory::loadFromDisk(const PageDirectoryPtr & dir, WALStoreReaderPtr && reader) diff --git a/dbms/src/Storages/Page/V3/PageDirectoryFactory.h b/dbms/src/Storages/Page/V3/PageDirectoryFactory.h index e4b76bfba0d..185e8fd19a5 100644 --- a/dbms/src/Storages/Page/V3/PageDirectoryFactory.h +++ b/dbms/src/Storages/Page/V3/PageDirectoryFactory.h @@ -60,10 +60,9 @@ class PageDirectoryFactory private: void loadFromDisk(const PageDirectoryPtr & dir, WALStoreReaderPtr && reader); void loadEdit(const PageDirectoryPtr & dir, const PageEntriesEdit & edit); - static bool applyRecord( + static void applyRecord( const PageDirectoryPtr & dir, - const PageEntriesEdit::EditRecord & r, - bool throw_on_error); + const PageEntriesEdit::EditRecord & r); BlobStore::BlobStats * blob_stats = nullptr; }; diff --git a/dbms/src/Storages/Page/V3/tests/gtest_page_directory.cpp b/dbms/src/Storages/Page/V3/tests/gtest_page_directory.cpp index dfa33824473..151b3b50657 100644 --- a/dbms/src/Storages/Page/V3/tests/gtest_page_directory.cpp +++ b/dbms/src/Storages/Page/V3/tests/gtest_page_directory.cpp @@ -2214,191 +2214,6 @@ try } CATCH -TEST_F(PageDirectoryGCTest, RestoreWithDuplicateID) -try -{ - auto restore_from_edit = [](const PageEntriesEdit & edit) { - auto ctx = ::DB::tests::TiFlashTestEnv::getContext(); - auto provider = ctx.getFileProvider(); - auto path = getTemporaryPath(); - PSDiskDelegatorPtr delegator = std::make_shared(path); - PageDirectoryFactory factory; - auto d = factory.createFromEdit(getCurrentTestName(), provider, delegator, edit); - return d; - }; - - const PageId target_id = 100; - // ========= 1 =======// - // Reuse same id: PUT_EXT/DEL/REF - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.putExternal(target_id); - edit.del(target_id); - // restart and reuse id=100 as ref to replace put_ext - edit.ref(target_id, 50); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - ASSERT_EQ(restored_dir->getNormalPageId(target_id, snap).low, 50); - } - // Reuse same id: PUT_EXT/DEL/PUT - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntryV3 entry_100{.file_id = 100, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.putExternal(target_id); - edit.del(target_id); - // restart and reuse id=100 as put to replace put_ext - edit.put(target_id, entry_100); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - ASSERT_SAME_ENTRY(restored_dir->get(target_id, snap).second, entry_100); - } - - // ========= 1-invalid =======// - // Reuse same id: PUT_EXT/BEING REF/DEL/REF - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.putExternal(target_id); - edit.ref(101, target_id); - edit.del(target_id); - // restart and reuse id=100 as ref. Should not happen because 101 still ref to 100 - edit.ref(target_id, 50); - - ASSERT_THROW(restore_from_edit(edit);, DB::Exception); - } - // Reuse same id: PUT_EXT/BEING REF/DEL/PUT - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntryV3 entry_100{.file_id = 100, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.putExternal(target_id); - edit.ref(101, target_id); - edit.del(target_id); - // restart and reuse id=100 as put. Should not happen because 101 still ref to 100 - edit.put(target_id, entry_100); - - ASSERT_THROW(restore_from_edit(edit);, DB::Exception); - } - - // ========= 2 =======// - // Reuse same id: PUT/DEL/REF - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.put(target_id, entry_50); - edit.del(target_id); - // restart and reuse id=100 as ref to replace put - edit.ref(target_id, 50); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - ASSERT_EQ(restored_dir->getNormalPageId(target_id, snap).low, 50); - } - // Reuse same id: PUT/DEL/PUT_EXT - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.put(target_id, entry_50); - edit.del(target_id); - // restart and reuse id=100 as external to replace put - edit.putExternal(target_id); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - auto ext_ids = restored_dir->getAliveExternalIds(TEST_NAMESPACE_ID); - ASSERT_EQ(ext_ids.size(), 1); - ASSERT_EQ(*ext_ids.begin(), target_id); - } - - // ========= 2-invalid =======// - // Reuse same id: PUT/BEING REF/DEL/REF - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.put(target_id, entry_50); - edit.ref(101, target_id); - edit.del(target_id); - // restart and reuse id=100 as ref to replace put - edit.ref(target_id, 50); - - ASSERT_THROW(restore_from_edit(edit);, DB::Exception); - } - // Reuse same id: PUT/BEING REF/DEL/PUT_EXT - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.put(target_id, entry_50); - edit.ref(101, target_id); - edit.del(target_id); - // restart and reuse id=100 as external to replace put - edit.putExternal(target_id); - - ASSERT_THROW(restore_from_edit(edit);, DB::Exception); - } - - // ========= 3 =======// - // Reuse same id: REF/DEL/PUT - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntryV3 entry_100{.file_id = 100, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.ref(target_id, 50); - edit.del(target_id); - // restart and reuse id=100 as put to replace ref - edit.put(target_id, entry_100); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - ASSERT_SAME_ENTRY(restored_dir->get(target_id, snap).second, entry_100); - } - // Reuse same id: REF/DEL/PUT_EXT - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.ref(target_id, 50); - edit.del(target_id); - // restart and reuse id=100 as external to replace ref - edit.putExternal(target_id); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - auto ext_ids = restored_dir->getAliveExternalIds(TEST_NAMESPACE_ID); - ASSERT_EQ(ext_ids.size(), 1); - ASSERT_EQ(*ext_ids.begin(), target_id); - } - // Reuse same id: REF/DEL/REF another id - { - PageEntryV3 entry_50{.file_id = 1, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntryV3 entry_51{.file_id = 2, .size = 7890, .tag = 0, .offset = 0x123, .checksum = 0x4567}; - PageEntriesEdit edit; - edit.put(50, entry_50); - edit.put(51, entry_51); - edit.ref(target_id, 50); - edit.del(target_id); - // restart and reuse id=target_id as external to replace put - edit.ref(target_id, 51); - - auto restored_dir = restore_from_edit(edit); - auto snap = restored_dir->createSnapshot(); - ASSERT_EQ(restored_dir->getNormalPageId(target_id, snap).low, 51); - } -} -CATCH - #undef INSERT_ENTRY_TO #undef INSERT_ENTRY #undef INSERT_ENTRY_ACQ_SNAP diff --git a/dbms/src/Storages/StorageJoin.cpp b/dbms/src/Storages/StorageJoin.cpp index 47907b3e94e..4ca3e79a7ab 100644 --- a/dbms/src/Storages/StorageJoin.cpp +++ b/dbms/src/Storages/StorageJoin.cpp @@ -52,7 +52,7 @@ StorageJoin::StorageJoin( /// NOTE StorageJoin doesn't use join_use_nulls setting. join = std::make_shared(key_names, key_names, false /* use_nulls */, SizeLimits(), kind, strictness, /*req_id=*/""); - join->setSampleBlock(getSampleBlock().sortColumns()); + join->init(getSampleBlock().sortColumns()); restore(); } @@ -87,7 +87,7 @@ void registerStorageJoin(StorageFactory & factory) "Storage Join requires at least 3 parameters: Join(ANY|ALL, LEFT|INNER, keys...).", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const ASTIdentifier * strictness_id = typeid_cast(engine_args[0].get()); + const auto * strictness_id = typeid_cast(engine_args[0].get()); if (!strictness_id) throw Exception("First parameter of storage Join must be ANY or ALL (without quotes).", ErrorCodes::BAD_ARGUMENTS); @@ -100,7 +100,7 @@ void registerStorageJoin(StorageFactory & factory) else throw Exception("First parameter of storage Join must be ANY or ALL (without quotes).", ErrorCodes::BAD_ARGUMENTS); - const ASTIdentifier * kind_id = typeid_cast(engine_args[1].get()); + const auto * kind_id = typeid_cast(engine_args[1].get()); if (!kind_id) throw Exception("Second parameter of storage Join must be LEFT or INNER (without quotes).", ErrorCodes::BAD_ARGUMENTS); @@ -121,7 +121,7 @@ void registerStorageJoin(StorageFactory & factory) key_names.reserve(engine_args.size() - 2); for (size_t i = 2, size = engine_args.size(); i < size; ++i) { - const ASTIdentifier * key = typeid_cast(engine_args[i].get()); + const auto * key = typeid_cast(engine_args[i].get()); if (!key) throw Exception("Parameter №" + toString(i + 1) + " of storage Join don't look like column name.", ErrorCodes::BAD_ARGUMENTS); diff --git a/tests/fullstack-test/mpp/misc_join.test b/tests/fullstack-test/mpp/misc_join.test new file mode 100644 index 00000000000..61a1de49925 --- /dev/null +++ b/tests/fullstack-test/mpp/misc_join.test @@ -0,0 +1,41 @@ +# Copyright 2022 PingCAP, Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Preparation. +mysql> drop table if exists test.t1; +mysql> create table test.t1 (id decimal(5,2), value bigint(20)); +mysql> insert into test.t1 values(1, 1),(2, 2); +mysql> drop table if exists test.t2; +mysql> create table test.t2 (id decimal(5,2), value bigint(20)); +mysql> insert into test.t2 values(1, 1),(2, 2),(3, 3),(4, 4); + +mysql> alter table test.t1 set tiflash replica 1 +mysql> alter table test.t2 set tiflash replica 1 +mysql> analyze table test.t1 +mysql> analyze table test.t2 + +func> wait_table test t1 +func> wait_table test t2 + +mysql> use test; set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select * from t1 left join t2 on cast(t1.id as decimal(7,2)) = cast(t2.id as decimal(7,2)) and t1.id + cast(t2.id as decimal(7,2)) + t1.id > 10; ++------+-------+------+-------+ +| id | value | id | value | ++------+-------+------+-------+ +| 1.00 | 1 | NULL | NULL | +| 2.00 | 2 | NULL | NULL | ++------+-------+------+-------+ + +# Clean up. +mysql> drop table if exists test.t1 +mysql> drop table if exists test.t2