diff --git a/executor/executor.go b/executor/executor.go index 4c2bab5b644f1..c477f25c435d8 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -2083,6 +2083,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if explainStmt, ok := s.(*ast.ExplainStmt); ok { sc.InExplainStmt = true sc.ExplainFormat = explainStmt.Format + sc.InExplainAnalyzeStmt = explainStmt.Analyze sc.IgnoreExplainIDSuffix = strings.ToLower(explainStmt.Format) == types.ExplainFormatBrief sc.InVerboseExplain = strings.ToLower(explainStmt.Format) == types.ExplainFormatVerbose s = explainStmt.Stmt @@ -2091,6 +2092,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { } if explainForStmt, ok := s.(*ast.ExplainForStmt); ok { sc.InExplainStmt = true + sc.InExplainAnalyzeStmt = true sc.InVerboseExplain = strings.ToLower(explainForStmt.Format) == types.ExplainFormatVerbose } diff --git a/expression/expression.go b/expression/expression.go index eea3c851db9a5..39afe4a1a914b 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -49,6 +49,7 @@ const ( columnFlag byte = 1 scalarFunctionFlag byte = 3 parameterFlag byte = 4 + ScalarSubQFlag byte = 5 ) // EvalAstExpr evaluates ast expression directly. diff --git a/planner/core/BUILD.bazel b/planner/core/BUILD.bazel index 527b86baa42c5..e4f110dd6d0d6 100644 --- a/planner/core/BUILD.bazel +++ b/planner/core/BUILD.bazel @@ -70,6 +70,7 @@ go_library( "rule_topn_push_down.go", "runtime_filter.go", "runtime_filter_generator.go", + "scalar_subq_expression.go", "show_predicate_extractor.go", "stats.go", "stringer.go", diff --git a/planner/core/casetest/scalarsubquery/BUILD.bazel b/planner/core/casetest/scalarsubquery/BUILD.bazel new file mode 100644 index 0000000000000..2ae01097534c6 --- /dev/null +++ b/planner/core/casetest/scalarsubquery/BUILD.bazel @@ -0,0 +1,20 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "scalarsubquery_test", + timeout = "short", + srcs = [ + "cases_test.go", + "main_test.go", + ], + data = glob(["testdata/**"]), + flaky = True, + deps = [ + "//testkit", + "//testkit/testdata", + "//testkit/testmain", + "//testkit/testsetup", + "@com_github_stretchr_testify//require", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/planner/core/casetest/scalarsubquery/cases_test.go b/planner/core/casetest/scalarsubquery/cases_test.go new file mode 100644 index 0000000000000..5fd6a3fceeb79 --- /dev/null +++ b/planner/core/casetest/scalarsubquery/cases_test.go @@ -0,0 +1,92 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scalarsubquery + +import ( + "fmt" + "testing" + + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/testkit/testdata" + "github.com/stretchr/testify/require" +) + +func TestExplainNonEvaledSubquery(t *testing.T) { + var ( + input []struct { + SQL string + IsExplainAnalyze bool + HasErr bool + } + output []struct { + SQL string + Plan []string + Error string + } + ) + planSuiteData := GetPlanSuiteData() + planSuiteData.LoadTestCases(t, &input, &output) + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t1(a int, b int, c int)") + tk.MustExec("create table t2(a int, b int, c int)") + tk.MustExec("create table t3(a varchar(5), b varchar(5), c varchar(5))") + tk.MustExec("set @@tidb_opt_enable_non_eval_scalar_subquery=true") + + cutExecutionInfoFromExplainAnalyzeOutput := func(rows [][]interface{}) [][]interface{} { + // The columns are id, estRows, actRows, task type, access object, execution info, operator info, memory, disk + // We need to cut the unstable output of execution info, memory and disk. + for i := range rows { + rows[i] = rows[i][:6] // cut the final memory and disk. + rows[i] = append(rows[i][:5], rows[i][6:]...) + } + return rows + } + + for i, ts := range input { + testdata.OnRecord(func() { + output[i].SQL = ts.SQL + if ts.HasErr { + err := tk.ExecToErr(ts.SQL) + require.NotNil(t, err, fmt.Sprintf("Failed at case #%d", i)) + output[i].Error = err.Error() + output[i].Plan = nil + } else { + rows := tk.MustQuery(ts.SQL).Rows() + if ts.IsExplainAnalyze { + rows = cutExecutionInfoFromExplainAnalyzeOutput(rows) + } + output[i].Plan = testdata.ConvertRowsToStrings(rows) + output[i].Error = "" + } + }) + if ts.HasErr { + err := tk.ExecToErr(ts.SQL) + require.NotNil(t, err, fmt.Sprintf("Failed at case #%d", i)) + } else { + rows := tk.MustQuery(ts.SQL).Rows() + if ts.IsExplainAnalyze { + rows = cutExecutionInfoFromExplainAnalyzeOutput(rows) + } + require.Equal(t, + testdata.ConvertRowsToStrings(testkit.Rows(output[i].Plan...)), + testdata.ConvertRowsToStrings(rows), + fmt.Sprintf("Failed at case #%d, SQL: %v", i, ts.SQL), + ) + } + } +} diff --git a/planner/core/casetest/scalarsubquery/main_test.go b/planner/core/casetest/scalarsubquery/main_test.go new file mode 100644 index 0000000000000..c790cd0b2264f --- /dev/null +++ b/planner/core/casetest/scalarsubquery/main_test.go @@ -0,0 +1,52 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scalarsubquery + +import ( + "flag" + "testing" + + "github.com/pingcap/tidb/testkit/testdata" + "github.com/pingcap/tidb/testkit/testmain" + "github.com/pingcap/tidb/testkit/testsetup" + "go.uber.org/goleak" +) + +var testDataMap = make(testdata.BookKeeper) + +func TestMain(m *testing.M) { + testsetup.SetupForCommonTest() + flag.Parse() + testDataMap.LoadTestSuiteData("testdata", "plan_suite") + opts := []goleak.Option{ + goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"), + goleak.IgnoreTopFunction("github.com/lestrrat-go/httprc.runFetchWorker"), + goleak.IgnoreTopFunction("go.etcd.io/etcd/client/pkg/v3/logutil.(*MergeLogger).outputLoop"), + goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), + goleak.IgnoreTopFunction("github.com/tikv/client-go/v2/txnkv/transaction.keepAlive"), + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + } + + callback := func(i int) int { + testDataMap.GenerateOutputIfNeeded() + return i + } + + goleak.VerifyTestMain(testmain.WrapTestingM(m, callback), opts...) +} + +func GetPlanSuiteData() testdata.TestData { + return testDataMap["plan_suite"] +} diff --git a/planner/core/casetest/scalarsubquery/testdata/plan_suite_in.json b/planner/core/casetest/scalarsubquery/testdata/plan_suite_in.json new file mode 100644 index 0000000000000..cb673dd95c4d4 --- /dev/null +++ b/planner/core/casetest/scalarsubquery/testdata/plan_suite_in.json @@ -0,0 +1,61 @@ +[ + { + "name": "TestExplainNonEvaledSubquery", + "cases": [ + // Test normal non-correlated scalar sub query. + { + "SQL": "explain format = 'brief' select * from t1 where a = (select a from t2 limit 1)", + "IsExplainAnalyze": false, + "HasErr": false + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where a = (select a from t2 limit 1)", + "IsExplainAnalyze": true, + "HasErr": false + }, + // Test EXISTS non-correlated scalar sub query. + { + "SQL": "explain format = 'brief' select * from t1 where exists(select 1 from t2 where a = 1)", + "IsExplainAnalyze": false, + "HasErr": false + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where exists(select 1 from t2 where a = 1)", + "IsExplainAnalyze": true, + "HasErr": false + }, + { + "SQL": "explain format = 'brief' select * from t1 where not exists(select 1 from t2 where a = 1)", + "IsExplainAnalyze": false, + "HasErr": false + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where not exists(select 1 from t2 where a = 1)", + "IsExplainAnalyze": true, + "HasErr": false + }, + // Test with constant propagation. + { + "SQL": "explain format = 'brief' select * from t1 where exists(select 1 from t2 where a = (select a from t3 limit 1) and b = a);", + "IsExplainAnalyze": false, + "HasErr": false + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where exists(select 1 from t2 where a = (select a from t3 limit 1) and b = a);", + "IsExplainAnalyze": true, + "HasErr": false + }, + // Test multiple returns. + { + "SQL": "explain format = 'brief' select * from t1 where (a, b) = (select a, b from t2 limit 1)", + "IsExplainAnalyze": false, + "HasErr": false + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where (a, b) = (select a, b from t2 limit 1)", + "IsExplainAnalyze": true, + "HasErr": false + } + ] + } +] \ No newline at end of file diff --git a/planner/core/casetest/scalarsubquery/testdata/plan_suite_out.json b/planner/core/casetest/scalarsubquery/testdata/plan_suite_out.json new file mode 100644 index 0000000000000..435d8405a9287 --- /dev/null +++ b/planner/core/casetest/scalarsubquery/testdata/plan_suite_out.json @@ -0,0 +1,119 @@ +[ + { + "Name": "TestExplainNonEvaledSubquery", + "Cases": [ + { + "SQL": "explain format = 'brief' select * from t1 where a = (select a from t2 limit 1)", + "Plan": [ + "Selection 8000.00 root eq(test.t1.a, ScalarQueryCol#9)", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#9", + "└─MaxOneRow 1.00 root ", + " └─Limit 1.00 root offset:0, count:1", + " └─TableReader 1.00 root data:Limit", + " └─Limit 1.00 cop[tikv] offset:0, count:1", + " └─TableFullScan 1.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Error": "" + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where a = (select a from t2 limit 1)", + "Plan": [ + "TableDual 0.00 0 root " + ], + "Error": "" + }, + { + "SQL": "explain format = 'brief' select * from t1 where exists(select 1 from t2 where a = 1)", + "Plan": [ + "Selection 8000.00 root ScalarQueryCol#10", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#10", + "└─TableReader 10.00 root data:Selection", + " └─Selection 10.00 cop[tikv] eq(test.t2.a, 1)", + " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Error": "" + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where exists(select 1 from t2 where a = 1)", + "Plan": [ + "TableDual 0.00 0 root " + ], + "Error": "" + }, + { + "SQL": "explain format = 'brief' select * from t1 where not exists(select 1 from t2 where a = 1)", + "Plan": [ + "Selection 8000.00 root not(ScalarQueryCol#10)", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#10", + "└─TableReader 10.00 root data:Selection", + " └─Selection 10.00 cop[tikv] eq(test.t2.a, 1)", + " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Error": "" + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where not exists(select 1 from t2 where a = 1)", + "Plan": [ + "TableReader 10000.00 0 root ", + "└─TableFullScan 10000.00 0 cop[tikv] table:t1" + ], + "Error": "" + }, + { + "SQL": "explain format = 'brief' select * from t1 where exists(select 1 from t2 where a = (select a from t3 limit 1) and b = a);", + "Plan": [ + "Selection 8000.00 root ScalarQueryCol#15", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#13", + "└─MaxOneRow 1.00 root ", + " └─Limit 1.00 root offset:0, count:1", + " └─TableReader 1.00 root data:Limit", + " └─Limit 1.00 cop[tikv] offset:0, count:1", + " └─TableFullScan 1.00 cop[tikv] table:t3 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#15", + "└─Selection 6400.00 root eq(cast(test.t2.a, double BINARY), cast(ScalarQueryCol#13, double BINARY)), eq(cast(test.t2.b, double BINARY), cast(ScalarQueryCol#13, double BINARY))", + " └─TableReader 8000.00 root data:Selection", + " └─Selection 8000.00 cop[tikv] eq(test.t2.b, test.t2.a)", + " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Error": "" + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where exists(select 1 from t2 where a = (select a from t3 limit 1) and b = a);", + "Plan": [ + "TableDual 0.00 0 root " + ], + "Error": "" + }, + { + "SQL": "explain format = 'brief' select * from t1 where (a, b) = (select a, b from t2 limit 1)", + "Plan": [ + "Selection 8000.00 root eq(test.t1.a, ScalarQueryCol#9), eq(test.t1.b, ScalarQueryCol#10)", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + "ScalarSubQuery N/A root Output: ScalarQueryCol#9,ScalarQueryCol#10", + "└─MaxOneRow 1.00 root ", + " └─Limit 1.00 root offset:0, count:1", + " └─TableReader 1.00 root data:Limit", + " └─Limit 1.00 cop[tikv] offset:0, count:1", + " └─TableFullScan 1.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + ], + "Error": "" + }, + { + "SQL": "explain analyze format = 'brief' select * from t1 where (a, b) = (select a, b from t2 limit 1)", + "Plan": [ + "TableDual 0.00 0 root " + ], + "Error": "" + } + ] + } +] \ No newline at end of file diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 3c87fcb88b91c..3dc784172d40f 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -969,6 +969,11 @@ func (e *Explain) explainFlatPlanInRowFormat(flat *FlatPhysicalPlan) { e.explainFlatOpInRowFormat(flatOp) } } + for _, subQ := range flat.ScalarSubQueries { + for _, flatOp := range subQ { + e.explainFlatOpInRowFormat(flatOp) + } + } } func (e *Explain) explainFlatPlanInJSONFormat(flat *FlatPhysicalPlan) (encodes []*ExplainInfoForEncode) { @@ -981,6 +986,9 @@ func (e *Explain) explainFlatPlanInJSONFormat(flat *FlatPhysicalPlan) (encodes [ for _, cte := range flat.CTEs { encodes = append(encodes, e.explainOpRecursivelyInJSONFormat(cte[0], cte)) } + for _, subQ := range flat.ScalarSubQueries { + encodes = append(encodes, e.explainOpRecursivelyInJSONFormat(subQ[0], subQ)) + } return } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index f6d4cfd162878..fef3e8bb1e461 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -875,6 +875,33 @@ func (er *expressionRewriter) handleExistSubquery(ctx context.Context, v *ast.Ex er.err = err return v, true } + if er.b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !er.b.ctx.GetSessionVars().StmtCtx.InExplainAnalyzeStmt && er.b.ctx.GetSessionVars().ExplainNonEvaledSubQuery { + newColID := er.b.ctx.GetSessionVars().AllocPlanColumnID() + subqueryCtx := ScalarSubqueryEvalCtx{ + scalarSubQuery: physicalPlan, + ctx: ctx, + is: er.b.is, + outputColIDs: []int64{newColID}, + }.Init(er.b.ctx, np.SelectBlockOffset()) + scalarSubQ := &ScalarSubQueryExpr{ + scalarSubqueryColID: newColID, + evalCtx: subqueryCtx, + } + scalarSubQ.RetType = np.Schema().Columns[0].GetType() + scalarSubQ.SetCoercibility(np.Schema().Columns[0].Coercibility()) + er.b.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx) + if v.Not { + notWrapped, err := expression.NewFunction(er.b.ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), scalarSubQ) + if err != nil { + er.err = err + return v, true + } + er.ctxStackAppend(notWrapped, types.EmptyName) + return v, true + } + er.ctxStackAppend(scalarSubQ, types.EmptyName) + return v, true + } row, err := EvalSubqueryFirstRow(ctx, physicalPlan, er.b.is, er.b.ctx) if err != nil { er.err = err @@ -1086,6 +1113,40 @@ func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, v *ast.S er.err = err return v, true } + if er.b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !er.b.ctx.GetSessionVars().StmtCtx.InExplainAnalyzeStmt && er.b.ctx.GetSessionVars().ExplainNonEvaledSubQuery { + subqueryCtx := ScalarSubqueryEvalCtx{ + scalarSubQuery: physicalPlan, + ctx: ctx, + is: er.b.is, + }.Init(er.b.ctx, np.SelectBlockOffset()) + newColIDs := make([]int64, 0, np.Schema().Len()) + newScalarSubQueryExprs := make([]expression.Expression, 0, np.Schema().Len()) + for _, col := range np.Schema().Columns { + newColID := er.b.ctx.GetSessionVars().AllocPlanColumnID() + scalarSubQ := &ScalarSubQueryExpr{ + scalarSubqueryColID: newColID, + evalCtx: subqueryCtx, + } + scalarSubQ.RetType = col.RetType + scalarSubQ.SetCoercibility(col.Coercibility()) + newColIDs = append(newColIDs, newColID) + newScalarSubQueryExprs = append(newScalarSubQueryExprs, scalarSubQ) + } + subqueryCtx.outputColIDs = newColIDs + + er.b.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx) + if len(newScalarSubQueryExprs) == 1 { + er.ctxStackAppend(newScalarSubQueryExprs[0], types.EmptyName) + } else { + rowFunc, err := er.newFunction(ast.RowFunc, newScalarSubQueryExprs[0].GetType(), newScalarSubQueryExprs...) + if err != nil { + er.err = err + return v, true + } + er.ctxStack = append(er.ctxStack, rowFunc) + } + return v, true + } row, err := EvalSubqueryFirstRow(ctx, physicalPlan, er.b.is, er.b.ctx) if err != nil { er.err = err diff --git a/planner/core/flat_plan.go b/planner/core/flat_plan.go index f1c6e9104b0d1..c35857df679da 100644 --- a/planner/core/flat_plan.go +++ b/planner/core/flat_plan.go @@ -15,17 +15,21 @@ package core import ( + "fmt" "sort" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/texttree" + "go.uber.org/zap" ) // FlatPhysicalPlan provides an easier structure to traverse a plan and collect needed information. // Note: Although it's named FlatPhysicalPlan, there also could be Insert, Delete and Update at the beginning of Main. type FlatPhysicalPlan struct { - Main FlatPlanTree - CTEs []FlatPlanTree + Main FlatPlanTree + CTEs []FlatPlanTree + ScalarSubQueries []FlatPlanTree // InExecute and InExplain are expected to handle some special cases. Usually you don't need to use them. @@ -197,6 +201,18 @@ func FlattenPhysicalPlan(p Plan, buildSideFirst bool) *FlatPhysicalPlan { res.CTEs = append(res.CTEs, cteExplained) flattenedCTEPlan[cteDef.CTE.IDForStorage] = struct{}{} } + if p.SCtx() == nil || p.SCtx().GetSessionVars() == nil { + return res + } + for _, scalarSubQ := range p.SCtx().GetSessionVars().MapScalarSubQ { + castedScalarSubQ, ok := scalarSubQ.(*ScalarSubqueryEvalCtx) + if !ok { + logutil.BgLogger().Debug("Wrong item regiestered as scalar subquery", zap.String("the wrong item", fmt.Sprintf("%T", scalarSubQ))) + continue + } + subQExplained := res.flattenScalarSubQRecursively(castedScalarSubQ, initInfo, nil) + res.ScalarSubQueries = append(res.ScalarSubQueries, subQExplained) + } return res } @@ -499,3 +515,26 @@ func (f *FlatPhysicalPlan) flattenCTERecursively(cteDef *CTEDefinition, info *op } return target } + +func (f *FlatPhysicalPlan) flattenScalarSubQRecursively(scalarSubQ *ScalarSubqueryEvalCtx, info *operatorCtx, target FlatPlanTree) FlatPlanTree { + flat := f.flattenSingle(scalarSubQ, info) + if flat != nil { + target = append(target, flat) + } + childIdxs := make([]int, 0) + var childIdx int + childInfo := &operatorCtx{ + depth: info.depth + 1, + label: Empty, + isRoot: true, + storeType: kv.TiDB, + indent: texttree.Indent4Child(info.indent, info.isLastChild), + isLastChild: true, + } + target, childIdx = f.flattenRecursively(scalarSubQ.scalarSubQuery, childInfo, target) + childIdxs = append(childIdxs, childIdx) + if flat != nil { + flat.ChildrenIdx = childIdxs + } + return target +} diff --git a/planner/core/initialize.go b/planner/core/initialize.go index 3b9a9f9371c94..8c111a7a4a959 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -677,3 +677,9 @@ func (p PhysicalSequence) Init(ctx sessionctx.Context, stats *property.StatsInfo p.childrenReqProps = props return &p } + +// Init initializes ScalarSubqueryEvalCtx +func (p ScalarSubqueryEvalCtx) Init(ctx sessionctx.Context, offset int) *ScalarSubqueryEvalCtx { + p.basePlan = newBasePlan(ctx, plancodec.TypeScalarSubQuery, offset) + return &p +} diff --git a/planner/core/scalar_subq_expression.go b/planner/core/scalar_subq_expression.go new file mode 100644 index 0000000000000..9d9f8c478eaef --- /dev/null +++ b/planner/core/scalar_subq_expression.go @@ -0,0 +1,345 @@ +// Copyright 2023 PingCAP, Ins. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" +) + +// ScalarSubqueryEvalCtx store the plan for the subquery, used by ScalarSubQueryExpr. +type ScalarSubqueryEvalCtx struct { + basePlan + + // The context for evaluating the subquery. + scalarSubQuery PhysicalPlan + ctx context.Context + is infoschema.InfoSchema + evalErr error + evaled bool + + outputColIDs []int64 + colsData []types.Datum +} + +func (ssctx *ScalarSubqueryEvalCtx) getColVal(colID int64) (*types.Datum, error) { + err := ssctx.selfEval() + if err != nil { + return nil, err + } + for i, id := range ssctx.outputColIDs { + if id == colID { + return &ssctx.colsData[i], nil + } + } + return nil, errors.Errorf("Could not found the ScalarSubQueryExpr#%d in the ScalarSubquery_%d", colID, ssctx.ID()) +} + +func (ssctx *ScalarSubqueryEvalCtx) selfEval() error { + if ssctx.evaled { + return ssctx.evalErr + } + ssctx.evaled = true + row, err := EvalSubqueryFirstRow(ssctx.ctx, ssctx.scalarSubQuery, ssctx.is, ssctx.SCtx()) + if err != nil { + ssctx.evalErr = err + return err + } + ssctx.colsData = row + return nil +} + +// ScalarSubQueryExpr is a expression placeholder for the non-correlated scalar subqueries which can be evaluated during optimizing phase. +// TODO: The methods related with evaluate the function will be revised in next step. +type ScalarSubQueryExpr struct { + scalarSubqueryColID int64 + + // The context for evaluating the subquery. + evalCtx *ScalarSubqueryEvalCtx + evalErr error + evaled bool + + hashcode []byte + + expression.Constant +} + +// Traverse implements the TraverseDown interface. +func (s *ScalarSubQueryExpr) Traverse(_ expression.TraverseAction) expression.Expression { + return s +} + +func (s *ScalarSubQueryExpr) selfEvaluate() error { + colVal, err := s.evalCtx.getColVal(s.scalarSubqueryColID) + if err != nil { + s.evalErr = err + s.Constant = *expression.NewNull() + return err + } + s.Constant.Value = *colVal + s.evaled = true + return nil +} + +// Eval implements the Expression interface. +func (s *ScalarSubQueryExpr) Eval(_ chunk.Row) (types.Datum, error) { + if s.evaled { + return s.Value, nil + } + if s.evalErr != nil { + return s.Value, s.evalErr + } + err := s.selfEvaluate() + return s.Value, err +} + +// EvalInt returns the int64 representation of expression. +func (*ScalarSubQueryExpr) EvalInt(_ sessionctx.Context, _ chunk.Row) (val int64, isNull bool, err error) { + return 0, false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// EvalReal returns the float64 representation of expression. +func (*ScalarSubQueryExpr) EvalReal(_ sessionctx.Context, _ chunk.Row) (val float64, isNull bool, err error) { + return 0, false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// EvalString returns the string representation of expression. +func (*ScalarSubQueryExpr) EvalString(_ sessionctx.Context, _ chunk.Row) (val string, isNull bool, err error) { + return "", false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// EvalDecimal returns the decimal representation of expression. +func (*ScalarSubQueryExpr) EvalDecimal(_ sessionctx.Context, _ chunk.Row) (val *types.MyDecimal, isNull bool, err error) { + return nil, false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// EvalTime returns the DATE/DATETIME/TIMESTAMP representation of expression. +func (*ScalarSubQueryExpr) EvalTime(_ sessionctx.Context, _ chunk.Row) (val types.Time, isNull bool, err error) { + return types.ZeroTime, false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// EvalDuration returns the duration representation of expression. +func (*ScalarSubQueryExpr) EvalDuration(_ sessionctx.Context, _ chunk.Row) (val types.Duration, isNull bool, err error) { + return types.ZeroDuration, false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// EvalJSON returns the JSON representation of expression. +func (*ScalarSubQueryExpr) EvalJSON(_ sessionctx.Context, _ chunk.Row) (val types.BinaryJSON, isNull bool, err error) { + return types.BinaryJSON{}, false, errors.Errorf("Evaluation methods is not implemented for ScalarSubQueryExpr") +} + +// GetType implements the Expression interface. +func (s *ScalarSubQueryExpr) GetType() *types.FieldType { + return s.RetType +} + +// Clone copies an expression totally. +func (s *ScalarSubQueryExpr) Clone() expression.Expression { + ret := *s + ret.RetType = s.RetType.Clone() + return &ret +} + +// Equal implements the Expression interface. +func (s *ScalarSubQueryExpr) Equal(_ sessionctx.Context, e expression.Expression) bool { + anotherS, ok := e.(*ScalarSubQueryExpr) + if !ok { + return false + } + if s.scalarSubqueryColID == anotherS.scalarSubqueryColID { + return true + } + return false +} + +// IsCorrelated implements the Expression interface. +func (*ScalarSubQueryExpr) IsCorrelated() bool { + return false +} + +// ConstItem implements the Expression interface. +func (*ScalarSubQueryExpr) ConstItem(_ *stmtctx.StatementContext) bool { + return true +} + +// Decorrelate implements the Expression interface. +func (s *ScalarSubQueryExpr) Decorrelate(_ *expression.Schema) expression.Expression { + return s +} + +// resolveIndices implements the Expression interface. +func (*ScalarSubQueryExpr) resolveIndices(_ *expression.Schema) error { + return nil +} + +// ResolveIndices implements the Expression interface. +func (s *ScalarSubQueryExpr) ResolveIndices(_ *expression.Schema) (expression.Expression, error) { + return s, nil +} + +// ResolveIndicesByVirtualExpr implements the Expression interface. +func (s *ScalarSubQueryExpr) ResolveIndicesByVirtualExpr(_ *expression.Schema) (expression.Expression, bool) { + return s, false +} + +// resolveIndicesByVirtualExpr implements the Expression interface. +func (*ScalarSubQueryExpr) resolveIndicesByVirtualExpr(_ *expression.Schema) bool { + return false +} + +// RemapColumn implements the Expression interface. +func (s *ScalarSubQueryExpr) RemapColumn(_ map[int64]*expression.Column) (expression.Expression, error) { + return s, nil +} + +// ExplainInfo implements the Expression interface. +func (s *ScalarSubQueryExpr) ExplainInfo() string { + return s.String() +} + +// ExplainNormalizedInfo implements the Expression interface. +func (s *ScalarSubQueryExpr) ExplainNormalizedInfo() string { + return s.String() +} + +// HashCode implements the Expression interface. +func (s *ScalarSubQueryExpr) HashCode(_ *stmtctx.StatementContext) []byte { + if len(s.hashcode) != 0 { + return s.hashcode + } + s.hashcode = make([]byte, 0, 9) + s.hashcode = append(s.hashcode, expression.ScalarSubQFlag) + s.hashcode = codec.EncodeInt(s.hashcode, s.scalarSubqueryColID) + return s.hashcode +} + +// MemoryUsage implements the Expression interface. +func (s *ScalarSubQueryExpr) MemoryUsage() int64 { + ret := int64(0) + if s.evaled { + ret += s.Constant.MemoryUsage() + } + return ret +} + +// String implements the Stringer interface. +func (s *ScalarSubQueryExpr) String() string { + builder := &strings.Builder{} + fmt.Fprintf(builder, "ScalarQueryCol#%d", s.scalarSubqueryColID) + return builder.String() +} + +// MarshalJSON implements the goJSON.Marshaler interface. +func (s *ScalarSubQueryExpr) MarshalJSON() ([]byte, error) { + if s.evalErr != nil { + return nil, s.evalErr + } + if s.evaled { + return s.Constant.MarshalJSON() + } + err := s.selfEvaluate() + if err != nil { + return nil, err + } + return s.Constant.MarshalJSON() +} + +// ReverseEval evaluates the only one column value with given function result. +func (s *ScalarSubQueryExpr) ReverseEval(_ *stmtctx.StatementContext, _ types.Datum, _ types.RoundingType) (val types.Datum, err error) { + if s.evalErr != nil { + return s.Value, s.evalErr + } + if s.evaled { + return s.Value, nil + } + err = s.selfEvaluate() + if err != nil { + return s.Value, err + } + return s.Value, nil +} + +// SupportReverseEval implements the Expression interface. +func (*ScalarSubQueryExpr) SupportReverseEval() bool { + return true +} + +// VecEvalInt evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalInt(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// VecEvalReal evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalReal(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// VecEvalString evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalString(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// VecEvalDecimal evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalDecimal(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// VecEvalTime evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalTime(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// VecEvalDuration evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalDuration(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// VecEvalJSON evaluates this expression in a vectorized manner. +func (*ScalarSubQueryExpr) VecEvalJSON(_ sessionctx.Context, _ *chunk.Chunk, _ *chunk.Column) error { + return errors.Errorf("ScalarSubQueryExpr doesn't implement the vec eval yet") +} + +// Vectorized returns whether the expression can be vectorized. +func (*ScalarSubQueryExpr) Vectorized() bool { + return true +} + +// Schema implements the Plan interface. +func (*ScalarSubqueryEvalCtx) Schema() *expression.Schema { + return nil +} + +// ExplainInfo implements the Plan interface. +func (ssctx *ScalarSubqueryEvalCtx) ExplainInfo() string { + builder := &strings.Builder{} + fmt.Fprintf(builder, "Output: ") + for i, id := range ssctx.outputColIDs { + fmt.Fprintf(builder, "ScalarQueryCol#%d", id) + if i+1 != len(ssctx.outputColIDs) { + fmt.Fprintf(builder, ",") + } + } + return builder.String() +} diff --git a/planner/optimize.go b/planner/optimize.go index 1b8d798c6c085..b0c5b0070d677 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -546,6 +546,7 @@ func OptimizeExecStmt(ctx context.Context, sctx sessionctx.Context, func buildLogicalPlan(ctx context.Context, sctx sessionctx.Context, node ast.Node, builder *core.PlanBuilder) (core.Plan, error) { sctx.GetSessionVars().PlanID.Store(0) sctx.GetSessionVars().PlanColumnID.Store(0) + sctx.GetSessionVars().MapScalarSubQ = nil sctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = nil failpoint.Inject("mockRandomPlanID", func() { diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index ee5f99b600181..81c07b29f1c4c 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -159,6 +159,7 @@ type StatementContext struct { InSelectStmt bool InLoadDataStmt bool InExplainStmt bool + InExplainAnalyzeStmt bool ExplainFormat string InCreateOrAlterStmt bool InSetSessionStatesStmt bool diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 91206ff8df2c5..2cbca84a997fd 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -744,6 +744,9 @@ type SessionVars struct { // PlanColumnID is the unique id for column when building plan. PlanColumnID atomic.Int64 + // MapScalarSubQ maps the scalar sub queries from its ID to its struct. + MapScalarSubQ []interface{} + // MapHashCode2UniqueID4ExtendedCol map the expr's hash code to specified unique ID. MapHashCode2UniqueID4ExtendedCol map[string]int @@ -822,6 +825,8 @@ type SessionVars struct { // Enable3StageMultiDistinctAgg indicates whether to allow 3 stage multi distinct aggregate Enable3StageMultiDistinctAgg bool + ExplainNonEvaledSubQuery bool + // MultiStatementMode permits incorrect client library usage. Not recommended to be turned on. MultiStatementMode int @@ -2149,6 +2154,11 @@ func (s *SessionVars) AllocPlanColumnID() int64 { return s.PlanColumnID.Add(1) } +// RegisterScalarSubQ register a scalar sub query into the map. This will be used for EXPLAIN. +func (s *SessionVars) RegisterScalarSubQ(scalarSubQ interface{}) { + s.MapScalarSubQ = append(s.MapScalarSubQ, scalarSubQ) +} + // GetCharsetInfo gets charset and collation for current context. // What character set should the server translate a statement to after receiving it? // For this, the server uses the character_set_connection and collation_connection system variables. diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 4b40520cc4393..4beb1d39b7f6e 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -211,6 +211,10 @@ var defaultSysVars = []*SysVar{ s.Enable3StageMultiDistinctAgg = TiDBOptOn(val) return nil }}, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBOptExplainNoEvaledSubQuery, Value: BoolToOnOff(DefTiDBOptExplainEvaledSubquery), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { + s.ExplainNonEvaledSubQuery = TiDBOptOn(val) + return nil + }}, {Scope: ScopeSession, Name: TiDBOptWriteRowID, Value: BoolToOnOff(DefOptWriteRowID), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { s.AllowWriteRowID = TiDBOptOn(val) return nil diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 29e307c520bd8..bc8c556cdbd3c 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -71,6 +71,8 @@ const ( // TiDBOptEnable3StageMultiDistinctAgg is used to indicate whether to plan and execute the multi distinct agg in 3 stages TiDBOptEnable3StageMultiDistinctAgg = "tidb_opt_enable_three_stage_multi_distinct_agg" + TiDBOptExplainNoEvaledSubQuery = "tidb_opt_enable_non_eval_scalar_subquery" + // TiDBBCJThresholdSize is used to limit the size of small table for mpp broadcast join. // Its unit is bytes, if the size of small table is larger than it, we will not use bcj. TiDBBCJThresholdSize = "tidb_broadcast_join_threshold_size" @@ -1262,6 +1264,7 @@ const ( DefTiDBSkewDistinctAgg = false DefTiDB3StageDistinctAgg = true DefTiDB3StageMultiDistinctAgg = false + DefTiDBOptExplainEvaledSubquery = false DefTiDBReadStaleness = 0 DefTiDBGCMaxWaitTime = 24 * 60 * 60 DefMaxAllowedPacket uint64 = 67108864 diff --git a/util/plancodec/id.go b/util/plancodec/id.go index d1e59edceb07f..096f0acd9c70b 100644 --- a/util/plancodec/id.go +++ b/util/plancodec/id.go @@ -137,6 +137,8 @@ const ( TypeImportInto = "ImportInto" // TypeSequence is the type of Sequence TypeSequence = "Sequence" + // TypeScalarSubQuery is the type of ScalarQuery + TypeScalarSubQuery = "ScalarSubQuery" ) // plan id. @@ -201,6 +203,7 @@ const ( typeForeignKeyCascade int = 57 typeExpandID int = 58 typeImportIntoID int = 59 + TypeScalarSubQueryID int = 60 ) // TypeStringToPhysicalID converts the plan type string to plan id. @@ -324,6 +327,8 @@ func TypeStringToPhysicalID(tp string) int { return typeExpandID case TypeImportInto: return typeImportIntoID + case TypeScalarSubQuery: + return TypeScalarSubQueryID } // Should never reach here. return 0 @@ -450,6 +455,8 @@ func PhysicalIDToTypeString(id int) string { return TypeExpand case typeImportIntoID: return TypeImportInto + case TypeScalarSubQueryID: + return TypeScalarSubQuery } // Should never reach here.