diff --git a/pkg/domain/BUILD.bazel b/pkg/domain/BUILD.bazel index 8bd86d46c8785..82951c64c6684 100644 --- a/pkg/domain/BUILD.bazel +++ b/pkg/domain/BUILD.bazel @@ -67,6 +67,7 @@ go_library( "//pkg/types", "//pkg/util", "//pkg/util/chunk", + "//pkg/util/context", "//pkg/util/dbterror", "//pkg/util/disttask", "//pkg/util/domainutil", diff --git a/pkg/domain/domainctx.go b/pkg/domain/domainctx.go index c541b289365cb..55e96cd48d920 100644 --- a/pkg/domain/domainctx.go +++ b/pkg/domain/domainctx.go @@ -15,7 +15,7 @@ package domain import ( - "github.com/pingcap/tidb/pkg/sessionctx" + contextutil "github.com/pingcap/tidb/pkg/util/context" ) // domainKeyType is a dummy type to avoid naming collision in context. @@ -29,12 +29,12 @@ func (domainKeyType) String() string { const domainKey domainKeyType = 0 // BindDomain binds domain to context. -func BindDomain(ctx sessionctx.Context, domain *Domain) { +func BindDomain(ctx contextutil.ValueStoreContext, domain *Domain) { ctx.SetValue(domainKey, domain) } // GetDomain gets domain from context. -func GetDomain(ctx sessionctx.Context) *Domain { +func GetDomain(ctx contextutil.ValueStoreContext) *Domain { v, ok := ctx.Value(domainKey).(*Domain) if !ok { return nil diff --git a/pkg/executor/BUILD.bazel b/pkg/executor/BUILD.bazel index c8b086896022a..c0cc78906c295 100644 --- a/pkg/executor/BUILD.bazel +++ b/pkg/executor/BUILD.bazel @@ -150,6 +150,7 @@ go_library( "//pkg/parser/types", "//pkg/planner", "//pkg/planner/cardinality", + "//pkg/planner/context", "//pkg/planner/core", "//pkg/planner/util", "//pkg/planner/util/fixcontrol", diff --git a/pkg/executor/compact_table.go b/pkg/executor/compact_table.go index 110d03b54a4b2..33b4b19cd4348 100644 --- a/pkg/executor/compact_table.go +++ b/pkg/executor/compact_table.go @@ -51,7 +51,7 @@ const ( func getTiFlashStores(ctx sessionctx.Context) ([]infoschema.ServerInfo, error) { // TODO: Don't use infoschema, to preserve StoreID information. aliveTiFlashStores := make([]infoschema.ServerInfo, 0) - stores, err := infoschema.GetStoreServerInfo(ctx) + stores, err := infoschema.GetStoreServerInfo(ctx.GetStore()) if err != nil { return nil, err } diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 351d4f1b97b32..c49af7fdca008 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -52,6 +52,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" + planctx "github.com/pingcap/tidb/pkg/planner/context" plannercore "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/privilege" @@ -74,6 +75,7 @@ import ( "github.com/pingcap/tidb/pkg/util/deadlockhistory" "github.com/pingcap/tidb/pkg/util/disk" "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/logutil/consistency" "github.com/pingcap/tidb/pkg/util/memory" @@ -1457,9 +1459,9 @@ func init() { // While doing optimization in the plan package, we need to execute uncorrelated subquery, // but the plan package cannot import the executor package because of the dependency cycle. // So we assign a function implemented in the executor package to the plan package to avoid the dependency cycle. - plannercore.EvalSubqueryFirstRow = func(ctx context.Context, p plannercore.PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) ([]types.Datum, error) { + plannercore.EvalSubqueryFirstRow = func(ctx context.Context, p plannercore.PhysicalPlan, is infoschema.InfoSchema, pctx planctx.PlanContext) ([]types.Datum, error) { defer func(begin time.Time) { - s := sctx.GetSessionVars() + s := pctx.GetSessionVars() s.StmtCtx.SetSkipPlanCache(errors.NewNoStackError("query has uncorrelated sub-queries is un-cacheable")) s.RewritePhaseInfo.PreprocessSubQueries++ s.RewritePhaseInfo.DurationPreprocessSubQuery += time.Since(begin) @@ -1468,6 +1470,12 @@ func init() { r, ctx := tracing.StartRegionEx(ctx, "executor.EvalSubQuery") defer r.End() + sctx, ok := pctx.(sessionctx.Context) + intest.Assert(ok) + if !ok { + return nil, errors.New("plan context should be sessionctx.Context to EvalSubqueryFirstRow") + } + e := newExecutorBuilder(sctx, is, nil) executor := e.build(p) if e.err != nil { diff --git a/pkg/executor/infoschema_reader.go b/pkg/executor/infoschema_reader.go index 3f045feb222f3..af39130847b8f 100644 --- a/pkg/executor/infoschema_reader.go +++ b/pkg/executor/infoschema_reader.go @@ -3130,7 +3130,7 @@ func (e *TiFlashSystemTableRetriever) retrieve(ctx context.Context, sctx session } func (e *TiFlashSystemTableRetriever) initialize(sctx sessionctx.Context, tiflashInstances set.StringSet) error { - storeInfo, err := infoschema.GetStoreServerInfo(sctx) + storeInfo, err := infoschema.GetStoreServerInfo(sctx.GetStore()) if err != nil { return err } diff --git a/pkg/executor/memtable_reader.go b/pkg/executor/memtable_reader.go index 8c68df436f1cf..773057202c0c2 100644 --- a/pkg/executor/memtable_reader.go +++ b/pkg/executor/memtable_reader.go @@ -319,7 +319,7 @@ func (e *clusterServerInfoRetriever) retrieve(ctx context.Context, sctx sessionc return nil, err } serversInfo = infoschema.FilterClusterServerInfo(serversInfo, e.extractor.NodeTypes, e.extractor.Instances) - return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx, serversInfo, e.serverInfoType, true) + return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, e.serverInfoType, true) } func parseFailpointServerInfo(s string) []infoschema.ServerInfo { diff --git a/pkg/executor/metrics_reader.go b/pkg/executor/metrics_reader.go index 7637fc51e6734..bdda123b1546a 100644 --- a/pkg/executor/metrics_reader.go +++ b/pkg/executor/metrics_reader.go @@ -120,7 +120,7 @@ func (e *MetricRetriever) queryMetric(ctx context.Context, sctx sessionctx.Conte promQLAPI := promv1.NewAPI(promClient) ctx, cancel := context.WithTimeout(ctx, promReadTimeout) defer cancel() - promQL := e.tblDef.GenPromQL(sctx, e.extractor.LabelConditions, quantile) + promQL := e.tblDef.GenPromQL(sctx.GetSessionVars().MetricSchemaRangeDuration, e.extractor.LabelConditions, quantile) // Add retry to avoid network error. for i := 0; i < 5; i++ { diff --git a/pkg/infoschema/metrics_schema.go b/pkg/infoschema/metrics_schema.go index 75ab92f1d065f..b8cbfd4992719 100644 --- a/pkg/infoschema/metrics_schema.go +++ b/pkg/infoschema/metrics_schema.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/meta/autoid" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/set" @@ -102,11 +101,11 @@ func (def *MetricTableDef) genColumnInfos() []columnInfo { } // GenPromQL generates the promQL. -func (def *MetricTableDef) GenPromQL(sctx sessionctx.Context, labels map[string]set.StringSet, quantile float64) string { +func (def *MetricTableDef) GenPromQL(metricsSchemaRangeDuration int64, labels map[string]set.StringSet, quantile float64) string { promQL := def.PromQL promQL = strings.ReplaceAll(promQL, promQLQuantileKey, strconv.FormatFloat(quantile, 'f', -1, 64)) promQL = strings.ReplaceAll(promQL, promQLLabelConditionKey, def.genLabelCondition(labels)) - promQL = strings.ReplaceAll(promQL, promQRangeDurationKey, strconv.FormatInt(sctx.GetSessionVars().MetricSchemaRangeDuration, 10)+"s") + promQL = strings.ReplaceAll(promQL, promQRangeDurationKey, strconv.FormatInt(metricsSchemaRangeDuration, 10)+"s") return promQL } diff --git a/pkg/infoschema/perfschema/tables.go b/pkg/infoschema/perfschema/tables.go index e369f21309bec..aea648135f298 100644 --- a/pkg/infoschema/perfschema/tables.go +++ b/pkg/infoschema/perfschema/tables.go @@ -302,7 +302,7 @@ func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGorout ) switch nodeType { case "tikv": - servers, err = infoschema.GetStoreServerInfo(ctx) + servers, err = infoschema.GetStoreServerInfo(ctx.GetStore()) case "pd": servers, err = infoschema.GetPDServerInfo(ctx) default: diff --git a/pkg/infoschema/tables.go b/pkg/infoschema/tables.go index fd80896bd5361..715cd3be29e4a 100644 --- a/pkg/infoschema/tables.go +++ b/pkg/infoschema/tables.go @@ -1806,7 +1806,9 @@ func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { type retriever func(ctx sessionctx.Context) ([]ServerInfo, error) //nolint: prealloc var servers []ServerInfo - for _, r := range []retriever{GetTiDBServerInfo, GetPDServerInfo, GetStoreServerInfo, GetTiProxyServerInfo} { + for _, r := range []retriever{GetTiDBServerInfo, GetPDServerInfo, func(ctx sessionctx.Context) ([]ServerInfo, error) { + return GetStoreServerInfo(ctx.GetStore()) + }, GetTiProxyServerInfo} { nodes, err := r(ctx) if err != nil { return nil, err @@ -1968,7 +1970,7 @@ func isTiFlashWriteNode(store *metapb.Store) bool { } // GetStoreServerInfo returns all store nodes(TiKV or TiFlash) cluster information -func GetStoreServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { +func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { failpoint.Inject("mockStoreServerInfo", func(val failpoint.Value) { if s := val.(string); len(s) > 0 { var servers []ServerInfo @@ -1987,7 +1989,6 @@ func GetStoreServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { } }) - store := ctx.GetStore() // Get TiKV servers info. tikvStore, ok := store.(tikv.Storage) if !ok { @@ -2051,7 +2052,7 @@ func GetTiFlashStoreCount(ctx sessionctx.Context) (cnt uint64, err error) { } }) - stores, err := GetStoreServerInfo(ctx) + stores, err := GetStoreServerInfo(ctx.GetStore()) if err != nil { return cnt, err } @@ -2426,11 +2427,11 @@ func (vt *VirtualTable) Type() table.Type { } // GetTiFlashServerInfo returns all TiFlash server infos -func GetTiFlashServerInfo(sctx sessionctx.Context) ([]ServerInfo, error) { +func GetTiFlashServerInfo(store kv.Storage) ([]ServerInfo, error) { if config.GetGlobalConfig().DisaggregatedTiFlash { return nil, table.ErrUnsupportedOp } - serversInfo, err := GetStoreServerInfo(sctx) + serversInfo, err := GetStoreServerInfo(store) if err != nil { return nil, err } @@ -2439,7 +2440,7 @@ func GetTiFlashServerInfo(sctx sessionctx.Context) ([]ServerInfo, error) { } // FetchClusterServerInfoWithoutPrivilegeCheck fetches cluster server information -func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, sctx sessionctx.Context, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) { +func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, vars *variable.SessionVars, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) { type result struct { idx int rows [][]types.Datum @@ -2476,7 +2477,7 @@ func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, sctx sessi for result := range ch { if result.err != nil { if recordWarningInStmtCtx { - sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) + vars.StmtCtx.AppendWarning(result.err) } else { log.Warn(result.err.Error()) } diff --git a/pkg/planner/cardinality/BUILD.bazel b/pkg/planner/cardinality/BUILD.bazel index 5cc11520d3f3c..8010960fb1f5f 100644 --- a/pkg/planner/cardinality/BUILD.bazel +++ b/pkg/planner/cardinality/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//pkg/parser/format", "//pkg/parser/model", "//pkg/parser/mysql", + "//pkg/planner/context", "//pkg/planner/property", "//pkg/planner/util", "//pkg/planner/util/debugtrace", diff --git a/pkg/planner/cardinality/cross_estimation.go b/pkg/planner/cardinality/cross_estimation.go index c1430d07d9deb..d045aef70ed93 100644 --- a/pkg/planner/cardinality/cross_estimation.go +++ b/pkg/planner/cardinality/cross_estimation.go @@ -18,9 +18,9 @@ import ( "math" "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/ranger" @@ -33,7 +33,7 @@ const SelectionFactor = 0.8 // AdjustRowCountForTableScanByLimit will adjust the row count for table scan by limit. // For a query like `select pk from t using index(primary) where pk > 10 limit 1`, the row count of the table scan // should be adjusted by the limit number 1, because only one row is returned. -func AdjustRowCountForTableScanByLimit(sctx sessionctx.Context, +func AdjustRowCountForTableScanByLimit(sctx context.PlanContext, dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table, path *util.AccessPath, expectedCnt float64, desc bool) float64 { rowCount := path.CountAfterAccess @@ -67,7 +67,7 @@ func AdjustRowCountForTableScanByLimit(sctx sessionctx.Context, // `select * from tbl where a = 1 order by pk limit 1` // if order of column `a` is strictly correlated with column `pk`, the row count of table scan should be: // `1 + row_count(a < 1 or a is null)` -func crossEstimateTableRowCount(sctx sessionctx.Context, +func crossEstimateTableRowCount(sctx context.PlanContext, dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table, path *util.AccessPath, expectedCnt float64, desc bool) (float64, bool, float64) { if dsStatisticTable.Pseudo || len(path.TableFilters) == 0 || !sctx.GetSessionVars().EnableCorrelationAdjustment { @@ -80,7 +80,7 @@ func crossEstimateTableRowCount(sctx sessionctx.Context, // AdjustRowCountForIndexScanByLimit will adjust the row count for table scan by limit. // For a query like `select k from t using index(k) where k > 10 limit 1`, the row count of the index scan // should be adjusted by the limit number 1, because only one row is returned. -func AdjustRowCountForIndexScanByLimit(sctx sessionctx.Context, +func AdjustRowCountForIndexScanByLimit(sctx context.PlanContext, dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table, path *util.AccessPath, expectedCnt float64, desc bool) float64 { rowCount := path.CountAfterAccess @@ -114,7 +114,7 @@ func AdjustRowCountForIndexScanByLimit(sctx sessionctx.Context, // `select * from tbl where a = 1 order by b limit 1` // if order of column `a` is strictly correlated with column `b`, the row count of IndexScan(b) should be: // `1 + row_count(a < 1 or a is null)` -func crossEstimateIndexRowCount(sctx sessionctx.Context, +func crossEstimateIndexRowCount(sctx context.PlanContext, dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table, path *util.AccessPath, expectedCnt float64, desc bool) (float64, bool, float64) { filtersLen := len(path.TableFilters) + len(path.IndexFilters) @@ -130,7 +130,7 @@ func crossEstimateIndexRowCount(sctx sessionctx.Context, } // crossEstimateRowCount is the common logic of crossEstimateTableRowCount and crossEstimateIndexRowCount. -func crossEstimateRowCount(sctx sessionctx.Context, +func crossEstimateRowCount(sctx context.PlanContext, dsStatsInfo, dsTableStats *property.StatsInfo, path *util.AccessPath, conds []expression.Expression, col *expression.Column, corr, expectedCnt float64, desc bool) (float64, bool, float64) { @@ -182,7 +182,7 @@ func crossEstimateRowCount(sctx sessionctx.Context, } // getColumnRangeCounts estimates row count for each range respectively. -func getColumnRangeCounts(sctx sessionctx.Context, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) { +func getColumnRangeCounts(sctx context.PlanContext, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) { var err error var count float64 rangeCounts := make([]float64, len(ranges)) diff --git a/pkg/planner/cardinality/join.go b/pkg/planner/cardinality/join.go index 40dc24db8caf7..063682f5c033d 100644 --- a/pkg/planner/cardinality/join.go +++ b/pkg/planner/cardinality/join.go @@ -18,12 +18,12 @@ import ( "math" "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/sessionctx" ) // EstimateFullJoinRowCount estimates the row count of a full join. -func EstimateFullJoinRowCount(sctx sessionctx.Context, +func EstimateFullJoinRowCount(sctx context.PlanContext, isCartesian bool, leftProfile, rightProfile *property.StatsInfo, leftJoinKeys, rightJoinKeys []*expression.Column, diff --git a/pkg/planner/cardinality/pseudo.go b/pkg/planner/cardinality/pseudo.go index 7a8dba62a17fc..18c7a6d1d11aa 100644 --- a/pkg/planner/cardinality/pseudo.go +++ b/pkg/planner/cardinality/pseudo.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/ranger" @@ -163,7 +162,7 @@ func getPseudoRowCountByUnsignedIntRanges(intRanges []*ranger.Range, tableRowCou return rowCount } -func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges []*ranger.Range, +func getPseudoRowCountByIndexRanges(tc types.Context, indexRanges []*ranger.Range, tableRowCount float64, colsLen int) (float64, error) { if tableRowCount == 0 { return 0, nil @@ -171,7 +170,7 @@ func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges [] var totalCount float64 for _, indexRange := range indexRanges { count := tableRowCount - i, err := indexRange.PrefixEqualLen(sc) + i, err := indexRange.PrefixEqualLen(tc) if err != nil { return 0, errors.Trace(err) } @@ -182,7 +181,7 @@ func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges [] if i >= len(indexRange.LowVal) { i = len(indexRange.LowVal) - 1 } - rowCount, err := getPseudoRowCountByColumnRanges(sc, tableRowCount, []*ranger.Range{indexRange}, i) + rowCount, err := getPseudoRowCountByColumnRanges(tc, tableRowCount, []*ranger.Range{indexRange}, i) if err != nil { return 0, errors.Trace(err) } @@ -201,7 +200,7 @@ func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges [] } // getPseudoRowCountByColumnRanges calculate the row count by the ranges if there's no statistics information for this column. -func getPseudoRowCountByColumnRanges(sc *stmtctx.StatementContext, tableRowCount float64, columnRanges []*ranger.Range, colIdx int) (float64, error) { +func getPseudoRowCountByColumnRanges(tc types.Context, tableRowCount float64, columnRanges []*ranger.Range, colIdx int) (float64, error) { var rowCount float64 for _, ran := range columnRanges { if ran.LowVal[colIdx].Kind() == types.KindNull && ran.HighVal[colIdx].Kind() == types.KindMaxValue { @@ -217,7 +216,7 @@ func getPseudoRowCountByColumnRanges(sc *stmtctx.StatementContext, tableRowCount } else if ran.HighVal[colIdx].Kind() == types.KindMaxValue { rowCount += tableRowCount / pseudoLessRate } else { - compare, err := ran.LowVal[colIdx].Compare(sc.TypeCtx(), &ran.HighVal[colIdx], ran.Collators[colIdx]) + compare, err := ran.LowVal[colIdx].Compare(tc, &ran.HighVal[colIdx], ran.Collators[colIdx]) if err != nil { return 0, errors.Trace(err) } diff --git a/pkg/planner/cardinality/row_count_column.go b/pkg/planner/cardinality/row_count_column.go index e5fcde0af830b..7a01e399d12fd 100644 --- a/pkg/planner/cardinality/row_count_column.go +++ b/pkg/planner/cardinality/row_count_column.go @@ -16,6 +16,7 @@ package cardinality import ( "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" @@ -33,7 +34,7 @@ func init() { } // GetRowCountByColumnRanges estimates the row count by a slice of Range. -func GetRowCountByColumnRanges(sctx sessionctx.Context, coll *statistics.HistColl, colID int64, colRanges []*ranger.Range) (result float64, err error) { +func GetRowCountByColumnRanges(sctx context.PlanContext, coll *statistics.HistColl, colID int64, colRanges []*ranger.Range) (result float64, err error) { var name string if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) @@ -50,7 +51,7 @@ func GetRowCountByColumnRanges(sctx sessionctx.Context, coll *statistics.HistCol name = c.Info.Name.O } if !ok || c.IsInvalid(sctx, coll.Pseudo) { - result, err = getPseudoRowCountByColumnRanges(sc, float64(coll.RealtimeCount), colRanges, 0) + result, err = getPseudoRowCountByColumnRanges(sc.TypeCtx(), float64(coll.RealtimeCount), colRanges, 0) if err == nil && sc.EnableOptimizerCETrace && ok { ceTraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats-Pseudo", uint64(result)) } @@ -71,7 +72,7 @@ func GetRowCountByColumnRanges(sctx sessionctx.Context, coll *statistics.HistCol } // GetRowCountByIntColumnRanges estimates the row count by a slice of IntColumnRange. -func GetRowCountByIntColumnRanges(sctx sessionctx.Context, coll *statistics.HistColl, colID int64, intRanges []*ranger.Range) (result float64, err error) { +func GetRowCountByIntColumnRanges(sctx context.PlanContext, coll *statistics.HistColl, colID int64, intRanges []*ranger.Range) (result float64, err error) { var name string if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) @@ -116,7 +117,7 @@ func GetRowCountByIntColumnRanges(sctx sessionctx.Context, coll *statistics.Hist } // equalRowCountOnColumn estimates the row count by a slice of Range and a Datum. -func equalRowCountOnColumn(sctx sessionctx.Context, c *statistics.Column, val types.Datum, encodedVal []byte, realtimeRowCount int64) (result float64, err error) { +func equalRowCountOnColumn(sctx context.PlanContext, c *statistics.Column, val types.Datum, encodedVal []byte, realtimeRowCount int64) (result float64, err error) { if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) debugtrace.RecordAnyValuesWithNames(sctx, "Value", val.String(), "Encoded", encodedVal) @@ -170,7 +171,7 @@ func equalRowCountOnColumn(sctx sessionctx.Context, c *statistics.Column, val ty } // GetColumnRowCount estimates the row count by a slice of Range. -func GetColumnRowCount(sctx sessionctx.Context, c *statistics.Column, ranges []*ranger.Range, realtimeRowCount, modifyCount int64, pkIsHandle bool) (float64, error) { +func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges []*ranger.Range, realtimeRowCount, modifyCount int64, pkIsHandle bool) (float64, error) { sc := sctx.GetSessionVars().StmtCtx debugTrace := sc.EnableOptimizerDebugTrace if debugTrace { @@ -303,7 +304,7 @@ func GetColumnRowCount(sctx sessionctx.Context, c *statistics.Column, ranges []* } // betweenRowCountOnColumn estimates the row count for interval [l, r). -func betweenRowCountOnColumn(sctx sessionctx.Context, c *statistics.Column, l, r types.Datum, lowEncoded, highEncoded []byte) float64 { +func betweenRowCountOnColumn(sctx context.PlanContext, c *statistics.Column, l, r types.Datum, lowEncoded, highEncoded []byte) float64 { histBetweenCnt := c.Histogram.BetweenRowCount(sctx, l, r) if c.StatsVer <= statistics.Version1 { return histBetweenCnt diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index 4fa07a72a1bd7..738fa5580fe92 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -24,8 +24,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" @@ -37,7 +37,7 @@ import ( ) // GetRowCountByIndexRanges estimates the row count by a slice of Range. -func GetRowCountByIndexRanges(sctx sessionctx.Context, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (result float64, err error) { +func GetRowCountByIndexRanges(sctx context.PlanContext, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (result float64, err error) { var name string if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) @@ -64,7 +64,7 @@ func GetRowCountByIndexRanges(sctx sessionctx.Context, coll *statistics.HistColl if idx != nil && idx.Info.Unique { colsLen = len(idx.Info.Columns) } - result, err = getPseudoRowCountByIndexRanges(sc, indexRanges, float64(coll.RealtimeCount), colsLen) + result, err = getPseudoRowCountByIndexRanges(sc.TypeCtx(), indexRanges, float64(coll.RealtimeCount), colsLen) if err == nil && sc.EnableOptimizerCETrace && ok { ceTraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) } @@ -89,7 +89,7 @@ func GetRowCountByIndexRanges(sctx sessionctx.Context, coll *statistics.HistColl return result, errors.Trace(err) } -func getIndexRowCountForStatsV1(sctx sessionctx.Context, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (float64, error) { +func getIndexRowCountForStatsV1(sctx context.PlanContext, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (float64, error) { sc := sctx.GetSessionVars().StmtCtx debugTrace := sc.EnableOptimizerDebugTrace if debugTrace { @@ -214,7 +214,7 @@ func isSingleColIdxNullRange(idx *statistics.Index, ran *ranger.Range) bool { } // It uses the modifyCount to adjust the influence of modifications on the table. -func getIndexRowCountForStatsV2(sctx sessionctx.Context, idx *statistics.Index, coll *statistics.HistColl, indexRanges []*ranger.Range, realtimeRowCount, modifyCount int64) (float64, error) { +func getIndexRowCountForStatsV2(sctx context.PlanContext, idx *statistics.Index, coll *statistics.HistColl, indexRanges []*ranger.Range, realtimeRowCount, modifyCount int64) (float64, error) { sc := sctx.GetSessionVars().StmtCtx debugTrace := sc.EnableOptimizerDebugTrace if debugTrace { @@ -347,7 +347,7 @@ func getIndexRowCountForStatsV2(sctx sessionctx.Context, idx *statistics.Index, var nullKeyBytes, _ = codec.EncodeKey(time.UTC, nil, types.NewDatum(nil)) -func equalRowCountOnIndex(sctx sessionctx.Context, idx *statistics.Index, b []byte, realtimeRowCount int64) (result float64) { +func equalRowCountOnIndex(sctx context.PlanContext, idx *statistics.Index, b []byte, realtimeRowCount int64) (result float64) { if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) debugtrace.RecordAnyValuesWithNames(sctx, "Encoded Value", b) @@ -394,7 +394,7 @@ func equalRowCountOnIndex(sctx sessionctx.Context, idx *statistics.Index, b []by } // expBackoffEstimation estimate the multi-col cases following the Exponential Backoff. See comment below for details. -func expBackoffEstimation(sctx sessionctx.Context, idx *statistics.Index, coll *statistics.HistColl, indexRange *ranger.Range) (sel float64, success bool, err error) { +func expBackoffEstimation(sctx context.PlanContext, idx *statistics.Index, coll *statistics.HistColl, indexRange *ranger.Range) (sel float64, success bool, err error) { if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { @@ -529,7 +529,7 @@ func matchPrefix(row chunk.Row, colIdx int, ad *types.Datum) bool { // betweenRowCountOnIndex estimates the row count for interval [l, r). // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func betweenRowCountOnIndex(sctx sessionctx.Context, idx *statistics.Index, l, r types.Datum) float64 { +func betweenRowCountOnIndex(sctx context.PlanContext, idx *statistics.Index, l, r types.Datum) float64 { histBetweenCnt := idx.Histogram.BetweenRowCount(sctx, l, r) if idx.StatsVer == statistics.Version1 { return histBetweenCnt diff --git a/pkg/planner/cardinality/row_size.go b/pkg/planner/cardinality/row_size.go index e908e73c7d955..b5b63205d3db1 100644 --- a/pkg/planner/cardinality/row_size.go +++ b/pkg/planner/cardinality/row_size.go @@ -20,7 +20,7 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/util/chunk" @@ -29,7 +29,7 @@ import ( const pseudoColSize = 8.0 // GetIndexAvgRowSize computes average row size for a index scan. -func GetIndexAvgRowSize(ctx sessionctx.Context, coll *statistics.HistColl, cols []*expression.Column, isUnique bool) (size float64) { +func GetIndexAvgRowSize(ctx context.PlanContext, coll *statistics.HistColl, cols []*expression.Column, isUnique bool) (size float64) { size = GetAvgRowSize(ctx, coll, cols, true, true) // tablePrefix(1) + tableID(8) + indexPrefix(2) + indexID(8) // Because the cols for index scan always contain the handle, so we don't add the rowID here. @@ -42,7 +42,7 @@ func GetIndexAvgRowSize(ctx sessionctx.Context, coll *statistics.HistColl, cols } // GetTableAvgRowSize computes average row size for a table scan, exclude the index key-value pairs. -func GetTableAvgRowSize(ctx sessionctx.Context, coll *statistics.HistColl, cols []*expression.Column, storeType kv.StoreType, handleInCols bool) (size float64) { +func GetTableAvgRowSize(ctx context.PlanContext, coll *statistics.HistColl, cols []*expression.Column, storeType kv.StoreType, handleInCols bool) (size float64) { size = GetAvgRowSize(ctx, coll, cols, false, true) switch storeType { case kv.TiKV: @@ -58,7 +58,7 @@ func GetTableAvgRowSize(ctx sessionctx.Context, coll *statistics.HistColl, cols } // GetAvgRowSize computes average row size for given columns. -func GetAvgRowSize(ctx sessionctx.Context, coll *statistics.HistColl, cols []*expression.Column, isEncodedKey bool, isForScan bool) (size float64) { +func GetAvgRowSize(ctx context.PlanContext, coll *statistics.HistColl, cols []*expression.Column, isEncodedKey bool, isForScan bool) (size float64) { sessionVars := ctx.GetSessionVars() if coll.Pseudo || len(coll.Columns) == 0 || coll.RealtimeCount == 0 { size = pseudoColSize * float64(len(cols)) diff --git a/pkg/planner/cardinality/selectivity.go b/pkg/planner/cardinality/selectivity.go index 023d56ef60025..1bb595b233daa 100644 --- a/pkg/planner/cardinality/selectivity.go +++ b/pkg/planner/cardinality/selectivity.go @@ -24,9 +24,9 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/planner/context" planutil "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -49,7 +49,7 @@ var ( // should be held when you call this. // Currently the time complexity is o(n^2). func Selectivity( - ctx sessionctx.Context, + ctx context.PlanContext, coll *statistics.HistColl, exprs []expression.Expression, filledPaths []*planutil.AccessPath, @@ -660,7 +660,7 @@ func isColEqCorCol(filter expression.Expression) *expression.Column { // findPrefixOfIndexByCol will find columns in index by checking the unique id or the virtual expression. // So it will return at once no matching column is found. -func findPrefixOfIndexByCol(ctx sessionctx.Context, cols []*expression.Column, idxColIDs []int64, +func findPrefixOfIndexByCol(ctx context.PlanContext, cols []*expression.Column, idxColIDs []int64, cachedPath *planutil.AccessPath) []*expression.Column { if cachedPath != nil { idxCols := cachedPath.IdxCols @@ -681,7 +681,7 @@ func findPrefixOfIndexByCol(ctx sessionctx.Context, cols []*expression.Column, i return expression.FindPrefixOfIndex(cols, idxColIDs) } -func getMaskAndRanges(ctx sessionctx.Context, exprs []expression.Expression, rangeType ranger.RangeType, +func getMaskAndRanges(ctx context.PlanContext, exprs []expression.Expression, rangeType ranger.RangeType, lengths []int, cachedPath *planutil.AccessPath, cols ...*expression.Column) ( mask int64, ranges []*ranger.Range, partCover bool, err error) { isDNF := false @@ -725,7 +725,7 @@ func getMaskAndRanges(ctx sessionctx.Context, exprs []expression.Expression, ran } func getMaskAndSelectivityForMVIndex( - ctx sessionctx.Context, + ctx context.PlanContext, coll *statistics.HistColl, id int64, exprs []expression.Expression, @@ -756,7 +756,7 @@ func getMaskAndSelectivityForMVIndex( // GetSelectivityByFilter try to estimate selectivity of expressions by evaluate the expressions using TopN, Histogram buckets boundaries and NULL. // Currently, this method can only handle expressions involving a single column. -func GetSelectivityByFilter(sctx sessionctx.Context, coll *statistics.HistColl, filters []expression.Expression) (ok bool, selectivity float64, err error) { +func GetSelectivityByFilter(sctx context.PlanContext, coll *statistics.HistColl, filters []expression.Expression) (ok bool, selectivity float64, err error) { // 1. Make sure the expressions // (1) are safe to be evaluated here, // (2) involve only one column, @@ -904,7 +904,7 @@ func GetSelectivityByFilter(sctx sessionctx.Context, coll *statistics.HistColl, return true, res, err } -func findAvailableStatsForCol(sctx sessionctx.Context, coll *statistics.HistColl, uniqueID int64) (isIndex bool, idx int64) { +func findAvailableStatsForCol(sctx context.PlanContext, coll *statistics.HistColl, uniqueID int64) (isIndex bool, idx int64) { // try to find available stats in column stats if colStats, ok := coll.Columns[uniqueID]; ok && colStats != nil && !colStats.IsInvalid(sctx, coll.Pseudo) && colStats.IsFullLoad() { return false, uniqueID @@ -925,7 +925,7 @@ func findAvailableStatsForCol(sctx sessionctx.Context, coll *statistics.HistColl } // getEqualCondSelectivity gets the selectivity of the equal conditions. -func getEqualCondSelectivity(sctx sessionctx.Context, coll *statistics.HistColl, idx *statistics.Index, bytes []byte, +func getEqualCondSelectivity(sctx context.PlanContext, coll *statistics.HistColl, idx *statistics.Index, bytes []byte, usedColsLen int, idxPointRange *ranger.Range) (result float64, err error) { if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) @@ -990,7 +990,7 @@ func getEqualCondSelectivity(sctx sessionctx.Context, coll *statistics.HistColl, // and has the same distribution with analyzed rows, which means each unique value should have the // same number of rows(Tot/NDV) of it. // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func outOfRangeEQSelectivity(sctx sessionctx.Context, ndv, realtimeRowCount, columnRowCount int64) (result float64) { +func outOfRangeEQSelectivity(sctx context.PlanContext, ndv, realtimeRowCount, columnRowCount int64) (result float64) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { @@ -1014,7 +1014,7 @@ func outOfRangeEQSelectivity(sctx sessionctx.Context, ndv, realtimeRowCount, col // crossValidationSelectivity gets the selectivity of multi-column equal conditions by cross validation. func crossValidationSelectivity( - sctx sessionctx.Context, + sctx context.PlanContext, coll *statistics.HistColl, idx *statistics.Index, usedColsLen int, @@ -1082,7 +1082,7 @@ func crossValidationSelectivity( // defined in planner/core package and hard to move here. So we use this trick to avoid the import cycle. var ( CollectFilters4MVIndex func( - sctx sessionctx.Context, + sctx context.PlanContext, filters []expression.Expression, idxCols []*expression.Column, ) ( @@ -1090,7 +1090,7 @@ var ( remainingFilters []expression.Expression, ) BuildPartialPaths4MVIndex func( - sctx sessionctx.Context, + sctx context.PlanContext, accessFilters []expression.Expression, idxCols []*expression.Column, mvIndex *model.IndexInfo, diff --git a/pkg/planner/cardinality/trace.go b/pkg/planner/cardinality/trace.go index b3065b89038ef..a067adbf1e34b 100644 --- a/pkg/planner/cardinality/trace.go +++ b/pkg/planner/cardinality/trace.go @@ -23,8 +23,8 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/format" "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/statistics" driver "github.com/pingcap/tidb/pkg/types/parser_driver" @@ -36,7 +36,7 @@ import ( ) // ceTraceExpr appends an expression and related information into CE trace -func ceTraceExpr(sctx sessionctx.Context, tableID int64, tp string, expr expression.Expression, rowCount float64) { +func ceTraceExpr(sctx context.PlanContext, tableID int64, tp string, expr expression.Expression, rowCount float64) { exprStr, err := exprToString(sctx, expr) if err != nil { logutil.BgLogger().Debug("Failed to trace CE of an expression", zap.String("category", "OptimizerTrace"), @@ -64,7 +64,7 @@ func ceTraceExpr(sctx sessionctx.Context, tableID int64, tp string, expr express // It may be more appropriate to put this in expression package. But currently we only use it for CE trace, // // and it may not be general enough to handle all possible expressions. So we put it here for now. -func exprToString(ctx sessionctx.Context, e expression.Expression) (string, error) { +func exprToString(ctx context.PlanContext, e expression.Expression) (string, error) { switch expr := e.(type) { case *expression.ScalarFunction: var buffer bytes.Buffer @@ -125,7 +125,7 @@ type getRowCountInput struct { } func debugTraceGetRowCountInput( - s sessionctx.Context, + s context.PlanContext, id int64, ranges ranger.Ranges, ) { @@ -141,10 +141,10 @@ func debugTraceGetRowCountInput( } // GetTblInfoForUsedStatsByPhysicalID get table name, partition name and TableInfo that will be used to record used stats. -var GetTblInfoForUsedStatsByPhysicalID func(sctx sessionctx.Context, id int64) (fullName string, tblInfo *model.TableInfo) +var GetTblInfoForUsedStatsByPhysicalID func(sctx context.PlanContext, id int64) (fullName string, tblInfo *model.TableInfo) // recordUsedItemStatsStatus only records un-FullLoad item load status during user query -func recordUsedItemStatsStatus(sctx sessionctx.Context, stats any, tableID, id int64) { +func recordUsedItemStatsStatus(sctx context.PlanContext, stats any, tableID, id int64) { // Sometimes we try to use stats on _tidb_rowid (id == -1), which must be empty, we ignore this case here. if id <= 0 { return @@ -205,7 +205,7 @@ func recordUsedItemStatsStatus(sctx sessionctx.Context, stats any, tableID, id i } // ceTraceRange appends a list of ranges and related information into CE trace -func ceTraceRange(sctx sessionctx.Context, tableID int64, colNames []string, ranges []*ranger.Range, tp string, rowCount uint64) { +func ceTraceRange(sctx context.PlanContext, tableID int64, colNames []string, ranges []*ranger.Range, tp string, rowCount uint64) { sc := sctx.GetSessionVars().StmtCtx tc := sc.TypeCtx() allPoint := true @@ -249,7 +249,7 @@ type startEstimateRangeInfo struct { } func debugTraceStartEstimateRange( - s sessionctx.Context, + s context.PlanContext, r *ranger.Range, lowBytes, highBytes []byte, currentCount float64, @@ -294,7 +294,7 @@ type endEstimateRangeInfo struct { } func debugTraceEndEstimateRange( - s sessionctx.Context, + s context.PlanContext, count float64, addType debugTraceAddRowCountType, ) { diff --git a/pkg/planner/cascades/BUILD.bazel b/pkg/planner/cascades/BUILD.bazel index dc475db286bc2..8c75d382124c7 100644 --- a/pkg/planner/cascades/BUILD.bazel +++ b/pkg/planner/cascades/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//pkg/kv", "//pkg/parser/ast", "//pkg/parser/mysql", + "//pkg/planner/context", "//pkg/planner/core", "//pkg/planner/implementation", "//pkg/planner/memo", diff --git a/pkg/planner/cascades/transformation_rules.go b/pkg/planner/cascades/transformation_rules.go index 1fbe083551036..9977527f3f675 100644 --- a/pkg/planner/cascades/transformation_rules.go +++ b/pkg/planner/cascades/transformation_rules.go @@ -22,10 +22,10 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/context" plannercore "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/planner/memo" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/ranger" "github.com/pingcap/tidb/pkg/util/set" @@ -848,7 +848,7 @@ type pushDownJoin struct { } func (*pushDownJoin) predicatePushDown( - sctx sessionctx.Context, + sctx context.PlanContext, predicates []expression.Expression, join *plannercore.LogicalJoin, leftSchema *expression.Schema, @@ -970,7 +970,7 @@ func (r *PushSelDownJoin) Match(expr *memo.ExprIter) bool { // buildChildSelectionGroup builds a new childGroup if the pushed down condition is not empty. func buildChildSelectionGroup( - sctx sessionctx.Context, + sctx context.PlanContext, qbOffset int, conditions []expression.Expression, childGroup *memo.Group) *memo.Group { @@ -2137,7 +2137,7 @@ func (*TransformAggregateCaseToSelection) allowsSelection(aggFuncName string) bo return aggFuncName != ast.AggFuncFirstRow } -func (*TransformAggregateCaseToSelection) isOnlyOneNotNull(ctx sessionctx.Context, args []expression.Expression, argsNum int, outputIdx int) bool { +func (*TransformAggregateCaseToSelection) isOnlyOneNotNull(ctx expression.EvalContext, args []expression.Expression, argsNum int, outputIdx int) bool { return !args[outputIdx].Equal(ctx, expression.NewNull()) && (argsNum == 2 || args[3-outputIdx].Equal(ctx, expression.NewNull())) } @@ -2541,7 +2541,7 @@ func (*MergeAdjacentWindow) Match(expr *memo.ExprIter) bool { ctx := expr.GetExpr().ExprNode.SCtx() // Whether Partition, OrderBy and Frame parts are the same. - if !(curWinPlan.EqualPartitionBy(ctx, nextWinPlan) && + if !(curWinPlan.EqualPartitionBy(nextWinPlan) && curWinPlan.EqualOrderBy(ctx, nextWinPlan) && curWinPlan.EqualFrame(ctx, nextWinPlan)) { return false diff --git a/pkg/planner/context/BUILD.bazel b/pkg/planner/context/BUILD.bazel new file mode 100644 index 0000000000000..097786e3b6bf9 --- /dev/null +++ b/pkg/planner/context/BUILD.bazel @@ -0,0 +1,15 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "context", + srcs = ["context.go"], + importpath = "github.com/pingcap/tidb/pkg/planner/context", + visibility = ["//visibility:public"], + deps = [ + "//pkg/expression", + "//pkg/kv", + "//pkg/parser/model", + "//pkg/sessionctx/variable", + "//pkg/util/context", + ], +) diff --git a/pkg/planner/context/context.go b/pkg/planner/context/context.go new file mode 100644 index 0000000000000..0298fb554e7fa --- /dev/null +++ b/pkg/planner/context/context.go @@ -0,0 +1,37 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + contextutil "github.com/pingcap/tidb/pkg/util/context" +) + +// PlanContext is the context for building plan. +type PlanContext interface { + expression.BuildContext + contextutil.ValueStoreContext + // GetSessionVars gets the session variables. + GetSessionVars() *variable.SessionVars + // UpdateColStatsUsage updates the column stats usage. + UpdateColStatsUsage(predicateColumns []model.TableItemID) + // GetClient gets a kv.Client. + GetClient() kv.Client + // GetMPPClient gets a kv.MPPClient. + GetMPPClient() kv.MPPClient +} diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index b8e70921ad80b..7488c4b1c2eb3 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -108,6 +108,7 @@ go_library( "//pkg/parser/terror", "//pkg/parser/types", "//pkg/planner/cardinality", + "//pkg/planner/context", "//pkg/planner/core/internal", "//pkg/planner/core/internal/base", "//pkg/planner/core/metrics", diff --git a/pkg/planner/core/access_object.go b/pkg/planner/core/access_object.go index 2c7ae016aa1f9..195495099e3d6 100644 --- a/pkg/planner/core/access_object.go +++ b/pkg/planner/core/access_object.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tipb/go-tipb" ) @@ -39,7 +38,7 @@ type dataAccesser interface { } type partitionAccesser interface { - accessObject(sessionctx.Context) AccessObject + accessObject(PlanContext) AccessObject } // AccessObject represents what is accessed by an operator. @@ -353,7 +352,7 @@ func (p *BatchPointGetPlan) AccessObject() AccessObject { return res } -func getDynamicAccessPartition(sctx sessionctx.Context, tblInfo *model.TableInfo, physPlanPartInfo *PhysPlanPartInfo, asName string) (res *DynamicPartitionAccessObject) { +func getDynamicAccessPartition(sctx PlanContext, tblInfo *model.TableInfo, physPlanPartInfo *PhysPlanPartInfo, asName string) (res *DynamicPartitionAccessObject) { pi := tblInfo.GetPartitionInfo() if pi == nil || !sctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { return nil @@ -394,7 +393,7 @@ func getDynamicAccessPartition(sctx sessionctx.Context, tblInfo *model.TableInfo return res } -func (p *PhysicalTableReader) accessObject(sctx sessionctx.Context) AccessObject { +func (p *PhysicalTableReader) accessObject(sctx PlanContext) AccessObject { if !sctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { return DynamicPartitionAccessObjects(nil) } @@ -448,7 +447,7 @@ func (p *PhysicalTableReader) accessObject(sctx sessionctx.Context) AccessObject return res } -func (p *PhysicalIndexReader) accessObject(sctx sessionctx.Context) AccessObject { +func (p *PhysicalIndexReader) accessObject(sctx PlanContext) AccessObject { if !sctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { return DynamicPartitionAccessObjects(nil) } @@ -464,7 +463,7 @@ func (p *PhysicalIndexReader) accessObject(sctx sessionctx.Context) AccessObject return DynamicPartitionAccessObjects{res} } -func (p *PhysicalIndexLookUpReader) accessObject(sctx sessionctx.Context) AccessObject { +func (p *PhysicalIndexLookUpReader) accessObject(sctx PlanContext) AccessObject { if !sctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { return DynamicPartitionAccessObjects(nil) } @@ -480,7 +479,7 @@ func (p *PhysicalIndexLookUpReader) accessObject(sctx sessionctx.Context) Access return DynamicPartitionAccessObjects{res} } -func (p *PhysicalIndexMergeReader) accessObject(sctx sessionctx.Context) AccessObject { +func (p *PhysicalIndexMergeReader) accessObject(sctx PlanContext) AccessObject { if !sctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { return DynamicPartitionAccessObjects(nil) } diff --git a/pkg/planner/core/common_plans.go b/pkg/planner/core/common_plans.go index 4bf6682dc831e..3b9af0eea6650 100644 --- a/pkg/planner/core/common_plans.go +++ b/pkg/planner/core/common_plans.go @@ -1027,7 +1027,7 @@ func (e *Explain) explainFlatOpInRowFormat(flatOp *FlatOperator) { e.prepareOperatorInfo(flatOp.Origin, taskTp, textTreeExplainID) } -func getRuntimeInfoStr(ctx sessionctx.Context, p Plan, runtimeStatsColl *execdetails.RuntimeStatsColl) (actRows, analyzeInfo, memoryInfo, diskInfo string) { +func getRuntimeInfoStr(ctx PlanContext, p Plan, runtimeStatsColl *execdetails.RuntimeStatsColl) (actRows, analyzeInfo, memoryInfo, diskInfo string) { if runtimeStatsColl == nil { runtimeStatsColl = ctx.GetSessionVars().StmtCtx.RuntimeStatsColl if runtimeStatsColl == nil { @@ -1058,7 +1058,7 @@ func getRuntimeInfoStr(ctx sessionctx.Context, p Plan, runtimeStatsColl *execdet return } -func getRuntimeInfo(ctx sessionctx.Context, p Plan, runtimeStatsColl *execdetails.RuntimeStatsColl) ( +func getRuntimeInfo(ctx PlanContext, p Plan, runtimeStatsColl *execdetails.RuntimeStatsColl) ( rootStats *execdetails.RootRuntimeStats, copStats *execdetails.CopRuntimeStats, memTracker *memory.Tracker, @@ -1180,7 +1180,7 @@ func (e *Explain) getOperatorInfo(p Plan, id string) (estRows, estCost, costForm } // BinaryPlanStrFromFlatPlan generates the compressed and encoded binary plan from a FlatPhysicalPlan. -func BinaryPlanStrFromFlatPlan(explainCtx sessionctx.Context, flat *FlatPhysicalPlan) string { +func BinaryPlanStrFromFlatPlan(explainCtx PlanContext, flat *FlatPhysicalPlan) string { binary := binaryDataFromFlatPlan(explainCtx, flat) if binary == nil { return "" @@ -1193,7 +1193,7 @@ func BinaryPlanStrFromFlatPlan(explainCtx sessionctx.Context, flat *FlatPhysical return str } -func binaryDataFromFlatPlan(explainCtx sessionctx.Context, flat *FlatPhysicalPlan) *tipb.ExplainData { +func binaryDataFromFlatPlan(explainCtx PlanContext, flat *FlatPhysicalPlan) *tipb.ExplainData { if len(flat.Main) == 0 { return nil } @@ -1218,7 +1218,7 @@ func binaryDataFromFlatPlan(explainCtx sessionctx.Context, flat *FlatPhysicalPla return res } -func binaryOpTreeFromFlatOps(explainCtx sessionctx.Context, ops FlatPlanTree) *tipb.ExplainOperator { +func binaryOpTreeFromFlatOps(explainCtx PlanContext, ops FlatPlanTree) *tipb.ExplainOperator { s := make([]tipb.ExplainOperator, len(ops)) for i, op := range ops { binaryOpFromFlatOp(explainCtx, op, &s[i]) @@ -1229,7 +1229,7 @@ func binaryOpTreeFromFlatOps(explainCtx sessionctx.Context, ops FlatPlanTree) *t return &s[0] } -func binaryOpFromFlatOp(explainCtx sessionctx.Context, op *FlatOperator, out *tipb.ExplainOperator) { +func binaryOpFromFlatOp(explainCtx PlanContext, op *FlatOperator, out *tipb.ExplainOperator) { out.Name = op.Origin.ExplainID().String() switch op.Label { case BuildSide: diff --git a/pkg/planner/core/debugtrace.go b/pkg/planner/core/debugtrace.go index b1f3cf5d4d7f5..7262f38cac1fc 100644 --- a/pkg/planner/core/debugtrace.go +++ b/pkg/planner/core/debugtrace.go @@ -172,7 +172,7 @@ type getStatsTblInfo struct { } func debugTraceGetStatsTbl( - s sessionctx.Context, + s PlanContext, tblInfo *model.TableInfo, pid int64, handleIsNil, @@ -249,7 +249,7 @@ func convertAccessPathForDebugTrace(path *util.AccessPath, out *accessPathForDeb } } -func debugTraceAccessPaths(s sessionctx.Context, paths []*util.AccessPath) { +func debugTraceAccessPaths(s PlanContext, paths []*util.AccessPath) { root := debugtrace.GetOrInitDebugTraceRoot(s) traceInfo := make([]accessPathForDebugTrace, len(paths)) for i, partialPath := range paths { diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index e659e28986291..0980eaa3bdba4 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -1976,7 +1976,7 @@ func (ijHelper *indexJoinBuildHelper) buildTemplateRange(matchedKeyCnt int, eqAn return } -func filterIndexJoinBySessionVars(sc sessionctx.Context, indexJoins []PhysicalPlan) []PhysicalPlan { +func filterIndexJoinBySessionVars(sc PlanContext, indexJoins []PhysicalPlan) []PhysicalPlan { if sc.GetSessionVars().EnableIndexMergeJoin { return indexJoins } diff --git a/pkg/planner/core/exhaust_physical_plans_test.go b/pkg/planner/core/exhaust_physical_plans_test.go index 3a1a78b0038ab..967689c573303 100644 --- a/pkg/planner/core/exhaust_physical_plans_test.go +++ b/pkg/planner/core/exhaust_physical_plans_test.go @@ -26,14 +26,13 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/ranger" "github.com/stretchr/testify/require" ) -func rewriteSimpleExpr(ctx sessionctx.Context, str string, schema *expression.Schema, names types.NameSlice) ([]expression.Expression, error) { +func rewriteSimpleExpr(ctx expression.BuildContext, str string, schema *expression.Schema, names types.NameSlice) ([]expression.Expression, error) { if str == "" { return nil, nil } @@ -340,7 +339,7 @@ func TestIndexJoinAnalyzeLookUpFilters(t *testing.T) { } } -func checkRangeFallbackAndReset(t *testing.T, ctx sessionctx.Context, expectedRangeFallback bool) { +func checkRangeFallbackAndReset(t *testing.T, ctx PlanContext, expectedRangeFallback bool) { require.Equal(t, expectedRangeFallback, ctx.GetSessionVars().StmtCtx.RangeFallback) ctx.GetSessionVars().StmtCtx.RangeFallback = false } diff --git a/pkg/planner/core/explain.go b/pkg/planner/core/explain.go index 53f9a482063da..5ea0e7ed8ee8f 100644 --- a/pkg/planner/core/explain.go +++ b/pkg/planner/core/explain.go @@ -27,7 +27,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" @@ -1002,7 +1001,7 @@ func (p *LogicalUnionScan) ExplainInfo() string { return buffer.String() } -func explainByItems(ctx sessionctx.Context, buffer *bytes.Buffer, byItems []*util.ByItems) *bytes.Buffer { +func explainByItems(ctx expression.EvalContext, buffer *bytes.Buffer, byItems []*util.ByItems) *bytes.Buffer { for i, item := range byItems { if item.Desc { fmt.Fprintf(buffer, "%s:desc", item.Expr.ExplainInfo(ctx)) diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index 4f684263b420d..ef6fdfb01ec70 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -49,7 +49,7 @@ import ( ) // EvalSubqueryFirstRow evaluates incorrelated subqueries once, and get first row. -var EvalSubqueryFirstRow func(ctx context.Context, p PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) (row []types.Datum, err error) +var EvalSubqueryFirstRow func(ctx context.Context, p PhysicalPlan, is infoschema.InfoSchema, sctx PlanContext) (row []types.Datum, err error) // evalAstExprWithPlanCtx evaluates ast expression with plan context. // Different with expression.EvalSimpleAst, it uses planner context and is more powerful to build some special expressions diff --git a/pkg/planner/core/find_best_task.go b/pkg/planner/core/find_best_task.go index fd08db711d07f..16b6aafc50476 100644 --- a/pkg/planner/core/find_best_task.go +++ b/pkg/planner/core/find_best_task.go @@ -32,7 +32,6 @@ import ( "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" tidbutil "github.com/pingcap/tidb/pkg/util" @@ -282,7 +281,7 @@ func (p *baseLogicalPlan) enumeratePhysicalPlans4Task( } opt.appendCandidate(p, curTask.plan(), prop) // Get the most efficient one. - if curIsBetter, err := compareTaskCost(p.SCtx(), curTask, bestTask, opt); err != nil { + if curIsBetter, err := compareTaskCost(curTask, bestTask, opt); err != nil { return nil, 0, err } else if curIsBetter { bestTask = curTask @@ -383,7 +382,7 @@ func (p *LogicalSequence) iterateChildPlan( } // compareTaskCost compares cost of curTask and bestTask and returns whether curTask's cost is smaller than bestTask's. -func compareTaskCost(_ sessionctx.Context, curTask, bestTask task, op *physicalOptimizeOp) (curIsBetter bool, err error) { +func compareTaskCost(curTask, bestTask task, op *physicalOptimizeOp) (curIsBetter bool, err error) { curCost, curInvalid, err := getTaskPlanCost(curTask, op) if err != nil { return false, err @@ -635,7 +634,7 @@ func (p *baseLogicalPlan) findBestTask(prop *property.PhysicalProperty, planCoun goto END } opt.appendCandidate(p, curTask.plan(), prop) - if curIsBetter, err := compareTaskCost(p.SCtx(), curTask, bestTask, opt); err != nil { + if curIsBetter, err := compareTaskCost(curTask, bestTask, opt); err != nil { return nil, 0, err } else if curIsBetter { bestTask = curTask @@ -747,7 +746,7 @@ func compareIndexBack(lhs, rhs *candidatePath) (int, bool) { // compareCandidates is the core of skyline pruning, which is used to decide which candidate path is better. // The return value is 1 if lhs is better, -1 if rhs is better, 0 if they are equivalent or not comparable. -func compareCandidates(sctx sessionctx.Context, prop *property.PhysicalProperty, lhs, rhs *candidatePath) int { +func compareCandidates(sctx PlanContext, prop *property.PhysicalProperty, lhs, rhs *candidatePath) int { // This rule is empirical but not always correct. // If x's range row count is significantly lower than y's, for example, 1000 times, we think x is better. if lhs.path.CountAfterAccess > 100 && rhs.path.CountAfterAccess > 100 && // to prevent some extreme cases, e.g. 0.01 : 10 @@ -1096,7 +1095,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter } if unenforcedTask != nil && !unenforcedTask.invalid() { - curIsBest, cerr := compareTaskCost(ds.SCtx(), unenforcedTask, t, opt) + curIsBest, cerr := compareTaskCost(unenforcedTask, t, opt) if cerr != nil { err = cerr return @@ -1147,7 +1146,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter } appendCandidate(ds, idxMergeTask, prop, opt) - curIsBetter, err := compareTaskCost(ds.SCtx(), idxMergeTask, t, opt) + curIsBetter, err := compareTaskCost(idxMergeTask, t, opt) if err != nil { return nil, 0, err } @@ -1250,7 +1249,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter cntPlan++ planCounter.Dec(1) } - curIsBetter, cerr := compareTaskCost(ds.SCtx(), pointGetTask, t, opt) + curIsBetter, cerr := compareTaskCost(pointGetTask, t, opt) if cerr != nil { return nil, 0, cerr } @@ -1284,7 +1283,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter planCounter.Dec(1) } appendCandidate(ds, tblTask, prop, opt) - curIsBetter, err := compareTaskCost(ds.SCtx(), tblTask, t, opt) + curIsBetter, err := compareTaskCost(tblTask, t, opt) if err != nil { return nil, 0, err } @@ -1309,7 +1308,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter planCounter.Dec(1) } appendCandidate(ds, idxTask, prop, opt) - curIsBetter, err := compareTaskCost(ds.SCtx(), idxTask, t, opt) + curIsBetter, err := compareTaskCost(idxTask, t, opt) if err != nil { return nil, 0, err } @@ -1621,7 +1620,7 @@ func (ds *DataSource) buildIndexMergeTableScan(tableFilters []expression.Express // // But the new Selection should exclude the exprs that can NOT be pushed to ALL the storage engines. // Because these exprs have already been put in another Selection(check rule_predicate_push_down). -func extractFiltersForIndexMerge(ctx sessionctx.Context, client kv.Client, filters []expression.Expression) (pushed []expression.Expression, remaining []expression.Expression) { +func extractFiltersForIndexMerge(ctx expression.BuildContext, client kv.Client, filters []expression.Expression) (pushed []expression.Expression, remaining []expression.Expression) { for _, expr := range filters { if expression.CanExprsPushDown(ctx, []expression.Expression{expr}, client, kv.TiKV) { pushed = append(pushed, expr) @@ -1634,7 +1633,7 @@ func extractFiltersForIndexMerge(ctx sessionctx.Context, client kv.Client, filte return } -func isIndexColsCoveringCol(sctx sessionctx.Context, col *expression.Column, indexCols []*expression.Column, idxColLens []int, ignoreLen bool) bool { +func isIndexColsCoveringCol(sctx expression.EvalContext, col *expression.Column, indexCols []*expression.Column, idxColLens []int, ignoreLen bool) bool { for i, indexCol := range indexCols { if indexCol == nil || !col.EqualByExprAndID(sctx, indexCol) { continue @@ -2011,7 +2010,7 @@ func SplitSelCondsWithVirtualColumn(conds []expression.Expression) (withoutVirt return withoutVirt, withVirt } -func matchIndicesProp(sctx sessionctx.Context, idxCols []*expression.Column, colLens []int, propItems []property.SortItem) bool { +func matchIndicesProp(sctx PlanContext, idxCols []*expression.Column, colLens []int, propItems []property.SortItem) bool { if len(idxCols) < len(propItems) { return false } @@ -2641,7 +2640,7 @@ func appendCandidate(lp LogicalPlan, task task, prop *property.PhysicalProperty, // PushDownNot here can convert condition 'not (a != 1)' to 'a = 1'. When we build range from conds, the condition like // 'not (a != 1)' would not be handled so we need to convert it to 'a = 1', which can be handled when building range. -func pushDownNot(ctx sessionctx.Context, conds []expression.Expression) []expression.Expression { +func pushDownNot(ctx expression.BuildContext, conds []expression.Expression) []expression.Expression { for i, cond := range conds { conds[i] = expression.PushDownNot(ctx, cond) } diff --git a/pkg/planner/core/find_best_task_test.go b/pkg/planner/core/find_best_task_test.go index 713d61b455931..deae441c29068 100644 --- a/pkg/planner/core/find_best_task_test.go +++ b/pkg/planner/core/find_best_task_test.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/stretchr/testify/require" ) @@ -29,7 +28,7 @@ type mockDataSource struct { baseLogicalPlan } -func (ds mockDataSource) Init(ctx sessionctx.Context) *mockDataSource { +func (ds mockDataSource) Init(ctx PlanContext) *mockDataSource { ds.baseLogicalPlan = newBaseLogicalPlan(ctx, "mockDS", &ds, 0) return &ds } @@ -66,7 +65,7 @@ type mockLogicalPlan4Test struct { costOverflow bool } -func (p mockLogicalPlan4Test) Init(ctx sessionctx.Context) *mockLogicalPlan4Test { +func (p mockLogicalPlan4Test) Init(ctx PlanContext) *mockLogicalPlan4Test { p.baseLogicalPlan = newBaseLogicalPlan(ctx, "mockPlan", &p, 0) return &p } @@ -118,7 +117,7 @@ type mockPhysicalPlan4Test struct { planType int } -func (p mockPhysicalPlan4Test) Init(ctx sessionctx.Context) *mockPhysicalPlan4Test { +func (p mockPhysicalPlan4Test) Init(ctx PlanContext) *mockPhysicalPlan4Test { p.basePhysicalPlan = newBasePhysicalPlan(ctx, "mockPlan", &p, 0) return &p } diff --git a/pkg/planner/core/hint_utils.go b/pkg/planner/core/hint_utils.go index 1b0c6526eaab4..1399a9420c886 100644 --- a/pkg/planner/core/hint_utils.go +++ b/pkg/planner/core/hint_utils.go @@ -18,7 +18,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" h "github.com/pingcap/tidb/pkg/util/hint" ) @@ -97,7 +96,7 @@ func extractTableAsName(p PhysicalPlan) (*model.CIStr, *model.CIStr) { return nil, nil } -func getJoinHints(sctx sessionctx.Context, joinType string, parentOffset int, nodeType h.NodeType, children ...PhysicalPlan) (res []*ast.TableOptimizerHint) { +func getJoinHints(sctx PlanContext, joinType string, parentOffset int, nodeType h.NodeType, children ...PhysicalPlan) (res []*ast.TableOptimizerHint) { if parentOffset == -1 { return res } diff --git a/pkg/planner/core/indexmerge_path.go b/pkg/planner/core/indexmerge_path.go index e09838d263014..b3b2184a72fe3 100644 --- a/pkg/planner/core/indexmerge_path.go +++ b/pkg/planner/core/indexmerge_path.go @@ -28,9 +28,9 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -355,7 +355,7 @@ func (ds *DataSource) accessPathsForConds(conditions []expression.Expression, us continue } // If we have point or empty range, just remove other possible paths. - if len(path.Ranges) == 0 || path.OnlyPointRange(ds.SCtx()) { + if len(path.Ranges) == 0 || path.OnlyPointRange(ds.SCtx().GetSessionVars().StmtCtx.TypeCtx()) { if len(results) == 0 { results = append(results, path) } else { @@ -380,7 +380,7 @@ func (ds *DataSource) accessPathsForConds(conditions []expression.Expression, us continue } // If we have empty range, or point range on unique index, just remove other possible paths. - if len(path.Ranges) == 0 || (path.OnlyPointRange(ds.SCtx()) && path.Index.Unique) { + if len(path.Ranges) == 0 || (path.OnlyPointRange(ds.SCtx().GetSessionVars().StmtCtx.TypeCtx()) && path.Index.Unique) { if len(results) == 0 { results = append(results, path) } else { @@ -1235,7 +1235,7 @@ func (*DataSource) buildPartialPathUp4MVIndex( // The accessFilters must be corresponding to these idxCols. // OK indicates whether it builds successfully. These partial paths should be ignored if ok==false. func buildPartialPaths4MVIndex( - sctx sessionctx.Context, + sctx context.PlanContext, accessFilters []expression.Expression, idxCols []*expression.Column, mvIndex *model.IndexInfo, @@ -1350,7 +1350,7 @@ func isSafeTypeConversion4MVIndexRange(valType, mvIndexType *types.FieldType) (s // buildPartialPath4MVIndex builds a partial path on this MVIndex with these accessFilters. func buildPartialPath4MVIndex( - sctx sessionctx.Context, + sctx context.PlanContext, accessFilters []expression.Expression, idxCols []*expression.Column, mvIndex *model.IndexInfo, @@ -1412,7 +1412,7 @@ func PrepareCols4MVIndex( // collectFilters4MVIndex splits these filters into 2 parts where accessFilters can be used to access this index directly. // For idx(x, cast(a as array), z), `x=1 and (2 member of a) and z=1 and x+z>0` is split to: // accessFilters: `x=1 and (2 member of a) and z=1`, remaining: `x+z>0`. -func collectFilters4MVIndex(sctx sessionctx.Context, filters []expression.Expression, idxCols []*expression.Column) (accessFilters, remainingFilters []expression.Expression) { +func collectFilters4MVIndex(sctx context.PlanContext, filters []expression.Expression, idxCols []*expression.Column) (accessFilters, remainingFilters []expression.Expression) { usedAsAccess := make([]bool, len(filters)) for _, col := range idxCols { found := false @@ -1476,7 +1476,7 @@ func collectFilters4MVIndex(sctx sessionctx.Context, filters []expression.Expres // accessFilters: [x=1, (2 member of a), z=1], remainingFilters: [x+z>0], mvColOffset: 1, mvFilterMutations[(2 member of a), (1 member of a)] // // the outer usage will be: accessFilter[mvColOffset] = each element of mvFilterMutations to get the mv access filters mutation combination. -func CollectFilters4MVIndexMutations(sctx sessionctx.Context, filters []expression.Expression, +func CollectFilters4MVIndexMutations(sctx PlanContext, filters []expression.Expression, idxCols []*expression.Column) (accessFilters, remainingFilters []expression.Expression, mvColOffset int, mvFilterMutations []expression.Expression) { usedAsAccess := make([]bool, len(filters)) // accessFilters [x, a, z] @@ -1575,7 +1575,7 @@ func indexMergeContainSpecificIndex(path *util.AccessPath, indexSet map[int64]st } // checkFilter4MVIndexColumn checks whether this filter can be used as an accessFilter to access the MVIndex column. -func checkFilter4MVIndexColumn(sctx sessionctx.Context, filter expression.Expression, idxCol *expression.Column) bool { +func checkFilter4MVIndexColumn(sctx PlanContext, filter expression.Expression, idxCol *expression.Column) bool { sf, ok := filter.(*expression.ScalarFunction) if !ok { return false @@ -1623,7 +1623,7 @@ func checkFilter4MVIndexColumn(sctx sessionctx.Context, filter expression.Expres } // jsonArrayExpr2Exprs converts a JsonArray expression to expression list: cast('[1, 2, 3]' as JSON) --> []expr{1, 2, 3} -func jsonArrayExpr2Exprs(sctx sessionctx.Context, jsonArrayExpr expression.Expression, targetType *types.FieldType) ([]expression.Expression, bool) { +func jsonArrayExpr2Exprs(sctx expression.EvalContext, jsonArrayExpr expression.Expression, targetType *types.FieldType) ([]expression.Expression, bool) { if !expression.IsInmutableExpr(jsonArrayExpr) || jsonArrayExpr.GetType().EvalType() != types.ETJson { return nil, false } diff --git a/pkg/planner/core/initialize.go b/pkg/planner/core/initialize.go index d50455d30e99e..d14615ea2c0fd 100644 --- a/pkg/planner/core/initialize.go +++ b/pkg/planner/core/initialize.go @@ -19,7 +19,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/planner/core/internal/base" "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/plancodec" @@ -27,55 +26,55 @@ import ( ) // Init initializes LogicalAggregation. -func (la LogicalAggregation) Init(ctx sessionctx.Context, offset int) *LogicalAggregation { +func (la LogicalAggregation) Init(ctx PlanContext, offset int) *LogicalAggregation { la.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeAgg, &la, offset) return &la } // Init initializes LogicalJoin. -func (p LogicalJoin) Init(ctx sessionctx.Context, offset int) *LogicalJoin { +func (p LogicalJoin) Init(ctx PlanContext, offset int) *LogicalJoin { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeJoin, &p, offset) return &p } // Init initializes DataSource. -func (ds DataSource) Init(ctx sessionctx.Context, offset int) *DataSource { +func (ds DataSource) Init(ctx PlanContext, offset int) *DataSource { ds.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeDataSource, &ds, offset) return &ds } // Init initializes TiKVSingleGather. -func (sg TiKVSingleGather) Init(ctx sessionctx.Context, offset int) *TiKVSingleGather { +func (sg TiKVSingleGather) Init(ctx PlanContext, offset int) *TiKVSingleGather { sg.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeTiKVSingleGather, &sg, offset) return &sg } // Init initializes LogicalTableScan. -func (ts LogicalTableScan) Init(ctx sessionctx.Context, offset int) *LogicalTableScan { +func (ts LogicalTableScan) Init(ctx PlanContext, offset int) *LogicalTableScan { ts.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeTableScan, &ts, offset) return &ts } // Init initializes LogicalIndexScan. -func (is LogicalIndexScan) Init(ctx sessionctx.Context, offset int) *LogicalIndexScan { +func (is LogicalIndexScan) Init(ctx PlanContext, offset int) *LogicalIndexScan { is.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeIdxScan, &is, offset) return &is } // Init initializes LogicalApply. -func (la LogicalApply) Init(ctx sessionctx.Context, offset int) *LogicalApply { +func (la LogicalApply) Init(ctx PlanContext, offset int) *LogicalApply { la.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeApply, &la, offset) return &la } // Init initializes LogicalSelection. -func (p LogicalSelection) Init(ctx sessionctx.Context, qbOffset int) *LogicalSelection { +func (p LogicalSelection) Init(ctx PlanContext, qbOffset int) *LogicalSelection { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeSel, &p, qbOffset) return &p } // Init initializes PhysicalSelection. -func (p PhysicalSelection) Init(ctx sessionctx.Context, stats *property.StatsInfo, qbOffset int, props ...*property.PhysicalProperty) *PhysicalSelection { +func (p PhysicalSelection) Init(ctx PlanContext, stats *property.StatsInfo, qbOffset int, props ...*property.PhysicalProperty) *PhysicalSelection { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSel, &p, qbOffset) p.childrenReqProps = props p.SetStats(stats) @@ -83,25 +82,25 @@ func (p PhysicalSelection) Init(ctx sessionctx.Context, stats *property.StatsInf } // Init initializes LogicalUnionScan. -func (p LogicalUnionScan) Init(ctx sessionctx.Context, qbOffset int) *LogicalUnionScan { +func (p LogicalUnionScan) Init(ctx PlanContext, qbOffset int) *LogicalUnionScan { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeUnionScan, &p, qbOffset) return &p } // Init initializes LogicalProjection. -func (p LogicalProjection) Init(ctx sessionctx.Context, qbOffset int) *LogicalProjection { +func (p LogicalProjection) Init(ctx PlanContext, qbOffset int) *LogicalProjection { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeProj, &p, qbOffset) return &p } // Init initializes LogicalProjection. -func (p LogicalExpand) Init(ctx sessionctx.Context, offset int) *LogicalExpand { +func (p LogicalExpand) Init(ctx PlanContext, offset int) *LogicalExpand { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeExpand, &p, offset) return &p } // Init initializes PhysicalProjection. -func (p PhysicalProjection) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalProjection { +func (p PhysicalProjection) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalProjection { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeProj, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -109,19 +108,19 @@ func (p PhysicalProjection) Init(ctx sessionctx.Context, stats *property.StatsIn } // Init initializes LogicalUnionAll. -func (p LogicalUnionAll) Init(ctx sessionctx.Context, offset int) *LogicalUnionAll { +func (p LogicalUnionAll) Init(ctx PlanContext, offset int) *LogicalUnionAll { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeUnion, &p, offset) return &p } // Init initializes LogicalPartitionUnionAll. -func (p LogicalPartitionUnionAll) Init(ctx sessionctx.Context, offset int) *LogicalPartitionUnionAll { +func (p LogicalPartitionUnionAll) Init(ctx PlanContext, offset int) *LogicalPartitionUnionAll { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypePartitionUnion, &p, offset) return &p } // Init initializes PhysicalUnionAll. -func (p PhysicalUnionAll) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalUnionAll { +func (p PhysicalUnionAll) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalUnionAll { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeUnion, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -129,13 +128,13 @@ func (p PhysicalUnionAll) Init(ctx sessionctx.Context, stats *property.StatsInfo } // Init initializes LogicalSort. -func (ls LogicalSort) Init(ctx sessionctx.Context, offset int) *LogicalSort { +func (ls LogicalSort) Init(ctx PlanContext, offset int) *LogicalSort { ls.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeSort, &ls, offset) return &ls } // Init initializes PhysicalSort. -func (p PhysicalSort) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalSort { +func (p PhysicalSort) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalSort { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSort, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -143,7 +142,7 @@ func (p PhysicalSort) Init(ctx sessionctx.Context, stats *property.StatsInfo, of } // Init initializes NominalSort. -func (p NominalSort) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *NominalSort { +func (p NominalSort) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *NominalSort { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSort, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -151,13 +150,13 @@ func (p NominalSort) Init(ctx sessionctx.Context, stats *property.StatsInfo, off } // Init initializes LogicalTopN. -func (lt LogicalTopN) Init(ctx sessionctx.Context, offset int) *LogicalTopN { +func (lt LogicalTopN) Init(ctx PlanContext, offset int) *LogicalTopN { lt.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeTopN, <, offset) return < } // Init initializes PhysicalTopN. -func (p PhysicalTopN) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalTopN { +func (p PhysicalTopN) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalTopN { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTopN, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -165,13 +164,13 @@ func (p PhysicalTopN) Init(ctx sessionctx.Context, stats *property.StatsInfo, of } // Init initializes LogicalLimit. -func (p LogicalLimit) Init(ctx sessionctx.Context, offset int) *LogicalLimit { +func (p LogicalLimit) Init(ctx PlanContext, offset int) *LogicalLimit { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeLimit, &p, offset) return &p } // Init initializes PhysicalLimit. -func (p PhysicalLimit) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalLimit { +func (p PhysicalLimit) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalLimit { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeLimit, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -179,26 +178,26 @@ func (p PhysicalLimit) Init(ctx sessionctx.Context, stats *property.StatsInfo, o } // Init initializes LogicalTableDual. -func (p LogicalTableDual) Init(ctx sessionctx.Context, offset int) *LogicalTableDual { +func (p LogicalTableDual) Init(ctx PlanContext, offset int) *LogicalTableDual { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeDual, &p, offset) return &p } // Init initializes PhysicalTableDual. -func (p PhysicalTableDual) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalTableDual { +func (p PhysicalTableDual) Init(ctx PlanContext, stats *property.StatsInfo, offset int) *PhysicalTableDual { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeDual, &p, offset) p.SetStats(stats) return &p } // Init initializes LogicalMaxOneRow. -func (p LogicalMaxOneRow) Init(ctx sessionctx.Context, offset int) *LogicalMaxOneRow { +func (p LogicalMaxOneRow) Init(ctx PlanContext, offset int) *LogicalMaxOneRow { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeMaxOneRow, &p, offset) return &p } // Init initializes PhysicalMaxOneRow. -func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalMaxOneRow { +func (p PhysicalMaxOneRow) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalMaxOneRow { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeMaxOneRow, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -206,13 +205,13 @@ func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInf } // Init initializes LogicalWindow. -func (p LogicalWindow) Init(ctx sessionctx.Context, offset int) *LogicalWindow { +func (p LogicalWindow) Init(ctx PlanContext, offset int) *LogicalWindow { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeWindow, &p, offset) return &p } // Init initializes PhysicalWindow. -func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalWindow { +func (p PhysicalWindow) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalWindow { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeWindow, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -220,7 +219,7 @@ func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, } // Init initializes PhysicalShuffle. -func (p PhysicalShuffle) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalShuffle { +func (p PhysicalShuffle) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalShuffle { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeShuffle, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -228,7 +227,7 @@ func (p PhysicalShuffle) Init(ctx sessionctx.Context, stats *property.StatsInfo, } // Init initializes PhysicalShuffleReceiverStub. -func (p PhysicalShuffleReceiverStub) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalShuffleReceiverStub { +func (p PhysicalShuffleReceiverStub) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalShuffleReceiverStub { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeShuffleReceiver, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -236,49 +235,49 @@ func (p PhysicalShuffleReceiverStub) Init(ctx sessionctx.Context, stats *propert } // Init initializes Update. -func (p Update) Init(ctx sessionctx.Context) *Update { +func (p Update) Init(ctx PlanContext) *Update { p.Plan = base.NewBasePlan(ctx, plancodec.TypeUpdate, 0) return &p } // Init initializes Delete. -func (p Delete) Init(ctx sessionctx.Context) *Delete { +func (p Delete) Init(ctx PlanContext) *Delete { p.Plan = base.NewBasePlan(ctx, plancodec.TypeDelete, 0) return &p } // Init initializes Insert. -func (p Insert) Init(ctx sessionctx.Context) *Insert { +func (p Insert) Init(ctx PlanContext) *Insert { p.Plan = base.NewBasePlan(ctx, plancodec.TypeInsert, 0) return &p } // Init initializes LoadData. -func (p LoadData) Init(ctx sessionctx.Context) *LoadData { +func (p LoadData) Init(ctx PlanContext) *LoadData { p.Plan = base.NewBasePlan(ctx, plancodec.TypeLoadData, 0) return &p } // Init initializes ImportInto. -func (p ImportInto) Init(ctx sessionctx.Context) *ImportInto { +func (p ImportInto) Init(ctx PlanContext) *ImportInto { p.Plan = base.NewBasePlan(ctx, plancodec.TypeImportInto, 0) return &p } // Init initializes LogicalShow. -func (p LogicalShow) Init(ctx sessionctx.Context) *LogicalShow { +func (p LogicalShow) Init(ctx PlanContext) *LogicalShow { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeShow, &p, 0) return &p } // Init initializes LogicalShowDDLJobs. -func (p LogicalShowDDLJobs) Init(ctx sessionctx.Context) *LogicalShowDDLJobs { +func (p LogicalShowDDLJobs) Init(ctx PlanContext) *LogicalShowDDLJobs { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeShowDDLJobs, &p, 0) return &p } // Init initializes PhysicalShow. -func (p PhysicalShow) Init(ctx sessionctx.Context) *PhysicalShow { +func (p PhysicalShow) Init(ctx PlanContext) *PhysicalShow { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeShow, &p, 0) // Just use pseudo stats to avoid panic. p.SetStats(&property.StatsInfo{RowCount: 1}) @@ -286,7 +285,7 @@ func (p PhysicalShow) Init(ctx sessionctx.Context) *PhysicalShow { } // Init initializes PhysicalShowDDLJobs. -func (p PhysicalShowDDLJobs) Init(ctx sessionctx.Context) *PhysicalShowDDLJobs { +func (p PhysicalShowDDLJobs) Init(ctx PlanContext) *PhysicalShowDDLJobs { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeShowDDLJobs, &p, 0) // Just use pseudo stats to avoid panic. p.SetStats(&property.StatsInfo{RowCount: 1}) @@ -294,13 +293,13 @@ func (p PhysicalShowDDLJobs) Init(ctx sessionctx.Context) *PhysicalShowDDLJobs { } // Init initializes LogicalLock. -func (p LogicalLock) Init(ctx sessionctx.Context) *LogicalLock { +func (p LogicalLock) Init(ctx PlanContext) *LogicalLock { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeLock, &p, 0) return &p } // Init initializes PhysicalLock. -func (p PhysicalLock) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalLock { +func (p PhysicalLock) Init(ctx PlanContext, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalLock { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeLock, &p, 0) p.childrenReqProps = props p.SetStats(stats) @@ -308,32 +307,32 @@ func (p PhysicalLock) Init(ctx sessionctx.Context, stats *property.StatsInfo, pr } // Init initializes PhysicalTableScan. -func (p PhysicalTableScan) Init(ctx sessionctx.Context, offset int) *PhysicalTableScan { +func (p PhysicalTableScan) Init(ctx PlanContext, offset int) *PhysicalTableScan { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTableScan, &p, offset) return &p } // Init initializes PhysicalIndexScan. -func (p PhysicalIndexScan) Init(ctx sessionctx.Context, offset int) *PhysicalIndexScan { +func (p PhysicalIndexScan) Init(ctx PlanContext, offset int) *PhysicalIndexScan { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIdxScan, &p, offset) return &p } // Init initializes LogicalMemTable. -func (p LogicalMemTable) Init(ctx sessionctx.Context, offset int) *LogicalMemTable { +func (p LogicalMemTable) Init(ctx PlanContext, offset int) *LogicalMemTable { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeMemTableScan, &p, offset) return &p } // Init initializes PhysicalMemTable. -func (p PhysicalMemTable) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalMemTable { +func (p PhysicalMemTable) Init(ctx PlanContext, stats *property.StatsInfo, offset int) *PhysicalMemTable { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeMemTableScan, &p, offset) p.SetStats(stats) return &p } // Init initializes PhysicalHashJoin. -func (p PhysicalHashJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalHashJoin { +func (p PhysicalHashJoin) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalHashJoin { tp := plancodec.TypeHashJoin p.basePhysicalPlan = newBasePhysicalPlan(ctx, tp, &p, offset) p.childrenReqProps = props @@ -342,20 +341,20 @@ func (p PhysicalHashJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo } // Init initializes PhysicalMergeJoin. -func (p PhysicalMergeJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *PhysicalMergeJoin { +func (p PhysicalMergeJoin) Init(ctx PlanContext, stats *property.StatsInfo, offset int) *PhysicalMergeJoin { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeMergeJoin, &p, offset) p.SetStats(stats) return &p } // Init initializes basePhysicalAgg. -func (base basePhysicalAgg) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int) *basePhysicalAgg { +func (base basePhysicalAgg) Init(ctx PlanContext, stats *property.StatsInfo, offset int) *basePhysicalAgg { base.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeHashAgg, &base, offset) base.SetStats(stats) return &base } -func (base basePhysicalAgg) initForHash(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalHashAgg { +func (base basePhysicalAgg) initForHash(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalHashAgg { p := &PhysicalHashAgg{base} p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeHashAgg, p, offset) p.childrenReqProps = props @@ -363,7 +362,7 @@ func (base basePhysicalAgg) initForHash(ctx sessionctx.Context, stats *property. return p } -func (base basePhysicalAgg) initForStream(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalStreamAgg { +func (base basePhysicalAgg) initForStream(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalStreamAgg { p := &PhysicalStreamAgg{base} p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeStreamAgg, p, offset) p.childrenReqProps = props @@ -372,7 +371,7 @@ func (base basePhysicalAgg) initForStream(ctx sessionctx.Context, stats *propert } // Init initializes PhysicalApply. -func (p PhysicalApply) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalApply { +func (p PhysicalApply) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalApply { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeApply, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -380,7 +379,7 @@ func (p PhysicalApply) Init(ctx sessionctx.Context, stats *property.StatsInfo, o } // Init initializes PhysicalUnionScan. -func (p PhysicalUnionScan) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalUnionScan { +func (p PhysicalUnionScan) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalUnionScan { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeUnionScan, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -388,7 +387,7 @@ func (p PhysicalUnionScan) Init(ctx sessionctx.Context, stats *property.StatsInf } // Init initializes PhysicalIndexLookUpReader. -func (p PhysicalIndexLookUpReader) Init(ctx sessionctx.Context, offset int) *PhysicalIndexLookUpReader { +func (p PhysicalIndexLookUpReader) Init(ctx PlanContext, offset int) *PhysicalIndexLookUpReader { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexLookUp, &p, offset) p.TablePlans = flattenPushDownPlan(p.tablePlan) p.IndexPlans = flattenPushDownPlan(p.indexPlan) @@ -397,7 +396,7 @@ func (p PhysicalIndexLookUpReader) Init(ctx sessionctx.Context, offset int) *Phy } // Init initializes PhysicalIndexMergeReader. -func (p PhysicalIndexMergeReader) Init(ctx sessionctx.Context, offset int) *PhysicalIndexMergeReader { +func (p PhysicalIndexMergeReader) Init(ctx PlanContext, offset int) *PhysicalIndexMergeReader { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexMerge, &p, offset) if p.tablePlan != nil { p.SetStats(p.tablePlan.StatsInfo()) @@ -438,7 +437,7 @@ func (p PhysicalIndexMergeReader) Init(ctx sessionctx.Context, offset int) *Phys return &p } -func (p *PhysicalTableReader) adjustReadReqType(ctx sessionctx.Context) { +func (p *PhysicalTableReader) adjustReadReqType(ctx PlanContext) { if p.StoreType == kv.TiFlash { _, ok := p.tablePlan.(*PhysicalExchangeSender) if ok { @@ -474,7 +473,7 @@ func (p *PhysicalTableReader) adjustReadReqType(ctx sessionctx.Context) { } // Init initializes PhysicalTableReader. -func (p PhysicalTableReader) Init(ctx sessionctx.Context, offset int) *PhysicalTableReader { +func (p PhysicalTableReader) Init(ctx PlanContext, offset int) *PhysicalTableReader { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTableReader, &p, offset) p.ReadReqType = Cop if p.tablePlan == nil { @@ -490,7 +489,7 @@ func (p PhysicalTableReader) Init(ctx sessionctx.Context, offset int) *PhysicalT } // Init initializes PhysicalTableSample. -func (p PhysicalTableSample) Init(ctx sessionctx.Context, offset int) *PhysicalTableSample { +func (p PhysicalTableSample) Init(ctx PlanContext, offset int) *PhysicalTableSample { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTableSample, &p, offset) p.SetStats(&property.StatsInfo{RowCount: 1}) return &p @@ -510,14 +509,14 @@ func (p *PhysicalTableSample) MemoryUsage() (sum int64) { } // Init initializes PhysicalIndexReader. -func (p PhysicalIndexReader) Init(ctx sessionctx.Context, offset int) *PhysicalIndexReader { +func (p PhysicalIndexReader) Init(ctx PlanContext, offset int) *PhysicalIndexReader { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexReader, &p, offset) p.SetSchema(nil) return &p } // Init initializes PhysicalIndexJoin. -func (p PhysicalIndexJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalIndexJoin { +func (p PhysicalIndexJoin) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalIndexJoin { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexJoin, &p, offset) p.childrenReqProps = props p.SetStats(stats) @@ -525,7 +524,7 @@ func (p PhysicalIndexJoin) Init(ctx sessionctx.Context, stats *property.StatsInf } // Init initializes PhysicalIndexMergeJoin. -func (p PhysicalIndexMergeJoin) Init(ctx sessionctx.Context) *PhysicalIndexMergeJoin { +func (p PhysicalIndexMergeJoin) Init(ctx PlanContext) *PhysicalIndexMergeJoin { p.SetTP(plancodec.TypeIndexMergeJoin) p.SetID(int(ctx.GetSessionVars().PlanID.Add(1))) p.SetSCtx(ctx) @@ -534,7 +533,7 @@ func (p PhysicalIndexMergeJoin) Init(ctx sessionctx.Context) *PhysicalIndexMerge } // Init initializes PhysicalIndexHashJoin. -func (p PhysicalIndexHashJoin) Init(ctx sessionctx.Context) *PhysicalIndexHashJoin { +func (p PhysicalIndexHashJoin) Init(ctx PlanContext) *PhysicalIndexHashJoin { p.SetTP(plancodec.TypeIndexHashJoin) p.SetID(int(ctx.GetSessionVars().PlanID.Add(1))) p.SetSCtx(ctx) @@ -543,7 +542,7 @@ func (p PhysicalIndexHashJoin) Init(ctx sessionctx.Context) *PhysicalIndexHashJo } // Init initializes BatchPointGetPlan. -func (p *BatchPointGetPlan) Init(ctx sessionctx.Context, stats *property.StatsInfo, schema *expression.Schema, names []*types.FieldName, offset int) *BatchPointGetPlan { +func (p *BatchPointGetPlan) Init(ctx PlanContext, stats *property.StatsInfo, schema *expression.Schema, names []*types.FieldName, offset int) *BatchPointGetPlan { p.Plan = base.NewBasePlan(ctx, plancodec.TypeBatchPointGet, offset) p.schema = schema p.names = names @@ -598,7 +597,7 @@ func (p *BatchPointGetPlan) Init(ctx sessionctx.Context, stats *property.StatsIn } // Init initializes PointGetPlan. -func (p PointGetPlan) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, _ ...*property.PhysicalProperty) *PointGetPlan { +func (p PointGetPlan) Init(ctx PlanContext, stats *property.StatsInfo, offset int, _ ...*property.PhysicalProperty) *PointGetPlan { p.Plan = base.NewBasePlan(ctx, plancodec.TypePointGet, offset) p.SetStats(stats) p.Columns = ExpandVirtualColumn(p.Columns, p.schema, p.TblInfo.Columns) @@ -606,14 +605,14 @@ func (p PointGetPlan) Init(ctx sessionctx.Context, stats *property.StatsInfo, of } // Init only assigns type and context. -func (p PhysicalExchangeSender) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalExchangeSender { +func (p PhysicalExchangeSender) Init(ctx PlanContext, stats *property.StatsInfo) *PhysicalExchangeSender { p.Plan = base.NewBasePlan(ctx, plancodec.TypeExchangeSender, 0) p.SetStats(stats) return &p } // Init only assigns type and context. -func (p PhysicalExchangeReceiver) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalExchangeReceiver { +func (p PhysicalExchangeReceiver) Init(ctx PlanContext, stats *property.StatsInfo) *PhysicalExchangeReceiver { p.Plan = base.NewBasePlan(ctx, plancodec.TypeExchangeReceiver, 0) p.SetStats(stats) return &p @@ -639,53 +638,53 @@ func flattenPushDownPlan(p PhysicalPlan) []PhysicalPlan { } // Init only assigns type and context. -func (p LogicalCTE) Init(ctx sessionctx.Context, offset int) *LogicalCTE { +func (p LogicalCTE) Init(ctx PlanContext, offset int) *LogicalCTE { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeCTE, &p, offset) return &p } // Init only assigns type and context. -func (p PhysicalCTE) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalCTE { +func (p PhysicalCTE) Init(ctx PlanContext, stats *property.StatsInfo) *PhysicalCTE { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeCTE, &p, 0) p.SetStats(stats) return &p } // Init only assigns type and context. -func (p LogicalCTETable) Init(ctx sessionctx.Context, offset int) *LogicalCTETable { +func (p LogicalCTETable) Init(ctx PlanContext, offset int) *LogicalCTETable { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeCTETable, &p, offset) return &p } // Init only assigns type and context. -func (p PhysicalCTETable) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalCTETable { +func (p PhysicalCTETable) Init(ctx PlanContext, stats *property.StatsInfo) *PhysicalCTETable { p.Plan = base.NewBasePlan(ctx, plancodec.TypeCTETable, 0) p.SetStats(stats) return &p } // Init initializes FKCheck. -func (p FKCheck) Init(ctx sessionctx.Context) *FKCheck { +func (p FKCheck) Init(ctx PlanContext) *FKCheck { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeForeignKeyCheck, &p, 0) p.SetStats(&property.StatsInfo{}) return &p } // Init initializes FKCascade -func (p FKCascade) Init(ctx sessionctx.Context) *FKCascade { +func (p FKCascade) Init(ctx PlanContext) *FKCascade { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeForeignKeyCascade, &p, 0) p.SetStats(&property.StatsInfo{}) return &p } // Init initializes LogicalSequence -func (p LogicalSequence) Init(ctx sessionctx.Context, offset int) *LogicalSequence { +func (p LogicalSequence) Init(ctx PlanContext, offset int) *LogicalSequence { p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeSequence, &p, offset) return &p } // Init initializes PhysicalSequence -func (p PhysicalSequence) Init(ctx sessionctx.Context, stats *property.StatsInfo, blockOffset int, props ...*property.PhysicalProperty) *PhysicalSequence { +func (p PhysicalSequence) Init(ctx PlanContext, stats *property.StatsInfo, blockOffset int, props ...*property.PhysicalProperty) *PhysicalSequence { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSequence, &p, blockOffset) p.SetStats(stats) p.childrenReqProps = props @@ -693,7 +692,7 @@ func (p PhysicalSequence) Init(ctx sessionctx.Context, stats *property.StatsInfo } // Init initializes ScalarSubqueryEvalCtx -func (p ScalarSubqueryEvalCtx) Init(ctx sessionctx.Context, offset int) *ScalarSubqueryEvalCtx { +func (p ScalarSubqueryEvalCtx) Init(ctx PlanContext, offset int) *ScalarSubqueryEvalCtx { p.Plan = base.NewBasePlan(ctx, plancodec.TypeScalarSubQuery, offset) return &p } diff --git a/pkg/planner/core/internal/BUILD.bazel b/pkg/planner/core/internal/BUILD.bazel index 359447cea4297..3c16228db707e 100644 --- a/pkg/planner/core/internal/BUILD.bazel +++ b/pkg/planner/core/internal/BUILD.bazel @@ -10,9 +10,9 @@ go_library( visibility = ["//pkg/planner/core:__subpackages__"], deps = [ "//pkg/domain", + "//pkg/expression", "//pkg/expression/aggregation", "//pkg/parser/model", - "//pkg/sessionctx", "//pkg/store/mockstore", "//pkg/store/mockstore/unistore", "@com_github_pingcap_kvproto//pkg/metapb", diff --git a/pkg/planner/core/internal/base/BUILD.bazel b/pkg/planner/core/internal/base/BUILD.bazel index 131e657804dd9..c326999fe8fcc 100644 --- a/pkg/planner/core/internal/base/BUILD.bazel +++ b/pkg/planner/core/internal/base/BUILD.bazel @@ -7,8 +7,8 @@ go_library( visibility = ["//pkg/planner/core:__subpackages__"], deps = [ "//pkg/expression", + "//pkg/planner/context", "//pkg/planner/property", - "//pkg/sessionctx", "//pkg/types", "//pkg/util/stringutil", "//pkg/util/tracing", diff --git a/pkg/planner/core/internal/base/plan.go b/pkg/planner/core/internal/base/plan.go index 00be1c4bd92c2..2e97883c39f5a 100644 --- a/pkg/planner/core/internal/base/plan.go +++ b/pkg/planner/core/internal/base/plan.go @@ -20,8 +20,8 @@ import ( "unsafe" "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/stringutil" "github.com/pingcap/tidb/pkg/util/tracing" @@ -29,7 +29,7 @@ import ( // Plan Should be used as embedded struct in Plan implementations. type Plan struct { - ctx sessionctx.Context + ctx context.PlanContext stats *property.StatsInfo tp string id int @@ -37,7 +37,7 @@ type Plan struct { } // NewBasePlan creates a new base plan. -func NewBasePlan(ctx sessionctx.Context, tp string, qbBlock int) Plan { +func NewBasePlan(ctx context.PlanContext, tp string, qbBlock int) Plan { id := ctx.GetSessionVars().PlanID.Add(1) return Plan{ tp: tp, @@ -48,12 +48,12 @@ func NewBasePlan(ctx sessionctx.Context, tp string, qbBlock int) Plan { } // SCtx is to get the sessionctx from the plan. -func (p *Plan) SCtx() sessionctx.Context { +func (p *Plan) SCtx() context.PlanContext { return p.ctx } // SetSCtx is to set the sessionctx for the plan. -func (p *Plan) SetSCtx(ctx sessionctx.Context) { +func (p *Plan) SetSCtx(ctx context.PlanContext) { p.ctx = ctx } diff --git a/pkg/planner/core/internal/util.go b/pkg/planner/core/internal/util.go index 849a921fe2318..b5af5a6b63ec9 100644 --- a/pkg/planner/core/internal/util.go +++ b/pkg/planner/core/internal/util.go @@ -15,14 +15,14 @@ package internal import ( + "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/sessionctx" ) // WrapCastForAggFuncs wraps the args of an aggregate function with a cast function. // If the mode is FinalMode or Partial2Mode, we do not need to wrap cast upon the args, // since the types of the args are already the expected. -func WrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) { +func WrapCastForAggFuncs(sctx expression.BuildContext, aggFuncs []*aggregation.AggFuncDesc) { for i := range aggFuncs { if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode { aggFuncs[i].WrapCastForAggArgs(sctx) diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 8259f5bfa33c7..580f654b6c216 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -4441,7 +4441,7 @@ func (ds *DataSource) AddExtraPhysTblIDColumn() *expression.Column { // 2. table row count from statistics is zero. // 3. statistics is outdated. // Note: please also update getLatestVersionFromStatsTable() when logic in this function changes. -func getStatsTable(ctx sessionctx.Context, tblInfo *model.TableInfo, pid int64) *statistics.Table { +func getStatsTable(ctx PlanContext, tblInfo *model.TableInfo, pid int64) *statistics.Table { statsHandle := domain.GetDomain(ctx).StatsHandle() var usePartitionStats, countIs0, pseudoStatsForUninitialized, pseudoStatsForOutdated bool var statsTbl *statistics.Table diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index 8af6127c033c8..47a3bffeb53bb 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -1190,7 +1190,7 @@ func extractNotNullFromConds(conditions []expression.Expression, p LogicalPlan) return notnullColsUniqueIDs } -func extractConstantCols(conditions []expression.Expression, sctx sessionctx.Context, fds *fd.FDSet) intset.FastIntSet { +func extractConstantCols(conditions []expression.Expression, sctx PlanContext, fds *fd.FDSet) intset.FastIntSet { // extract constant cols // eg: where a=1 and b is null and (1+c)=5. // TODO: Some columns can only be determined to be constant from multiple constraints (e.g. x <= 1 AND x >= 1) @@ -1217,7 +1217,7 @@ func extractConstantCols(conditions []expression.Expression, sctx sessionctx.Con return constUniqueIDs } -func extractEquivalenceCols(conditions []expression.Expression, sctx sessionctx.Context, fds *fd.FDSet) [][]intset.FastIntSet { +func extractEquivalenceCols(conditions []expression.Expression, sctx PlanContext, fds *fd.FDSet) [][]intset.FastIntSet { var equivObjsPair [][]expression.Expression equivObjsPair = expression.ExtractEquivalenceColumns(equivObjsPair, conditions) equivUniqueIDs := make([][]intset.FastIntSet, 0, len(equivObjsPair)) @@ -1617,7 +1617,7 @@ func (ds *DataSource) Convert2Gathers() (gathers []LogicalPlan) { } func detachCondAndBuildRangeForPath( - sctx sessionctx.Context, + sctx PlanContext, path *util.AccessPath, conds []expression.Expression, histColl *statistics.HistColl, @@ -2090,7 +2090,7 @@ func (p *LogicalWindow) GetPartitionBy() []property.SortItem { } // EqualPartitionBy checks whether two LogicalWindow.Partitions are equal. -func (p *LogicalWindow) EqualPartitionBy(_ sessionctx.Context, newWindow *LogicalWindow) bool { +func (p *LogicalWindow) EqualPartitionBy(newWindow *LogicalWindow) bool { if len(p.PartitionBy) != len(newWindow.PartitionBy) { return false } @@ -2107,7 +2107,7 @@ func (p *LogicalWindow) EqualPartitionBy(_ sessionctx.Context, newWindow *Logica } // EqualOrderBy checks whether two LogicalWindow.OrderBys are equal. -func (p *LogicalWindow) EqualOrderBy(ctx sessionctx.Context, newWindow *LogicalWindow) bool { +func (p *LogicalWindow) EqualOrderBy(ctx expression.EvalContext, newWindow *LogicalWindow) bool { if len(p.OrderBy) != len(newWindow.OrderBy) { return false } @@ -2121,7 +2121,7 @@ func (p *LogicalWindow) EqualOrderBy(ctx sessionctx.Context, newWindow *LogicalW } // EqualFrame checks whether two LogicalWindow.Frames are equal. -func (p *LogicalWindow) EqualFrame(ctx sessionctx.Context, newWindow *LogicalWindow) bool { +func (p *LogicalWindow) EqualFrame(ctx expression.EvalContext, newWindow *LogicalWindow) bool { if (p.Frame == nil && newWindow.Frame != nil) || (p.Frame != nil && newWindow.Frame == nil) { return false diff --git a/pkg/planner/core/memtable_predicate_extractor.go b/pkg/planner/core/memtable_predicate_extractor.go index 406e53c068b78..e00b80c787c50 100644 --- a/pkg/planner/core/memtable_predicate_extractor.go +++ b/pkg/planner/core/memtable_predicate_extractor.go @@ -31,7 +31,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -59,7 +58,7 @@ import ( // 4. Executor sends requests to the target components instead of all of the components type MemTablePredicateExtractor interface { // Extracts predicates which can be pushed down and returns the remained predicates - Extract(sessionctx.Context, *expression.Schema, []*types.FieldName, []expression.Expression) (remained []expression.Expression) + Extract(PlanContext, *expression.Schema, []*types.FieldName, []expression.Expression) (remained []expression.Expression) explainInfo(p *PhysicalMemTable) string } @@ -411,7 +410,7 @@ func (extractHelper) getStringFunctionName(fn *expression.ScalarFunction) string // SELECT * FROM t WHERE time='2019-10-10 10:10:10' // SELECT * FROM t WHERE time>'2019-10-10 10:10:10' AND time<'2019-10-11 10:10:10' func (helper extractHelper) extractTimeRange( - ctx sessionctx.Context, + ctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -600,7 +599,7 @@ type ClusterTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *ClusterTableExtractor) Extract(_ sessionctx.Context, +func (e *ClusterTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -666,8 +665,7 @@ type ClusterLogTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *ClusterLogTableExtractor) Extract( - ctx sessionctx.Context, +func (e *ClusterLogTableExtractor) Extract(ctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -780,8 +778,7 @@ type HotRegionsHistoryTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *HotRegionsHistoryTableExtractor) Extract( - ctx sessionctx.Context, +func (e *HotRegionsHistoryTableExtractor) Extract(ctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -899,8 +896,7 @@ func newMetricTableExtractor() *MetricTableExtractor { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *MetricTableExtractor) Extract( - ctx sessionctx.Context, +func (e *MetricTableExtractor) Extract(ctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -970,7 +966,7 @@ func (e *MetricTableExtractor) explainInfo(p *PhysicalMemTable) string { } // GetMetricTablePromQL uses to get the promQL of metric table. -func (e *MetricTableExtractor) GetMetricTablePromQL(sctx sessionctx.Context, lowerTableName string) string { +func (e *MetricTableExtractor) GetMetricTablePromQL(sctx PlanContext, lowerTableName string) string { quantiles := e.Quantiles def, err := infoschema.GetMetricTableDef(lowerTableName) if err != nil { @@ -981,7 +977,7 @@ func (e *MetricTableExtractor) GetMetricTablePromQL(sctx sessionctx.Context, low } var buf bytes.Buffer for i, quantile := range quantiles { - promQL := def.GenPromQL(sctx, e.LabelConditions, quantile) + promQL := def.GenPromQL(sctx.GetSessionVars().MetricSchemaRangeDuration, e.LabelConditions, quantile) if i > 0 { buf.WriteByte(',') } @@ -1000,8 +996,7 @@ type MetricSummaryTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *MetricSummaryTableExtractor) Extract( - _ sessionctx.Context, +func (e *MetricSummaryTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1033,8 +1028,7 @@ type InspectionResultTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *InspectionResultTableExtractor) Extract( - _ sessionctx.Context, +func (e *InspectionResultTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1071,8 +1065,7 @@ type InspectionSummaryTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *InspectionSummaryTableExtractor) Extract( - _ sessionctx.Context, +func (e *InspectionSummaryTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1130,8 +1123,7 @@ type InspectionRuleTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *InspectionRuleTableExtractor) Extract( - _ sessionctx.Context, +func (e *InspectionRuleTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1175,8 +1167,7 @@ type TimeRange struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *SlowQueryExtractor) Extract( - ctx sessionctx.Context, +func (e *SlowQueryExtractor) Extract(ctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1279,8 +1270,7 @@ type TableStorageStatsExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface. -func (e *TableStorageStatsExtractor) Extract( - _ sessionctx.Context, +func (e *TableStorageStatsExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1350,7 +1340,7 @@ type TiFlashSystemTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *TiFlashSystemTableExtractor) Extract(_ sessionctx.Context, +func (e *TiFlashSystemTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1412,8 +1402,7 @@ type StatementsSummaryExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *StatementsSummaryExtractor) Extract( - sctx sessionctx.Context, +func (e *StatementsSummaryExtractor) Extract(sctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1466,7 +1455,7 @@ func (e *StatementsSummaryExtractor) explainInfo(p *PhysicalMemTable) string { } func (e *StatementsSummaryExtractor) findCoarseTimeRange( - sctx sessionctx.Context, + sctx PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1514,7 +1503,7 @@ type TikvRegionPeersExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *TikvRegionPeersExtractor) Extract(_ sessionctx.Context, +func (e *TikvRegionPeersExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1572,7 +1561,7 @@ type ColumnsTableExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *ColumnsTableExtractor) Extract(_ sessionctx.Context, +func (e *ColumnsTableExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1635,7 +1624,7 @@ type TiKVRegionStatusExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *TiKVRegionStatusExtractor) Extract(_ sessionctx.Context, +func (e *TiKVRegionStatusExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, @@ -1689,7 +1678,7 @@ type InfoSchemaTablesExtractor struct { } // Extract implements the MemTablePredicateExtractor Extract interface -func (e *InfoSchemaTablesExtractor) Extract(_ sessionctx.Context, +func (e *InfoSchemaTablesExtractor) Extract(_ PlanContext, schema *expression.Schema, names []*types.FieldName, predicates []expression.Expression, diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index 1095aedf2320d..e25333e8c429f 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -300,7 +300,7 @@ func CheckTableLock(ctx sessionctx.Context, is infoschema.InfoSchema, vs []visit return nil } -func checkStableResultMode(sctx sessionctx.Context) bool { +func checkStableResultMode(sctx PlanContext) bool { s := sctx.GetSessionVars() st := s.StmtCtx return s.EnableStableResultMode && (!st.InInsertStmt && !st.InUpdateStmt && !st.InDeleteStmt && !st.InLoadDataStmt) @@ -311,7 +311,7 @@ func checkStableResultMode(sctx sessionctx.Context) bool { // The returned logical plan is necessary for generating plans for Common Table Expressions (CTEs). func doOptimize( ctx context.Context, - sctx sessionctx.Context, + sctx PlanContext, flag uint64, logic LogicalPlan, ) (LogicalPlan, PhysicalPlan, float64, error) { @@ -368,7 +368,7 @@ func adjustOptimizationFlags(flag uint64, logic LogicalPlan) uint64 { // DoOptimize optimizes a logical plan to a physical plan. func DoOptimize( ctx context.Context, - sctx sessionctx.Context, + sctx PlanContext, flag uint64, logic LogicalPlan, ) (PhysicalPlan, float64, error) { @@ -383,7 +383,7 @@ func DoOptimize( // refineCETrace will adjust the content of CETrace. // Currently, it will (1) deduplicate trace records, (2) sort the trace records (to make it easier in the tests) and (3) fill in the table name. -func refineCETrace(sctx sessionctx.Context) { +func refineCETrace(sctx PlanContext) { stmtCtx := sctx.GetSessionVars().StmtCtx stmtCtx.OptimizerCETrace = tracing.DedupCETrace(stmtCtx.OptimizerCETrace) slices.SortFunc(stmtCtx.OptimizerCETrace, func(i, j *tracing.CETraceRecord) int { @@ -442,7 +442,7 @@ func mergeContinuousSelections(p PhysicalPlan) { } } -func postOptimize(ctx context.Context, sctx sessionctx.Context, plan PhysicalPlan) (PhysicalPlan, error) { +func postOptimize(ctx context.Context, sctx PlanContext, plan PhysicalPlan) (PhysicalPlan, error) { // some cases from update optimize will require avoiding projection elimination. // see comments ahead of call of DoOptimize in function of buildUpdate(). err := prunePhysicalColumns(sctx, plan) @@ -463,7 +463,7 @@ func postOptimize(ctx context.Context, sctx sessionctx.Context, plan PhysicalPla return plan, nil } -func generateRuntimeFilter(sctx sessionctx.Context, plan PhysicalPlan) { +func generateRuntimeFilter(sctx PlanContext, plan PhysicalPlan) { if !sctx.GetSessionVars().IsRuntimeFilterEnabled() || sctx.GetSessionVars().InRestrictedSQL { return } @@ -482,7 +482,7 @@ func generateRuntimeFilter(sctx sessionctx.Context, plan PhysicalPlan) { // prunePhysicalColumns currently only work for MPP(HashJoin<-Exchange). // Here add projection instead of pruning columns directly for safety considerations. // And projection is cheap here for it saves the network cost and work in memory. -func prunePhysicalColumns(sctx sessionctx.Context, plan PhysicalPlan) error { +func prunePhysicalColumns(sctx PlanContext, plan PhysicalPlan) error { if tableReader, ok := plan.(*PhysicalTableReader); ok { if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender { err := prunePhysicalColumnsInternal(sctx, tableReader.tablePlan) @@ -526,7 +526,7 @@ func (p *PhysicalHashJoin) extractUsedCols(parentUsedCols []*expression.Column) return leftCols, rightCols } -func prunePhysicalColumnForHashJoinChild(sctx sessionctx.Context, hashJoin *PhysicalHashJoin, joinUsedCols []*expression.Column, sender *PhysicalExchangeSender) error { +func prunePhysicalColumnForHashJoinChild(sctx PlanContext, hashJoin *PhysicalHashJoin, joinUsedCols []*expression.Column, sender *PhysicalExchangeSender) error { var err error joinUsed := expression.GetUsedList(sctx, joinUsedCols, sender.Schema()) hashCols := make([]*expression.Column, len(sender.HashCols)) @@ -574,7 +574,7 @@ func prunePhysicalColumnForHashJoinChild(sctx sessionctx.Context, hashJoin *Phys return err } -func prunePhysicalColumnsInternal(sctx sessionctx.Context, plan PhysicalPlan) error { +func prunePhysicalColumnsInternal(sctx PlanContext, plan PhysicalPlan) error { var err error switch x := plan.(type) { case *PhysicalHashJoin: @@ -626,7 +626,7 @@ func prunePhysicalColumnsInternal(sctx sessionctx.Context, plan PhysicalPlan) er // - Only the filter conditions with high selectivity should be pushed down. // - The filter conditions which contain heavy cost functions should not be pushed down. // - Filter conditions that apply to the same column are either pushed down or not pushed down at all. -func tryEnableLateMaterialization(sctx sessionctx.Context, plan PhysicalPlan) { +func tryEnableLateMaterialization(sctx PlanContext, plan PhysicalPlan) { // check if EnableLateMaterialization is set if sctx.GetSessionVars().EnableLateMaterialization && !sctx.GetSessionVars().TiFlashFastScan { predicatePushDownToTableScan(sctx, plan) @@ -756,7 +756,7 @@ func rewriteTableScanAndAggArgs(physicalTableScan *PhysicalTableScan, aggFuncs [ // < 0: fine grained shuffle is disabled. // > 0: use TiFlashFineGrainedShuffleStreamCount as stream count. // == 0: use TiFlashMaxThreads as stream count when it's greater than 0. Otherwise set status as uninitialized. -func handleFineGrainedShuffle(ctx context.Context, sctx sessionctx.Context, plan PhysicalPlan) { +func handleFineGrainedShuffle(ctx context.Context, sctx PlanContext, plan PhysicalPlan) { streamCount := sctx.GetSessionVars().TiFlashFineGrainedShuffleStreamCount if streamCount < 0 { return @@ -776,7 +776,7 @@ func handleFineGrainedShuffle(ctx context.Context, sctx sessionctx.Context, plan setupFineGrainedShuffle(ctx, sctx, &streamCountInfo, &tiflashServerCountInfo, plan) } -func setupFineGrainedShuffle(ctx context.Context, sctx sessionctx.Context, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, plan PhysicalPlan) { +func setupFineGrainedShuffle(ctx context.Context, sctx PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, plan PhysicalPlan) { if tableReader, ok := plan.(*PhysicalTableReader); ok { if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender { helper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: make([]*basePhysicalPlan, 1)} @@ -830,7 +830,7 @@ func (h *fineGrainedShuffleHelper) updateTarget(t shuffleTarget, p *basePhysical // calculateTiFlashStreamCountUsingMinLogicalCores uses minimal logical cpu cores among tiflash servers, and divide by 2 // return false, 0 if any err happens -func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx sessionctx.Context, serversInfo []infoschema.ServerInfo) (bool, uint64) { +func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx PlanContext, serversInfo []infoschema.ServerInfo) (bool, uint64) { failpoint.Inject("mockTiFlashStreamCountUsingMinLogicalCores", func(val failpoint.Value) { intVal, err := strconv.Atoi(val.(string)) if err == nil { @@ -839,7 +839,7 @@ func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx s failpoint.Return(false, 0) } }) - rows, err := infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx, serversInfo, diagnosticspb.ServerInfoType_HardwareInfo, false) + rows, err := infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, diagnosticspb.ServerInfoType_HardwareInfo, false) if err != nil { return false, 0 } @@ -867,7 +867,7 @@ func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx s return false, 0 } -func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx sessionctx.Context, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, exchangeColCount int, splitLimit uint64) (applyFlag bool, streamCount uint64) { +func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, exchangeColCount int, splitLimit uint64) (applyFlag bool, streamCount uint64) { switch (*streamCountInfo).itemStatus { case unInitialized: streamCount = 4 // assume 8c node in cluster as minimal, stream count is 8 / 2 = 4 @@ -880,7 +880,7 @@ func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx sessionctx.Cont var tiflashServerCount uint64 switch (*tiflashServerCountInfo).itemStatus { case unInitialized: - serversInfo, err := infoschema.GetTiFlashServerInfo(sctx) + serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) if err != nil { (*tiflashServerCountInfo).itemStatus = failed (*tiflashServerCountInfo).itemValue = 0 @@ -908,7 +908,7 @@ func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx sessionctx.Cont return true, streamCount } - serversInfo, err := infoschema.GetTiFlashServerInfo(sctx) + serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) if err != nil { (*tiflashServerCountInfo).itemStatus = failed (*tiflashServerCountInfo).itemValue = 0 @@ -927,7 +927,7 @@ func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx sessionctx.Cont return applyFlag, streamCount } -func inferFineGrainedShuffleStreamCountForWindow(ctx context.Context, sctx sessionctx.Context, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) (streamCount uint64) { +func inferFineGrainedShuffleStreamCountForWindow(ctx context.Context, sctx PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) (streamCount uint64) { switch (*streamCountInfo).itemStatus { case unInitialized: if (*tiflashServerCountInfo).itemStatus == failed { @@ -936,7 +936,7 @@ func inferFineGrainedShuffleStreamCountForWindow(ctx context.Context, sctx sessi break } - serversInfo, err := infoschema.GetTiFlashServerInfo(sctx) + serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) if err != nil { setDefaultStreamCount(streamCountInfo) streamCount = (*streamCountInfo).itemValue @@ -973,7 +973,7 @@ func setDefaultStreamCount(streamCountInfo *tiflashClusterInfo) { (*streamCountInfo).itemValue = variable.DefStreamCountWhenMaxThreadsNotSet } -func setupFineGrainedShuffleInternal(ctx context.Context, sctx sessionctx.Context, plan PhysicalPlan, helper *fineGrainedShuffleHelper, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) { +func setupFineGrainedShuffleInternal(ctx context.Context, sctx PlanContext, plan PhysicalPlan, helper *fineGrainedShuffleHelper, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) { switch x := plan.(type) { case *PhysicalWindow: // Do not clear the plans because window executor will keep the data partition. @@ -1108,7 +1108,7 @@ func propagateProbeParents(plan PhysicalPlan, probeParents []PhysicalPlan) { } } -func enableParallelApply(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan { +func enableParallelApply(sctx PlanContext, plan PhysicalPlan) PhysicalPlan { if !sctx.GetSessionVars().EnableParallelApply { return plan } @@ -1265,7 +1265,7 @@ func physicalOptimize(logic LogicalPlan, planCounter *PlanCounterTp) (plan Physi } // eliminateUnionScanAndLock set lock property for PointGet and BatchPointGet and eliminates UnionScan and Lock. -func eliminateUnionScanAndLock(sctx sessionctx.Context, p PhysicalPlan) PhysicalPlan { +func eliminateUnionScanAndLock(sctx PlanContext, p PhysicalPlan) PhysicalPlan { var pointGet *PointGetPlan var batchPointGet *BatchPointGetPlan var physLock *PhysicalLock @@ -1356,7 +1356,7 @@ func init() { DefaultDisabledLogicalRulesList.Store(set.NewStringSet()) } -func disableReuseChunkIfNeeded(sctx sessionctx.Context, plan PhysicalPlan) { +func disableReuseChunkIfNeeded(sctx PlanContext, plan PhysicalPlan) { if !sctx.GetSessionVars().IsAllocValid() { return } @@ -1371,7 +1371,7 @@ func disableReuseChunkIfNeeded(sctx sessionctx.Context, plan PhysicalPlan) { } // checkOverlongColType Check if read field type is long field. -func checkOverlongColType(sctx sessionctx.Context, plan PhysicalPlan) bool { +func checkOverlongColType(sctx PlanContext, plan PhysicalPlan) bool { if plan == nil { return false } diff --git a/pkg/planner/core/partition_prune.go b/pkg/planner/core/partition_prune.go index 29defea381f68..957bceccf94b1 100644 --- a/pkg/planner/core/partition_prune.go +++ b/pkg/planner/core/partition_prune.go @@ -17,7 +17,6 @@ package core import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" ) @@ -25,7 +24,7 @@ import ( // PartitionPruning finds all used partitions according to query conditions, it will // return nil if condition match none of partitions. The return value is a array of the // idx in the partition definitions array, use pi.Definitions[idx] to get the partition ID -func PartitionPruning(ctx sessionctx.Context, tbl table.PartitionedTable, conds []expression.Expression, partitionNames []model.CIStr, +func PartitionPruning(ctx PlanContext, tbl table.PartitionedTable, conds []expression.Expression, partitionNames []model.CIStr, columns []*expression.Column, names types.NameSlice) ([]int, error) { s := partitionProcessor{} pi := tbl.Meta().Partition diff --git a/pkg/planner/core/physical_plans.go b/pkg/planner/core/physical_plans.go index c4ecd60502101..14e474ace7f1c 100644 --- a/pkg/planner/core/physical_plans.go +++ b/pkg/planner/core/physical_plans.go @@ -804,7 +804,7 @@ func (p *PhysicalIndexScan) MemoryUsage() (sum int64) { // For keepOrder with partition table, // we need use partitionHandle to distinct two handles, // the `_tidb_rowid` in different partitions can have the same value. -func AddExtraPhysTblIDColumn(sctx sessionctx.Context, columns []*model.ColumnInfo, schema *expression.Schema) ([]*model.ColumnInfo, *expression.Schema, bool) { +func AddExtraPhysTblIDColumn(sctx PlanContext, columns []*model.ColumnInfo, schema *expression.Schema) ([]*model.ColumnInfo, *expression.Schema, bool) { // Not adding the ExtraPhysTblID if already exists if FindColumnInfoByID(columns, model.ExtraPhysTblID) != nil { return columns, schema, false @@ -1620,7 +1620,7 @@ type PhysicalExpand struct { } // Init only assigns type and context. -func (p PhysicalExpand) Init(ctx sessionctx.Context, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalExpand { +func (p PhysicalExpand) Init(ctx PlanContext, stats *property.StatsInfo, offset int, props ...*property.PhysicalProperty) *PhysicalExpand { p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeExpand, &p, offset) p.childrenReqProps = props p.SetStats(stats) diff --git a/pkg/planner/core/plan.go b/pkg/planner/core/plan.go index c599d12531d01..3a1294412ede2 100644 --- a/pkg/planner/core/plan.go +++ b/pkg/planner/core/plan.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/core/internal/base" fd "github.com/pingcap/tidb/pkg/planner/funcdep" "github.com/pingcap/tidb/pkg/planner/property" @@ -34,6 +35,9 @@ import ( "github.com/pingcap/tipb/go-tipb" ) +// PlanContext is the context for building plan. +type PlanContext = context.PlanContext + // Plan is the description of an execution flow. // It is created from ast.Node first, then optimized by the optimizer, // finally used by the executor to create a Cursor which executes the statement. @@ -56,7 +60,7 @@ type Plan interface { // ReplaceExprColumns replace all the column reference in the plan's expression node. ReplaceExprColumns(replace map[string]*expression.Column) - SCtx() sessionctx.Context + SCtx() PlanContext // StatsInfo will return the property.StatsInfo for this plan. StatsInfo() *property.StatsInfo @@ -76,7 +80,7 @@ type Plan interface { BuildPlanTrace() *tracing.PlanTrace } -func enforceProperty(p *property.PhysicalProperty, tsk task, ctx sessionctx.Context) task { +func enforceProperty(p *property.PhysicalProperty, tsk task, ctx PlanContext) task { if p.TaskTp == property.MppTaskType { mpp, ok := tsk.(*mppTask) if !ok || mpp.invalid() { @@ -108,7 +112,7 @@ func enforceProperty(p *property.PhysicalProperty, tsk task, ctx sessionctx.Cont } // optimizeByShuffle insert `PhysicalShuffle` to optimize performance by running in a parallel manner. -func optimizeByShuffle(tsk task, ctx sessionctx.Context) task { +func optimizeByShuffle(tsk task, ctx PlanContext) task { if tsk.plan() == nil { return tsk } @@ -130,7 +134,7 @@ func optimizeByShuffle(tsk task, ctx sessionctx.Context) task { return tsk } -func optimizeByShuffle4Window(pp *PhysicalWindow, ctx sessionctx.Context) *PhysicalShuffle { +func optimizeByShuffle4Window(pp *PhysicalWindow, ctx PlanContext) *PhysicalShuffle { concurrency := ctx.GetSessionVars().WindowConcurrency() if concurrency <= 1 { return nil @@ -169,7 +173,7 @@ func optimizeByShuffle4Window(pp *PhysicalWindow, ctx sessionctx.Context) *Physi return shuffle } -func optimizeByShuffle4StreamAgg(pp *PhysicalStreamAgg, ctx sessionctx.Context) *PhysicalShuffle { +func optimizeByShuffle4StreamAgg(pp *PhysicalStreamAgg, ctx PlanContext) *PhysicalShuffle { concurrency := ctx.GetSessionVars().StreamAggConcurrency() if concurrency <= 1 { return nil @@ -206,7 +210,7 @@ func optimizeByShuffle4StreamAgg(pp *PhysicalStreamAgg, ctx sessionctx.Context) return shuffle } -func optimizeByShuffle4MergeJoin(pp *PhysicalMergeJoin, ctx sessionctx.Context) *PhysicalShuffle { +func optimizeByShuffle4MergeJoin(pp *PhysicalMergeJoin, ctx PlanContext) *PhysicalShuffle { concurrency := ctx.GetSessionVars().MergeJoinConcurrency() if concurrency <= 1 { return nil @@ -740,7 +744,7 @@ func (p *logicalSchemaProducer) BuildKeyInfo(selfSchema *expression.Schema, chil } } -func newBaseLogicalPlan(ctx sessionctx.Context, tp string, self LogicalPlan, qbOffset int) baseLogicalPlan { +func newBaseLogicalPlan(ctx PlanContext, tp string, self LogicalPlan, qbOffset int) baseLogicalPlan { return baseLogicalPlan{ taskMap: make(map[string]task), taskMapBak: make([]string, 0, 10), @@ -750,7 +754,7 @@ func newBaseLogicalPlan(ctx sessionctx.Context, tp string, self LogicalPlan, qbO } } -func newBasePhysicalPlan(ctx sessionctx.Context, tp string, self PhysicalPlan, offset int) basePhysicalPlan { +func newBasePhysicalPlan(ctx PlanContext, tp string, self PhysicalPlan, offset int) basePhysicalPlan { return basePhysicalPlan{ Plan: base.NewBasePlan(ctx, tp, offset), self: self, diff --git a/pkg/planner/core/plan_cache.go b/pkg/planner/core/plan_cache.go index 2866546a0de44..5b56c9c621a80 100644 --- a/pkg/planner/core/plan_cache.go +++ b/pkg/planner/core/plan_cache.go @@ -618,7 +618,7 @@ func rebuildRange(p Plan) error { return nil } -func convertConstant2Datum(ctx sessionctx.Context, con *expression.Constant, target *types.FieldType) (*types.Datum, error) { +func convertConstant2Datum(ctx PlanContext, con *expression.Constant, target *types.FieldType) (*types.Datum, error) { val, err := con.Eval(ctx, chunk.Row{}) if err != nil { return nil, err @@ -636,7 +636,7 @@ func convertConstant2Datum(ctx sessionctx.Context, con *expression.Constant, tar return &dVal, nil } -func buildRangeForTableScan(sctx sessionctx.Context, ts *PhysicalTableScan) (err error) { +func buildRangeForTableScan(sctx PlanContext, ts *PhysicalTableScan) (err error) { if ts.Table.IsCommonHandle { pk := tables.FindPrimaryIndex(ts.Table) pkCols := make([]*expression.Column, 0, len(pk.Columns)) @@ -701,7 +701,7 @@ func buildRangeForTableScan(sctx sessionctx.Context, ts *PhysicalTableScan) (err return } -func buildRangeForIndexScan(sctx sessionctx.Context, is *PhysicalIndexScan) (err error) { +func buildRangeForIndexScan(sctx PlanContext, is *PhysicalIndexScan) (err error) { if len(is.IdxCols) == 0 { if ranger.HasFullRange(is.Ranges, false) { // the original range is already a full-range. is.Ranges = ranger.FullRange() diff --git a/pkg/planner/core/plan_cache_utils.go b/pkg/planner/core/plan_cache_utils.go index e42d369a61a2a..bf1dc82131cf6 100644 --- a/pkg/planner/core/plan_cache_utils.go +++ b/pkg/planner/core/plan_cache_utils.go @@ -562,7 +562,7 @@ func checkTypesCompatibility4PC(tpsExpected, tpsActual []*types.FieldType) bool return true } -func isSafePointGetPath4PlanCache(sctx sessionctx.Context, path *util.AccessPath) bool { +func isSafePointGetPath4PlanCache(sctx PlanContext, path *util.AccessPath) bool { // PointGet might contain some over-optimized assumptions, like `a>=1 and a<=1` --> `a=1`, but // these assumptions may be broken after parameters change. diff --git a/pkg/planner/core/plan_stats.go b/pkg/planner/core/plan_stats.go index a58ea188776b9..fe2a5d7e732c0 100644 --- a/pkg/planner/core/plan_stats.go +++ b/pkg/planner/core/plan_stats.go @@ -21,9 +21,7 @@ import ( "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/util/logutil" @@ -49,7 +47,7 @@ func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan } // Prepare the table metadata to avoid repeatedly fetching from the infoSchema below. - is := sessiontxn.GetTxnManager(plan.SCtx()).GetTxnInfoSchema() + is := plan.SCtx().GetInfoSchema().(infoschema.InfoSchema) tblID2Tbl := make(map[int64]table.Table) for _, neededCol := range histNeededColumns { tbl, _ := infoschema.FindTableByTblOrPartID(is, neededCol.TableID) @@ -99,7 +97,7 @@ func (syncWaitStatsLoadPoint) name() string { const maxDuration = 1<<63 - 1 // RequestLoadStats send load column/index stats requests to stats handle -func RequestLoadStats(ctx sessionctx.Context, neededHistItems []model.TableItemID, syncWait int64) error { +func RequestLoadStats(ctx PlanContext, neededHistItems []model.TableItemID, syncWait int64) error { stmtCtx := ctx.GetSessionVars().StmtCtx hintMaxExecutionTime := int64(stmtCtx.MaxExecutionTime) if hintMaxExecutionTime <= 0 { @@ -229,7 +227,7 @@ func CollectDependingVirtualCols(tblID2Tbl map[int64]table.Table, neededItems [] // 1. the indices contained the any one of histNeededColumns, eg: histNeededColumns contained A,B columns, and idx_a is // composed up by A column, then we thought the idx_a should be collected // 2. The stats condition of idx_a can't meet IsFullLoad, which means its stats was evicted previously -func collectSyncIndices(ctx sessionctx.Context, +func collectSyncIndices(ctx PlanContext, histNeededColumns []model.TableItemID, tblID2Tbl map[int64]table.Table, ) map[model.TableItemID]struct{} { @@ -276,7 +274,7 @@ func collectHistNeededItems(histNeededColumns []model.TableItemID, histNeededInd return } -func recordTableRuntimeStats(sctx sessionctx.Context, tbls map[int64]struct{}) { +func recordTableRuntimeStats(sctx PlanContext, tbls map[int64]struct{}) { tblStats := sctx.GetSessionVars().StmtCtx.TableStats if tblStats == nil { tblStats = map[int64]any{} @@ -294,7 +292,7 @@ func recordTableRuntimeStats(sctx sessionctx.Context, tbls map[int64]struct{}) { sctx.GetSessionVars().StmtCtx.TableStats = tblStats } -func recordSingleTableRuntimeStats(sctx sessionctx.Context, tblID int64) (stats *statistics.Table, skip bool, err error) { +func recordSingleTableRuntimeStats(sctx PlanContext, tblID int64) (stats *statistics.Table, skip bool, err error) { dom := domain.GetDomain(sctx) statsHandle := dom.StatsHandle() is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index 4f37d7ab8134f..3895d92df56eb 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -1034,7 +1034,7 @@ func isForUpdateReadSelectLock(lock *ast.SelectLockInfo) bool { // getLatestIndexInfo gets the index info of latest schema version from given table id, // it returns nil if the schema version is not changed -func getLatestIndexInfo(ctx sessionctx.Context, id int64, startVer int64) (map[int64]*model.IndexInfo, bool, error) { +func getLatestIndexInfo(ctx PlanContext, id int64, startVer int64) (map[int64]*model.IndexInfo, bool, error) { dom := domain.GetDomain(ctx) if dom == nil { return nil, false, errors.New("domain not found for ctx") @@ -1055,7 +1055,7 @@ func getLatestIndexInfo(ctx sessionctx.Context, id int64, startVer int64) (map[i return latestIndexes, true, nil } -func getPossibleAccessPaths(ctx sessionctx.Context, tableHints *hint.PlanHints, indexHints []*ast.IndexHint, tbl table.Table, dbName, tblName model.CIStr, check bool, hasFlagPartitionProcessor bool) ([]*util.AccessPath, error) { +func getPossibleAccessPaths(ctx PlanContext, tableHints *hint.PlanHints, indexHints []*ast.IndexHint, tbl table.Table, dbName, tblName model.CIStr, check bool, hasFlagPartitionProcessor bool) ([]*util.AccessPath, error) { tblInfo := tbl.Meta() publicPaths := make([]*util.AccessPath, 0, len(tblInfo.Indices)+2) tp := kv.TiKV @@ -1234,7 +1234,7 @@ func getPossibleAccessPaths(ctx sessionctx.Context, tableHints *hint.PlanHints, return available, nil } -func filterPathByIsolationRead(ctx sessionctx.Context, paths []*util.AccessPath, tblName model.CIStr, dbName model.CIStr) ([]*util.AccessPath, error) { +func filterPathByIsolationRead(ctx PlanContext, paths []*util.AccessPath, tblName model.CIStr, dbName model.CIStr) ([]*util.AccessPath, error) { // TODO: filter paths with isolation read locations. if dbName.L == mysql.SystemDB { return paths, nil diff --git a/pkg/planner/core/point_get_plan.go b/pkg/planner/core/point_get_plan.go index bc524a73e108e..418963e1034ac 100644 --- a/pkg/planner/core/point_get_plan.go +++ b/pkg/planner/core/point_get_plan.go @@ -83,7 +83,7 @@ type PointGetPlan struct { IdxCols []*expression.Column IdxColLens []int AccessConditions []expression.Expression - ctx sessionctx.Context + ctx PlanContext UnsignedHandle bool IsTableDual bool Lock bool @@ -322,7 +322,7 @@ func (p *PointGetPlan) LoadTableStats(ctx sessionctx.Context) { type BatchPointGetPlan struct { baseSchemaProducer - ctx sessionctx.Context + ctx PlanContext dbName string TblInfo *model.TableInfo IndexInfo *model.IndexInfo @@ -628,7 +628,7 @@ func IsSelectForUpdateLockType(lockType ast.SelectLockType) bool { return false } -func getLockWaitTime(ctx sessionctx.Context, lockInfo *ast.SelectLockInfo) (lock bool, waitTime int64) { +func getLockWaitTime(ctx PlanContext, lockInfo *ast.SelectLockInfo) (lock bool, waitTime int64) { if lockInfo != nil { if IsSelectForUpdateLockType(lockInfo.LockType) { // Locking of rows for update using SELECT FOR UPDATE only applies when autocommit @@ -2060,7 +2060,7 @@ func getColumnPosInIndex(idx *model.IndexInfo, colName *model.CIStr) int { panic("unique index must include all partition columns") } -func getPartitionExpr(ctx sessionctx.Context, tbl *model.TableInfo) *tables.PartitionExpr { +func getPartitionExpr(ctx PlanContext, tbl *model.TableInfo) *tables.PartitionExpr { is := ctx.GetInfoSchema().(infoschema.InfoSchema) table, ok := is.TableByID(tbl.ID) if !ok { @@ -2076,7 +2076,7 @@ func getPartitionExpr(ctx sessionctx.Context, tbl *model.TableInfo) *tables.Part return partTable.PartitionExpr() } -func getHashOrKeyPartitionColumnName(ctx sessionctx.Context, tbl *model.TableInfo) *model.CIStr { +func getHashOrKeyPartitionColumnName(ctx PlanContext, tbl *model.TableInfo) *model.CIStr { pi := tbl.GetPartitionInfo() if pi == nil { return nil diff --git a/pkg/planner/core/rule_aggregation_elimination.go b/pkg/planner/core/rule_aggregation_elimination.go index 2da4cc5a90cd0..ebf689ac9ceca 100644 --- a/pkg/planner/core/rule_aggregation_elimination.go +++ b/pkg/planner/core/rule_aggregation_elimination.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/pkg/expression/aggregation" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" ) @@ -196,7 +195,7 @@ func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool, } // rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. -func rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) (bool, expression.Expression) { +func rewriteExpr(ctx expression.BuildContext, aggFunc *aggregation.AggFuncDesc) (bool, expression.Expression) { switch aggFunc.Name { case ast.AggFuncCount: if aggFunc.Mode == aggregation.FinalMode && @@ -214,7 +213,7 @@ func rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) (bool } } -func rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression { +func rewriteCount(ctx expression.BuildContext, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression { // If is count(expr), we will change it to if(isnull(expr), 0, 1). // If is count(distinct x, y, z), we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). // If is count(expr not null), we will change it to constant 1. @@ -233,7 +232,7 @@ func rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetT return newExpr } -func rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression { +func rewriteBitFunc(ctx expression.BuildContext, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression { // For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work. innerCast := expression.WrapWithCastAsInt(ctx, arg) outerCast := wrapCastFunction(ctx, innerCast, targetTp) @@ -247,7 +246,7 @@ func rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expr } // wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's. -func wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression { +func wrapCastFunction(ctx expression.BuildContext, arg expression.Expression, targetTp *types.FieldType) expression.Expression { if arg.GetType().Equal(targetTp) { return arg } diff --git a/pkg/planner/core/rule_aggregation_push_down.go b/pkg/planner/core/rule_aggregation_push_down.go index bd1a51a4991f6..1c439c0fc99d4 100644 --- a/pkg/planner/core/rule_aggregation_push_down.go +++ b/pkg/planner/core/rule_aggregation_push_down.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" ) @@ -193,7 +192,7 @@ func (a *aggregationPushDownSolver) splitAggFuncsAndGbyCols(agg *LogicalAggregat } // addGbyCol adds a column to gbyCols. If a group by column has existed, it will not be added repeatedly. -func (*aggregationPushDownSolver) addGbyCol(ctx sessionctx.Context, gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column { +func (*aggregationPushDownSolver) addGbyCol(ctx PlanContext, gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column { for _, c := range cols { duplicate := false for _, gbyCol := range gbyCols { @@ -216,7 +215,7 @@ func (*aggregationPushDownSolver) checkValidJoin(join *LogicalJoin) bool { // decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently // there are no differences between partial mode and complete mode, so we can confuse them. -func (*aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc, +func (*aggregationPushDownSolver) decompose(ctx PlanContext, aggFunc *aggregation.AggFuncDesc, schema *expression.Schema, nullGenerating bool) ([]*aggregation.AggFuncDesc, *expression.Schema) { // Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case. result := []*aggregation.AggFuncDesc{aggFunc.Clone()} @@ -325,7 +324,7 @@ func (*aggregationPushDownSolver) checkAllArgsColumn(fun *aggregation.AggFuncDes // TODO: // 1. https://github.com/pingcap/tidb/issues/16355, push avg & distinct functions across join // 2. remove this method and use splitPartialAgg instead for clean code. -func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, +func (a *aggregationPushDownSolver) makeNewAgg(ctx PlanContext, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, preferAggType uint, preferAggToCop bool, blockOffset int, nullGenerating bool) (*LogicalAggregation, error) { agg := LogicalAggregation{ GroupByItems: expression.Column2Exprs(gbyCols), diff --git a/pkg/planner/core/rule_join_reorder.go b/pkg/planner/core/rule_join_reorder.go index 0a409e3ff7866..0fd1d54940982 100644 --- a/pkg/planner/core/rule_join_reorder.go +++ b/pkg/planner/core/rule_join_reorder.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" h "github.com/pingcap/tidb/pkg/util/hint" "github.com/pingcap/tidb/pkg/util/plancodec" "github.com/pingcap/tidb/pkg/util/tracing" @@ -234,7 +233,7 @@ func (s *joinReOrderSolver) optimize(_ context.Context, p LogicalPlan, opt *logi } // optimizeRecursive recursively collects join groups and applies join reorder algorithm for each group. -func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalPlan, tracer *joinReorderTrace) (LogicalPlan, error) { +func (s *joinReOrderSolver) optimizeRecursive(ctx PlanContext, p LogicalPlan, tracer *joinReorderTrace) (LogicalPlan, error) { if _, ok := p.(*LogicalCTE); ok { return p, nil } @@ -393,7 +392,7 @@ type joinGroupResult struct { // nolint:structcheck type baseSingleGroupJoinOrderSolver struct { - ctx sessionctx.Context + ctx PlanContext curJoinGroup []*jrNode leadingJoinGroup LogicalPlan *basicJoinGroupInfo diff --git a/pkg/planner/core/rule_join_reorder_dp_test.go b/pkg/planner/core/rule_join_reorder_dp_test.go index ca48f11908ec5..443346520d0e7 100644 --- a/pkg/planner/core/rule_join_reorder_dp_test.go +++ b/pkg/planner/core/rule_join_reorder_dp_test.go @@ -36,7 +36,7 @@ type mockLogicalJoin struct { JoinType JoinType } -func (mj mockLogicalJoin) init(ctx sessionctx.Context) *mockLogicalJoin { +func (mj mockLogicalJoin) init(ctx PlanContext) *mockLogicalJoin { mj.baseLogicalPlan = newBaseLogicalPlan(ctx, "MockLogicalJoin", &mj, 0) return &mj } diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index 622e21b8d16a4..ef21fb84a09b0 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -29,7 +29,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tables" "github.com/pingcap/tidb/pkg/types" @@ -118,7 +117,7 @@ type partitionTable interface { PartitionExpr() *tables.PartitionExpr } -func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) { +func generateHashPartitionExpr(ctx PlanContext, pi *model.PartitionInfo, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) { schema := expression.NewSchema(columns...) // Increase the PlanID to make sure some tests will pass. The old implementation to rewrite AST builds a `TableDual` // that causes the `PlanID` increases, and many test cases hardcoded the output plan ID in the expected result. @@ -144,7 +143,7 @@ func getPartColumnsForHashPartition(hashExpr expression.Expression) ([]*expressi return partCols, colLen } -func (s *partitionProcessor) getUsedHashPartitions(ctx sessionctx.Context, +func (s *partitionProcessor) getUsedHashPartitions(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, columns []*expression.Column, conds []expression.Expression, names types.NameSlice) ([]int, []expression.Expression, error) { pi := tbl.Meta().Partition @@ -262,7 +261,7 @@ func (s *partitionProcessor) getUsedHashPartitions(ctx sessionctx.Context, return used, detachedResult.RemainedConds, nil } -func (s *partitionProcessor) getUsedKeyPartitions(ctx sessionctx.Context, +func (s *partitionProcessor) getUsedKeyPartitions(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, columns []*expression.Column, conds []expression.Expression, _ types.NameSlice) ([]int, []expression.Expression, error) { pi := tbl.Meta().Partition @@ -369,7 +368,7 @@ func (s *partitionProcessor) getUsedKeyPartitions(ctx sessionctx.Context, } // getUsedPartitions is used to get used partitions for hash or key partition tables -func (s *partitionProcessor) getUsedPartitions(ctx sessionctx.Context, tbl table.Table, +func (s *partitionProcessor) getUsedPartitions(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, columns []*expression.Column, conds []expression.Expression, names types.NameSlice, partType model.PartitionType) ([]int, []expression.Expression, error) { if partType == model.PartitionTypeHash { @@ -381,7 +380,7 @@ func (s *partitionProcessor) getUsedPartitions(ctx sessionctx.Context, tbl table // findUsedPartitions is used to get used partitions for hash or key partition tables. // The first returning is the used partition index set pruned by `conds`. // The second returning is the filter conditions which should be kept after pruning. -func (s *partitionProcessor) findUsedPartitions(ctx sessionctx.Context, +func (s *partitionProcessor) findUsedPartitions(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, conds []expression.Expression, columns []*expression.Column, names types.NameSlice) ([]int, []expression.Expression, error) { pi := tbl.Meta().Partition @@ -434,7 +433,7 @@ func convertToRangeOr(used []int, pi *model.PartitionInfo) partitionRangeOR { } // pruneHashOrKeyPartition is used to prune hash or key partition tables -func (s *partitionProcessor) pruneHashOrKeyPartition(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, +func (s *partitionProcessor) pruneHashOrKeyPartition(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, conds []expression.Expression, columns []*expression.Column, names types.NameSlice) ([]int, error) { used, _, err := s.findUsedPartitions(ctx, tbl, partitionNames, conds, columns, names) if err != nil { @@ -519,14 +518,14 @@ func (s *partitionProcessor) processHashOrKeyPartition(ds *DataSource, pi *model // listPartitionPruner uses to prune partition for list partition. type listPartitionPruner struct { *partitionProcessor - ctx sessionctx.Context + ctx PlanContext pi *model.PartitionInfo partitionNames []model.CIStr fullRange map[int]struct{} listPrune *tables.ForListPruning } -func newListPartitionPruner(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, s *partitionProcessor, pruneList *tables.ForListPruning, columns []*expression.Column) *listPartitionPruner { +func newListPartitionPruner(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, s *partitionProcessor, pruneList *tables.ForListPruning, columns []*expression.Column) *listPartitionPruner { pruneList = pruneList.Clone() for i := range pruneList.PruneExprCols { for j := range columns { @@ -781,7 +780,7 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi return used, nil } -func (s *partitionProcessor) findUsedListPartitions(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, +func (s *partitionProcessor) findUsedListPartitions(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, conds []expression.Expression, columns []*expression.Column) ([]int, error) { pi := tbl.Meta().Partition partExpr := tbl.(partitionTable).PartitionExpr() @@ -809,7 +808,7 @@ func (s *partitionProcessor) findUsedListPartitions(ctx sessionctx.Context, tbl return ret, nil } -func (s *partitionProcessor) pruneListPartition(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, +func (s *partitionProcessor) pruneListPartition(ctx PlanContext, tbl table.Table, partitionNames []model.CIStr, conds []expression.Expression, columns []*expression.Column) ([]int, error) { used, err := s.findUsedListPartitions(ctx, tbl, partitionNames, conds, columns) if err != nil { @@ -992,7 +991,7 @@ func intersectionRange(start, end, newStart, newEnd int) (s int, e int) { return s, e } -func (s *partitionProcessor) pruneRangePartition(ctx sessionctx.Context, pi *model.PartitionInfo, tbl table.PartitionedTable, conds []expression.Expression, +func (s *partitionProcessor) pruneRangePartition(ctx PlanContext, pi *model.PartitionInfo, tbl table.PartitionedTable, conds []expression.Expression, columns []*expression.Column, names types.NameSlice) (partitionRangeOR, error) { partExpr := tbl.(partitionTable).PartitionExpr() @@ -1051,7 +1050,7 @@ func (s *partitionProcessor) processListPartition(ds *DataSource, pi *model.Part } // makePartitionByFnCol extracts the column and function information in 'partition by ... fn(col)'. -func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, monotoneMode, error) { +func makePartitionByFnCol(sctx PlanContext, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, monotoneMode, error) { monotonous := monotoneModeInvalid schema := expression.NewSchema(columns...) // Increase the PlanID to make sure some tests will pass. The old implementation to rewrite AST builds a `TableDual` @@ -1097,7 +1096,7 @@ func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, return col, fn, monotonous, nil } -func minCmp(ctx sessionctx.Context, lowVal []types.Datum, columnsPruner *rangeColumnsPruner, comparer []collate.Collator, lowExclude bool, gotError *bool) func(i int) bool { +func minCmp(ctx PlanContext, lowVal []types.Datum, columnsPruner *rangeColumnsPruner, comparer []collate.Collator, lowExclude bool, gotError *bool) func(i int) bool { return func(i int) bool { for j := range lowVal { expr := columnsPruner.lessThan[i][j] @@ -1177,7 +1176,7 @@ func minCmp(ctx sessionctx.Context, lowVal []types.Datum, columnsPruner *rangeCo } } -func maxCmp(ctx sessionctx.Context, hiVal []types.Datum, columnsPruner *rangeColumnsPruner, comparer []collate.Collator, hiExclude bool, gotError *bool) func(i int) bool { +func maxCmp(ctx PlanContext, hiVal []types.Datum, columnsPruner *rangeColumnsPruner, comparer []collate.Collator, hiExclude bool, gotError *bool) func(i int) bool { return func(i int) bool { for j := range hiVal { expr := columnsPruner.lessThan[i][j] @@ -1216,7 +1215,7 @@ func maxCmp(ctx sessionctx.Context, hiVal []types.Datum, columnsPruner *rangeCol } } -func multiColumnRangeColumnsPruner(sctx sessionctx.Context, exprs []expression.Expression, +func multiColumnRangeColumnsPruner(sctx PlanContext, exprs []expression.Expression, columnsPruner *rangeColumnsPruner, result partitionRangeOR) partitionRangeOR { lens := make([]int, 0, len(columnsPruner.partCols)) for i := range columnsPruner.partCols { @@ -1265,7 +1264,7 @@ func multiColumnRangeColumnsPruner(sctx sessionctx.Context, exprs []expression.E return result.intersection(rangeOr).simplify() } -func partitionRangeForCNFExpr(sctx sessionctx.Context, exprs []expression.Expression, +func partitionRangeForCNFExpr(sctx PlanContext, exprs []expression.Expression, pruner partitionRangePruner, result partitionRangeOR) partitionRangeOR { // TODO: When the ranger/detacher handles varchar_col_general_ci cmp constant bin collation // remove the check for single column RANGE COLUMNS and remove the single column implementation @@ -1279,7 +1278,7 @@ func partitionRangeForCNFExpr(sctx sessionctx.Context, exprs []expression.Expres } // partitionRangeForExpr calculate the partitions for the expression. -func partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression, +func partitionRangeForExpr(sctx PlanContext, expr expression.Expression, pruner partitionRangePruner, result partitionRangeOR) partitionRangeOR { // Handle AND, OR respectively. if op, ok := expr.(*expression.ScalarFunction); ok { @@ -1311,7 +1310,7 @@ func partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression, } type partitionRangePruner interface { - partitionRangeForExpr(sessionctx.Context, expression.Expression) (start, end int, succ bool) + partitionRangeForExpr(PlanContext, expression.Expression) (start, end int, succ bool) fullRange() partitionRangeOR } @@ -1326,7 +1325,7 @@ type rangePruner struct { monotonous monotoneMode } -func (p *rangePruner) partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression) (start int, end int, ok bool) { +func (p *rangePruner) partitionRangeForExpr(sctx PlanContext, expr expression.Expression) (start int, end int, ok bool) { if constExpr, ok := expr.(*expression.Constant); ok { if b, err := constExpr.Value.ToBool(sctx.GetSessionVars().StmtCtx.TypeCtx()); err == nil && b == 0 { // A constant false expression. @@ -1348,14 +1347,14 @@ func (p *rangePruner) fullRange() partitionRangeOR { } // partitionRangeForOrExpr calculate the partitions for or(expr1, expr2) -func partitionRangeForOrExpr(sctx sessionctx.Context, expr1, expr2 expression.Expression, +func partitionRangeForOrExpr(sctx PlanContext, expr1, expr2 expression.Expression, pruner partitionRangePruner) partitionRangeOR { tmp1 := partitionRangeForExpr(sctx, expr1, pruner, pruner.fullRange()) tmp2 := partitionRangeForExpr(sctx, expr2, pruner, pruner.fullRange()) return tmp1.union(tmp2) } -func partitionRangeColumnForInExpr(sctx sessionctx.Context, args []expression.Expression, +func partitionRangeColumnForInExpr(sctx PlanContext, args []expression.Expression, pruner *rangeColumnsPruner) partitionRangeOR { col, ok := args[0].(*expression.Column) if !ok || col.ID != pruner.partCols[0].ID { @@ -1392,7 +1391,7 @@ func partitionRangeColumnForInExpr(sctx sessionctx.Context, args []expression.Ex return result.simplify() } -func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expression, +func partitionRangeForInExpr(sctx PlanContext, args []expression.Expression, pruner *rangePruner) partitionRangeOR { col, ok := args[0].(*expression.Column) if !ok || col.ID != pruner.col.ID { @@ -1469,7 +1468,7 @@ type dataForPrune struct { // extractDataForPrune extracts data from the expression for pruning. // The expression should have this form: 'f(x) op const', otherwise it can't be pruned. -func (p *rangePruner) extractDataForPrune(sctx sessionctx.Context, expr expression.Expression) (dataForPrune, bool) { +func (p *rangePruner) extractDataForPrune(sctx PlanContext, expr expression.Expression) (dataForPrune, bool) { var ret dataForPrune op, ok := expr.(*expression.ScalarFunction) if !ok { @@ -1740,7 +1739,7 @@ func checkTableHintsApplicableForPartition(partitions []model.CIStr, partitionSe return unknownPartitions } -func appendWarnForUnknownPartitions(ctx sessionctx.Context, hintName string, unknownPartitions []string) { +func appendWarnForUnknownPartitions(ctx PlanContext, hintName string, unknownPartitions []string) { if len(unknownPartitions) == 0 { return } @@ -1819,7 +1818,7 @@ func (s *partitionProcessor) makeUnionAllChildren(ds *DataSource, pi *model.Part return unionAll, nil } -func (*partitionProcessor) pruneRangeColumnsPartition(ctx sessionctx.Context, conds []expression.Expression, pi *model.PartitionInfo, pe *tables.PartitionExpr, columns []*expression.Column) (partitionRangeOR, error) { +func (*partitionProcessor) pruneRangeColumnsPartition(ctx PlanContext, conds []expression.Expression, pi *model.PartitionInfo, pe *tables.PartitionExpr, columns []*expression.Column) (partitionRangeOR, error) { result := fullRange(len(pi.Definitions)) if len(pi.Columns) < 1 { @@ -1879,7 +1878,7 @@ func (p *rangeColumnsPruner) getPartCol(colID int64) *expression.Column { return nil } -func (p *rangeColumnsPruner) partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression) (start int, end int, ok bool) { +func (p *rangeColumnsPruner) partitionRangeForExpr(sctx PlanContext, expr expression.Expression) (start int, end int, ok bool) { op, ok := expr.(*expression.ScalarFunction) if !ok { return 0, len(p.lessThan), false @@ -1936,7 +1935,7 @@ func (p *rangeColumnsPruner) partitionRangeForExpr(sctx sessionctx.Context, expr // pruneUseBinarySearch returns the start and end of which partitions will match. // If no match (i.e. value > last partition) the start partition will be the number of partition, not the first partition! -func (p *rangeColumnsPruner) pruneUseBinarySearch(sctx sessionctx.Context, op string, data *expression.Constant) (start int, end int) { +func (p *rangeColumnsPruner) pruneUseBinarySearch(sctx PlanContext, op string, data *expression.Constant) (start int, end int) { var savedError error var isNull bool if len(p.partCols) > 1 { diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index 209ba4ea20608..105e2754aa035 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/ranger" @@ -36,7 +35,7 @@ type ppdSolver struct{} // exprPrefixAdder is the wrapper struct to add tidb_shard(x) = val for `OrigConds` // `cols` is the index columns for a unique shard index type exprPrefixAdder struct { - sctx sessionctx.Context + sctx PlanContext OrigConds []expression.Expression cols []*expression.Column lengths []int @@ -429,7 +428,7 @@ func simplifyOuterJoin(p *LogicalJoin, predicates []expression.Expression) { // If it is a predicate containing a reference to an inner table that evaluates to UNKNOWN or FALSE when one of its arguments is NULL. // If it is a conjunction containing a null-rejected condition as a conjunct. // If it is a disjunction of null-rejected conditions. -func isNullRejected(ctx sessionctx.Context, schema *expression.Schema, expr expression.Expression) bool { +func isNullRejected(ctx PlanContext, schema *expression.Schema, expr expression.Expression) bool { expr = expression.PushDownNot(ctx, expr) if expression.ContainOuterNot(expr) { return false @@ -460,14 +459,14 @@ func isNullRejected(ctx sessionctx.Context, schema *expression.Schema, expr expr // isNullRejectedSpecially handles some null-rejected cases specially, since the current in // EvaluateExprWithNull is too strict for some cases, e.g. #49616. -func isNullRejectedSpecially(ctx sessionctx.Context, schema *expression.Schema, expr expression.Expression) bool { +func isNullRejectedSpecially(ctx PlanContext, schema *expression.Schema, expr expression.Expression) bool { return specialNullRejectedCase1(ctx, schema, expr) // only 1 case now } // specialNullRejectedCase1 is mainly for #49616. // Case1 specially handles `null-rejected OR (null-rejected AND {others})`, then no matter what the result // of `{others}` is (True, False or Null), the result of this predicate is null, so this predicate is null-rejected. -func specialNullRejectedCase1(ctx sessionctx.Context, schema *expression.Schema, expr expression.Expression) bool { +func specialNullRejectedCase1(ctx PlanContext, schema *expression.Schema, expr expression.Expression) bool { isFunc := func(e expression.Expression, lowerFuncName string) *expression.ScalarFunction { f, ok := e.(*expression.ScalarFunction) if !ok { @@ -720,7 +719,7 @@ func DeriveOtherConditions( // deriveNotNullExpr generates a new expression `not(isnull(col))` given `col1 op col2`, // in which `col` is in specified schema. Caller guarantees that only one of `col1` or // `col2` is in schema. -func deriveNotNullExpr(ctx sessionctx.Context, expr expression.Expression, schema *expression.Schema) expression.Expression { +func deriveNotNullExpr(ctx PlanContext, expr expression.Expression, schema *expression.Schema) expression.Expression { binop, ok := expr.(*expression.ScalarFunction) if !ok || len(binop.GetArgs()) != 2 { return nil @@ -926,7 +925,7 @@ func appendAddSelectionTraceStep(p LogicalPlan, child LogicalPlan, sel *LogicalS // "SELECT * FROM test WHERE tidb_shard(a) = val1 AND a = 10 OR tidb_shard(a) = val2 AND a = 20" // @param[in] conds the original condtion of this datasource // @retval - the new condition after adding expression prefix -func (ds *DataSource) AddPrefix4ShardIndexes(sc sessionctx.Context, conds []expression.Expression) []expression.Expression { +func (ds *DataSource) AddPrefix4ShardIndexes(sc PlanContext, conds []expression.Expression) []expression.Expression { if !ds.containExprPrefixUk { return conds } @@ -953,7 +952,7 @@ func (ds *DataSource) AddPrefix4ShardIndexes(sc sessionctx.Context, conds []expr return newConds } -func (ds *DataSource) addExprPrefixCond(sc sessionctx.Context, path *util.AccessPath, +func (ds *DataSource) addExprPrefixCond(sc PlanContext, path *util.AccessPath, conds []expression.Expression) ([]expression.Expression, error) { idxCols, idxColLens := expression.IndexInfo2PrefixCols(ds.Columns, ds.schema.Columns, path.Index) diff --git a/pkg/planner/core/rule_predicate_simplification.go b/pkg/planner/core/rule_predicate_simplification.go index 501c78767012e..1319bcab0a848 100644 --- a/pkg/planner/core/rule_predicate_simplification.go +++ b/pkg/planner/core/rule_predicate_simplification.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" ) // predicateSimplification consolidates different predcicates on a column and its equivalence classes. Initial out is for @@ -81,7 +80,7 @@ func (s *baseLogicalPlan) predicateSimplification(opt *logicalOptimizeOp) Logica // updateInPredicate applies intersection of an in list with <> value. It returns updated In list and a flag for // a special case if an element in the inlist is not removed to keep the list not empty. -func updateInPredicate(ctx sessionctx.Context, inPredicate expression.Expression, notEQPredicate expression.Expression) (expression.Expression, bool) { +func updateInPredicate(ctx PlanContext, inPredicate expression.Expression, notEQPredicate expression.Expression) (expression.Expression, bool) { _, inPredicateType := findPredicateType(inPredicate) _, notEQPredicateType := findPredicateType(notEQPredicate) if inPredicateType != inListPredicate || notEQPredicateType != notEqualPredicate { @@ -117,7 +116,7 @@ func updateInPredicate(ctx sessionctx.Context, inPredicate expression.Expression return newPred, specialCase } -func applyPredicateSimplification(sctx sessionctx.Context, predicates []expression.Expression) []expression.Expression { +func applyPredicateSimplification(sctx PlanContext, predicates []expression.Expression) []expression.Expression { if len(predicates) <= 1 { return predicates } diff --git a/pkg/planner/core/stats.go b/pkg/planner/core/stats.go index 675f93363a795..4ce85d413e0d4 100644 --- a/pkg/planner/core/stats.go +++ b/pkg/planner/core/stats.go @@ -220,7 +220,7 @@ func init() { } // getTblInfoForUsedStatsByPhysicalID get table name, partition name and HintedTable that will be used to record used stats. -func getTblInfoForUsedStatsByPhysicalID(sctx sessionctx.Context, id int64) (fullName string, tblInfo *model.TableInfo) { +func getTblInfoForUsedStatsByPhysicalID(sctx PlanContext, id int64) (fullName string, tblInfo *model.TableInfo) { fullName = "tableID " + strconv.FormatInt(id, 10) is := domain.GetDomain(sctx).InfoSchema() @@ -343,7 +343,7 @@ func (ds *DataSource) derivePathStatsAndTryHeuristics() error { selected = path break } - if path.OnlyPointRange(ds.SCtx()) { + if path.OnlyPointRange(ds.SCtx().GetSessionVars().StmtCtx.TypeCtx()) { if path.IsTablePath() || path.Index.Unique { if path.IsSingleScan { selected = path diff --git a/pkg/planner/core/task.go b/pkg/planner/core/task.go index bcc564e2e944b..9567b4b5993fa 100644 --- a/pkg/planner/core/task.go +++ b/pkg/planner/core/task.go @@ -29,7 +29,6 @@ import ( "github.com/pingcap/tidb/pkg/planner/core/internal/base" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -55,7 +54,7 @@ type task interface { copy() task plan() PhysicalPlan invalid() bool - convertToRootTask(ctx sessionctx.Context) *rootTask + convertToRootTask(ctx PlanContext) *rootTask MemoryUsage() int64 } @@ -392,7 +391,7 @@ func negotiateCommonType(lType, rType *types.FieldType) (*types.FieldType, bool, return commonType, needConvert(lType, commonType), needConvert(rType, commonType) } -func getProj(ctx sessionctx.Context, p PhysicalPlan) *PhysicalProjection { +func getProj(ctx PlanContext, p PhysicalPlan) *PhysicalProjection { proj := PhysicalProjection{ Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), }.Init(ctx, p.StatsInfo(), p.QueryBlockOffset()) @@ -581,7 +580,7 @@ func (p *PhysicalMergeJoin) attach2Task(tasks ...task) task { return t } -func buildIndexLookUpTask(ctx sessionctx.Context, t *copTask) *rootTask { +func buildIndexLookUpTask(ctx PlanContext, t *copTask) *rootTask { newTask := &rootTask{} p := PhysicalIndexLookUpReader{ tablePlan: t.tablePlan, @@ -629,7 +628,7 @@ func extractRows(p PhysicalPlan) float64 { } // calcPagingCost calculates the cost for paging processing which may increase the seekCnt and reduce scanned rows. -func calcPagingCost(ctx sessionctx.Context, indexPlan PhysicalPlan, expectCnt uint64) float64 { +func calcPagingCost(ctx PlanContext, indexPlan PhysicalPlan, expectCnt uint64) float64 { sessVars := ctx.GetSessionVars() indexRows := indexPlan.StatsCount() sourceRows := extractRows(indexPlan) @@ -650,16 +649,16 @@ func calcPagingCost(ctx sessionctx.Context, indexPlan PhysicalPlan, expectCnt ui return math.Max(pagingCst-sessVars.GetSeekFactor(nil), 0) } -func (t *rootTask) convertToRootTask(_ sessionctx.Context) *rootTask { +func (t *rootTask) convertToRootTask(_ PlanContext) *rootTask { return t.copy().(*rootTask) } -func (t *copTask) convertToRootTask(ctx sessionctx.Context) *rootTask { +func (t *copTask) convertToRootTask(ctx PlanContext) *rootTask { // copy one to avoid changing itself. return t.copy().(*copTask).convertToRootTaskImpl(ctx) } -func (t *copTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { +func (t *copTask) convertToRootTaskImpl(ctx PlanContext) *rootTask { // copTasks are run in parallel, to make the estimated cost closer to execution time, we amortize // the cost to cop iterator workers. According to `CopClient::Send`, the concurrency // is Min(DistSQLScanConcurrency, numRegionsInvolvedInScan), since we cannot infer @@ -760,7 +759,7 @@ func (t *copTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { return newTask } -func (t *copTask) handleRootTaskConds(ctx sessionctx.Context, newTask *rootTask) { +func (t *copTask) handleRootTaskConds(ctx PlanContext, newTask *rootTask) { if len(t.rootTaskConds) > 0 { selectivity, _, err := cardinality.Selectivity(ctx, t.tblColHists, t.rootTaskConds, nil) if err != nil { @@ -1286,7 +1285,7 @@ func (sel *PhysicalSelection) attach2Task(tasks ...task) task { // CheckAggCanPushCop checks whether the aggFuncs and groupByItems can // be pushed down to coprocessor. -func CheckAggCanPushCop(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, storeType kv.StoreType) bool { +func CheckAggCanPushCop(sctx PlanContext, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, storeType kv.StoreType) bool { sc := sctx.GetSessionVars().StmtCtx client := sctx.GetClient() ret := true @@ -1365,7 +1364,7 @@ type AggInfo struct { // building the aggregate executor(e.g. buildHashAgg will split the AggDesc further for parallel executing). // firstRowFuncMap is a map between partial first_row to final first_row, will be used in RemoveUnnecessaryFirstRow func BuildFinalModeAggregation( - sctx sessionctx.Context, original *AggInfo, partialIsCop bool, isMPPTask bool) (partial, final *AggInfo, firstRowFuncMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) { + sctx PlanContext, original *AggInfo, partialIsCop bool, isMPPTask bool) (partial, final *AggInfo, firstRowFuncMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) { firstRowFuncMap = make(map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc, len(original.AggFuncs)) partial = &AggInfo{ AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(original.AggFuncs)), @@ -1849,7 +1848,7 @@ func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { return num == 1 } -func genFirstRowAggForGroupBy(ctx sessionctx.Context, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) { +func genFirstRowAggForGroupBy(ctx PlanContext, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) { aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems)) for _, groupBy := range groupByItems { agg, err := aggregation.NewAggFuncDesc(ctx, ast.AggFuncFirstRow, []expression.Expression{groupBy}, false) @@ -1869,7 +1868,7 @@ func genFirstRowAggForGroupBy(ctx sessionctx.Context, groupByItems []expression. // The schema is [firstrow(a), count(b), a]. The column firstrow(a) is unnecessary. // Can optimize the schema to [count(b), a] , and change the index to get value. func RemoveUnnecessaryFirstRow( - sctx sessionctx.Context, + sctx PlanContext, finalGbyItems []expression.Expression, partialAggFuncs []*aggregation.AggFuncDesc, partialGbyItems []expression.Expression, @@ -2588,7 +2587,7 @@ func (t *mppTask) invalid() bool { return t.p == nil } -func (t *mppTask) convertToRootTask(ctx sessionctx.Context) *rootTask { +func (t *mppTask) convertToRootTask(ctx PlanContext) *rootTask { return t.copy().(*mppTask).convertToRootTaskImpl(ctx) } @@ -2643,7 +2642,7 @@ func tryExpandVirtualColumn(p PhysicalPlan) { } } -func (t *mppTask) convertToRootTaskImpl(ctx sessionctx.Context) *rootTask { +func (t *mppTask) convertToRootTaskImpl(ctx PlanContext) *rootTask { // In disaggregated-tiflash mode, need to consider generated column. tryExpandVirtualColumn(t.p) sender := PhysicalExchangeSender{ diff --git a/pkg/planner/core/tiflash_selection_late_materialization.go b/pkg/planner/core/tiflash_selection_late_materialization.go index aab128d4166df..2039b42e2b2ba 100644 --- a/pkg/planner/core/tiflash_selection_late_materialization.go +++ b/pkg/planner/core/tiflash_selection_late_materialization.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/planner/cardinality" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) @@ -48,7 +47,7 @@ type expressionGroup struct { // predicatePushDownToTableScan is used find the selection just above the table scan // and try to push down the predicates to the table scan. // Used for TiFlash late materialization. -func predicatePushDownToTableScan(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan { +func predicatePushDownToTableScan(sctx PlanContext, plan PhysicalPlan) PhysicalPlan { switch p := plan.(type) { case *PhysicalSelection: if physicalTableScan, ok := plan.Children()[0].(*PhysicalTableScan); ok && physicalTableScan.StoreType == kv.TiFlash { @@ -104,7 +103,7 @@ func transformColumnsToCode(cols []*expression.Column, totalColumnCount int) str // @example: conds = [a > 1, b > 1, a > 2, c > 1, a > 3, b > 2], return = [[a > 3, a > 2, a > 1], [b > 2, b > 1], [c > 1]] // @note: when the selectivity of one group is larger than the threshold, we will remove it from the returned result. // @note: when the number of columns of one group is larger than the threshold, we will remove it from the returned result. -func groupByColumnsSortBySelectivity(sctx sessionctx.Context, conds []expression.Expression, physicalTableScan *PhysicalTableScan) []expressionGroup { +func groupByColumnsSortBySelectivity(sctx PlanContext, conds []expression.Expression, physicalTableScan *PhysicalTableScan) []expressionGroup { // Create a map to store the groupMap of conditions keyed by the columns groupMap := make(map[string][]expression.Expression) @@ -205,7 +204,7 @@ func removeSpecificExprsFromSelection(physicalSelection *PhysicalSelection, expr // @param: sctx: the session context // @param: physicalSelection: the PhysicalSelection containing the conditions to be pushed down // @param: physicalTableScan: the PhysicalTableScan to be pushed down to -func predicatePushDownToTableScanImpl(sctx sessionctx.Context, physicalSelection *PhysicalSelection, physicalTableScan *PhysicalTableScan) { +func predicatePushDownToTableScanImpl(sctx PlanContext, physicalSelection *PhysicalSelection, physicalTableScan *PhysicalTableScan) { // When the table is small, there is no need to push down the conditions. if physicalTableScan.tblColHists.RealtimeCount <= tiflashDataPackSize || physicalTableScan.KeepOrder { return diff --git a/pkg/planner/property/BUILD.bazel b/pkg/planner/property/BUILD.bazel index eb02cf0275784..907013afdcdc1 100644 --- a/pkg/planner/property/BUILD.bazel +++ b/pkg/planner/property/BUILD.bazel @@ -12,7 +12,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/expression", - "//pkg/sessionctx", "//pkg/statistics", "//pkg/util/codec", "//pkg/util/collate", diff --git a/pkg/planner/property/physical_property.go b/pkg/planner/property/physical_property.go index 79a16448572d8..c83592e8bdc25 100644 --- a/pkg/planner/property/physical_property.go +++ b/pkg/planner/property/physical_property.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/log" "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/size" @@ -130,7 +129,7 @@ func (partitionCol *MPPPartitionColumn) MemoryUsage() (sum int64) { } // ExplainColumnList generates explain information for a list of columns. -func ExplainColumnList(ctx sessionctx.Context, cols []*MPPPartitionColumn) []byte { +func ExplainColumnList(ctx expression.EvalContext, cols []*MPPPartitionColumn) []byte { buffer := bytes.NewBufferString("") for i, col := range cols { buffer.WriteString("[name: ") diff --git a/pkg/planner/util/BUILD.bazel b/pkg/planner/util/BUILD.bazel index 75c6a1418507f..c2597d6777ec7 100644 --- a/pkg/planner/util/BUILD.bazel +++ b/pkg/planner/util/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//pkg/kv", "//pkg/parser/ast", "//pkg/parser/model", + "//pkg/planner/context", "//pkg/sessionctx", "//pkg/types", "//pkg/util/collate", diff --git a/pkg/planner/util/debugtrace/BUILD.bazel b/pkg/planner/util/debugtrace/BUILD.bazel index fded8bf9198d6..13deb51fd2a5d 100644 --- a/pkg/planner/util/debugtrace/BUILD.bazel +++ b/pkg/planner/util/debugtrace/BUILD.bazel @@ -5,5 +5,5 @@ go_library( srcs = ["base.go"], importpath = "github.com/pingcap/tidb/pkg/planner/util/debugtrace", visibility = ["//visibility:public"], - deps = ["//pkg/sessionctx"], + deps = ["//pkg/planner/context"], ) diff --git a/pkg/planner/util/debugtrace/base.go b/pkg/planner/util/debugtrace/base.go index 3aee91302237a..45021ea51d9e0 100644 --- a/pkg/planner/util/debugtrace/base.go +++ b/pkg/planner/util/debugtrace/base.go @@ -19,7 +19,7 @@ import ( "encoding/json" "runtime" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/planner/context" ) // OptimizerDebugTraceRoot is for recording the optimizer debug trace. @@ -94,7 +94,7 @@ func (root *OptimizerDebugTraceRoot) AppendStepWithNameToCurrentContext(step any // GetOrInitDebugTraceRoot returns the debug trace root. // If it's not initialized, it will initialize it first. -func GetOrInitDebugTraceRoot(sctx sessionctx.Context) *OptimizerDebugTraceRoot { +func GetOrInitDebugTraceRoot(sctx context.PlanContext) *OptimizerDebugTraceRoot { stmtCtx := sctx.GetSessionVars().StmtCtx res, ok := stmtCtx.OptimizerDebugTrace.(*OptimizerDebugTraceRoot) if !ok || res == nil { @@ -123,7 +123,7 @@ func EncodeJSONCommon(input any) ([]byte, error) { // EnterContextCommon records the function name of the caller, // then creates and enter a new context for this debug trace structure. -func EnterContextCommon(sctx sessionctx.Context) { +func EnterContextCommon(sctx context.PlanContext) { root := GetOrInitDebugTraceRoot(sctx) funcName := "Fail to get function name." pc, _, _, ok := runtime.Caller(1) @@ -139,7 +139,7 @@ func EnterContextCommon(sctx sessionctx.Context) { } // LeaveContextCommon makes the debug trace goes to its parent context. -func LeaveContextCommon(sctx sessionctx.Context) { +func LeaveContextCommon(sctx context.PlanContext) { root := GetOrInitDebugTraceRoot(sctx) root.currentCtx = root.currentCtx.parentCtx } @@ -148,7 +148,7 @@ func LeaveContextCommon(sctx sessionctx.Context) { // The vals arguments should be a slice like ["name1", value1, "name2", value2]. // The names must be string, the values can be any type. func RecordAnyValuesWithNames( - s sessionctx.Context, + s context.PlanContext, vals ...any, ) { root := GetOrInitDebugTraceRoot(s) diff --git a/pkg/planner/util/path.go b/pkg/planner/util/path.go index 4fe34ad2f799a..925e3fa7843d2 100644 --- a/pkg/planner/util/path.go +++ b/pkg/planner/util/path.go @@ -21,7 +21,7 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/ranger" @@ -121,7 +121,7 @@ func (path *AccessPath) IsTablePath() bool { // SplitCorColAccessCondFromFilters move the necessary filter in the form of index_col = corrlated_col to access conditions. // The function consider the `idx_col_1 = const and index_col_2 = cor_col and index_col_3 = const` case. // It enables more index columns to be considered. The range will be rebuilt in 'ResolveCorrelatedColumns'. -func (path *AccessPath) SplitCorColAccessCondFromFilters(ctx sessionctx.Context, eqOrInCount int) (access, remained []expression.Expression) { +func (path *AccessPath) SplitCorColAccessCondFromFilters(ctx context.PlanContext, eqOrInCount int) (access, remained []expression.Expression) { // The plan cache do not support subquery now. So we skip this function when // 'MaybeOverOptimized4PlanCache' function return true . if expression.MaybeOverOptimized4PlanCache(ctx, path.TableFilters) { @@ -217,8 +217,7 @@ func isColEqExpr(expr expression.Expression, col *expression.Column, checkFn fun } // OnlyPointRange checks whether each range is a point(no interval range exists). -func (path *AccessPath) OnlyPointRange(sctx sessionctx.Context) bool { - tc := sctx.GetSessionVars().StmtCtx.TypeCtx() +func (path *AccessPath) OnlyPointRange(tc types.Context) bool { if path.IsIntHandlePath { for _, ran := range path.Ranges { if !ran.IsPointNullable(tc) { @@ -241,7 +240,7 @@ type Col2Len map[int64]int // ExtractCol2Len collects index/table columns with lengths from expressions. If idxCols and idxColLens are not nil, it collects index columns with lengths(maybe prefix lengths). // Otherwise it collects table columns with full lengths. -func ExtractCol2Len(ctx sessionctx.Context, exprs []expression.Expression, idxCols []*expression.Column, idxColLens []int) Col2Len { +func ExtractCol2Len(ctx expression.EvalContext, exprs []expression.Expression, idxCols []*expression.Column, idxColLens []int) Col2Len { col2len := make(Col2Len, len(idxCols)) for _, expr := range exprs { extractCol2LenFromExpr(ctx, expr, idxCols, idxColLens, col2len) @@ -249,7 +248,7 @@ func ExtractCol2Len(ctx sessionctx.Context, exprs []expression.Expression, idxCo return col2len } -func extractCol2LenFromExpr(ctx sessionctx.Context, expr expression.Expression, idxCols []*expression.Column, idxColLens []int, col2Len Col2Len) { +func extractCol2LenFromExpr(ctx expression.EvalContext, expr expression.Expression, idxCols []*expression.Column, idxColLens []int, col2Len Col2Len) { switch v := expr.(type) { case *expression.Column: if idxCols == nil { @@ -333,7 +332,7 @@ func CompareCol2Len(c1, c2 Col2Len) (int, bool) { } // GetCol2LenFromAccessConds returns columns with lengths from path.AccessConds. -func (path *AccessPath) GetCol2LenFromAccessConds(ctx sessionctx.Context) Col2Len { +func (path *AccessPath) GetCol2LenFromAccessConds(ctx context.PlanContext) Col2Len { if path.IsTablePath() { return ExtractCol2Len(ctx, path.AccessConds, nil, nil) } diff --git a/pkg/planner/util/path_test.go b/pkg/planner/util/path_test.go index 596b8aa693906..36bc3f852c514 100644 --- a/pkg/planner/util/path_test.go +++ b/pkg/planner/util/path_test.go @@ -102,21 +102,22 @@ func TestOnlyPointRange(t *testing.T) { Collators: collate.GetBinaryCollatorSlice(1), } + tc := sctx.GetSessionVars().StmtCtx.TypeCtx() intHandlePath := &util.AccessPath{IsIntHandlePath: true} intHandlePath.Ranges = []*ranger.Range{&nullPointRange, &onePointRange} - require.True(t, intHandlePath.OnlyPointRange(sctx)) + require.True(t, intHandlePath.OnlyPointRange(tc)) intHandlePath.Ranges = []*ranger.Range{&onePointRange, &one2TwoRange} - require.False(t, intHandlePath.OnlyPointRange(sctx)) + require.False(t, intHandlePath.OnlyPointRange(tc)) indexPath := &util.AccessPath{Index: &model.IndexInfo{Columns: make([]*model.IndexColumn, 1)}} indexPath.Ranges = []*ranger.Range{&onePointRange} - require.True(t, indexPath.OnlyPointRange(sctx)) + require.True(t, indexPath.OnlyPointRange(tc)) indexPath.Ranges = []*ranger.Range{&nullPointRange, &onePointRange} - require.False(t, indexPath.OnlyPointRange(sctx)) + require.False(t, indexPath.OnlyPointRange(tc)) indexPath.Ranges = []*ranger.Range{&onePointRange, &one2TwoRange} - require.False(t, indexPath.OnlyPointRange(sctx)) + require.False(t, indexPath.OnlyPointRange(tc)) indexPath.Index.Columns = make([]*model.IndexColumn, 2) indexPath.Ranges = []*ranger.Range{&onePointRange} - require.False(t, indexPath.OnlyPointRange(sctx)) + require.False(t, indexPath.OnlyPointRange(tc)) } diff --git a/pkg/sessionctx/BUILD.bazel b/pkg/sessionctx/BUILD.bazel index 0def5d860d80f..06bb54efd5603 100644 --- a/pkg/sessionctx/BUILD.bazel +++ b/pkg/sessionctx/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//pkg/sessionctx/variable", "//pkg/statistics/handle/usage/indexusage", "//pkg/util", + "//pkg/util/context", "//pkg/util/kvcache", "//pkg/util/plancache", "//pkg/util/sli", diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go index da06fcfc1e0fd..d6304e000da05 100644 --- a/pkg/sessionctx/context.go +++ b/pkg/sessionctx/context.go @@ -16,7 +16,6 @@ package sessionctx import ( "context" - "fmt" "time" "github.com/pingcap/errors" @@ -29,6 +28,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/statistics/handle/usage/indexusage" "github.com/pingcap/tidb/pkg/util" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/kvcache" utilpc "github.com/pingcap/tidb/pkg/util/plancache" "github.com/pingcap/tidb/pkg/util/sli" @@ -67,6 +67,7 @@ type PlanCache interface { // Context is an interface for transaction and executive args environment. type Context interface { SessionStatesHandler + contextutil.ValueStoreContext // SetDiskFullOpt set the disk full opt when tikv disk full happened. SetDiskFullOpt(level kvrpcpb.DiskFullOpt) // RollbackTxn rolls back the current transaction. @@ -85,15 +86,6 @@ type Context interface { // GetMPPClient gets a kv.MPPClient. GetMPPClient() kv.MPPClient - // SetValue saves a value associated with this context for key. - SetValue(key fmt.Stringer, value any) - - // Value returns the value associated with this context for key. - Value(key fmt.Stringer) any - - // ClearValue clears the value associated with this context for key. - ClearValue(key fmt.Stringer) - // Deprecated: the semantics of session.GetInfoSchema() is ambiguous // If you want to get the infoschema of the current transaction in SQL layer, use sessiontxn.GetTxnManager(ctx).GetTxnInfoSchema() // If you want to get the latest infoschema use `GetDomainInfoSchema` diff --git a/pkg/statistics/BUILD.bazel b/pkg/statistics/BUILD.bazel index 6997f6ceb6d1e..7b8a5054788c8 100644 --- a/pkg/statistics/BUILD.bazel +++ b/pkg/statistics/BUILD.bazel @@ -30,6 +30,7 @@ go_library( "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/parser/terror", + "//pkg/planner/context", "//pkg/planner/util/debugtrace", "//pkg/sessionctx", "//pkg/sessionctx/stmtctx", diff --git a/pkg/statistics/cmsketch.go b/pkg/statistics/cmsketch.go index e3ca298d60b3a..9a38594ecf2e7 100644 --- a/pkg/statistics/cmsketch.go +++ b/pkg/statistics/cmsketch.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" @@ -258,7 +259,7 @@ func (c *CMSketch) SubValue(h1, h2 uint64, count uint64) { } // QueryValue is used to query the count of specified value. -func QueryValue(sctx sessionctx.Context, c *CMSketch, t *TopN, val types.Datum) (uint64, error) { +func QueryValue(sctx context.PlanContext, c *CMSketch, t *TopN, val types.Datum) (uint64, error) { var sc *stmtctx.StatementContext tz := time.UTC if sctx != nil { @@ -289,7 +290,7 @@ func (c *CMSketch) QueryBytes(d []byte) uint64 { } // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (c *CMSketch) queryHashValue(sctx sessionctx.Context, h1, h2 uint64) (result uint64) { +func (c *CMSketch) queryHashValue(sctx context.PlanContext, h1, h2 uint64) (result uint64) { vals := make([]uint32, c.depth) originVals := make([]uint32, c.depth) minValue := uint32(math.MaxUint32) @@ -628,7 +629,7 @@ type TopNMeta struct { // QueryTopN returns the results for (h1, h2) in murmur3.Sum128(), if not exists, return (0, false). // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (c *TopN) QueryTopN(sctx sessionctx.Context, d []byte) (result uint64, found bool) { +func (c *TopN) QueryTopN(sctx context.PlanContext, d []byte) (result uint64, found bool) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { @@ -692,7 +693,7 @@ func (c *TopN) LowerBound(d []byte) (idx int, match bool) { // BetweenCount estimates the row count for interval [l, r). // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (c *TopN) BetweenCount(sctx sessionctx.Context, l, r []byte) (result uint64) { +func (c *TopN) BetweenCount(sctx context.PlanContext, l, r []byte) (result uint64) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { diff --git a/pkg/statistics/column.go b/pkg/statistics/column.go index 507a72c2e04a3..0f634cd652c91 100644 --- a/pkg/statistics/column.go +++ b/pkg/statistics/column.go @@ -17,8 +17,8 @@ package statistics import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/logutil" @@ -142,7 +142,7 @@ var HistogramNeededItems = neededStatsMap{items: map[model.TableItemID]struct{}{ // If this column has histogram but not loaded yet, // then we mark it as need histogram. func (c *Column) IsInvalid( - sctx sessionctx.Context, + sctx context.PlanContext, collPseudo bool, ) (res bool) { var totalCount float64 diff --git a/pkg/statistics/debugtrace.go b/pkg/statistics/debugtrace.go index 2b6503a6437f8..3936feafe2fd2 100644 --- a/pkg/statistics/debugtrace.go +++ b/pkg/statistics/debugtrace.go @@ -17,8 +17,8 @@ package statistics import ( "slices" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "golang.org/x/exp/maps" ) @@ -183,7 +183,7 @@ type locateBucketInfo struct { } func debugTraceLocateBucket( - s sessionctx.Context, + s context.PlanContext, value *types.Datum, exceed bool, bucketIdx int, @@ -212,7 +212,7 @@ type bucketInfo struct { } // DebugTraceBuckets is used to trace the buckets used in the histogram. -func DebugTraceBuckets(s sessionctx.Context, hg *Histogram, bucketIdxs []int) { +func DebugTraceBuckets(s context.PlanContext, hg *Histogram, bucketIdxs []int) { root := debugtrace.GetOrInitDebugTraceRoot(s) buckets := make([]bucketInfo, len(bucketIdxs)) for i := range buckets { @@ -242,7 +242,7 @@ type topNRangeInfo struct { LastIdx int } -func debugTraceTopNRange(s sessionctx.Context, t *TopN, startIdx, endIdx int) { +func debugTraceTopNRange(s context.PlanContext, t *TopN, startIdx, endIdx int) { if endIdx <= startIdx { return } diff --git a/pkg/statistics/histogram.go b/pkg/statistics/histogram.go index c69eb18f3c08a..662889459d1e4 100644 --- a/pkg/statistics/histogram.go +++ b/pkg/statistics/histogram.go @@ -30,8 +30,8 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" @@ -465,7 +465,7 @@ func (hg *Histogram) ToString(idxCols int) string { // EqualRowCount estimates the row count where the column equals to value. // matched: return true if this returned row count is from Bucket.Repeat or bucket NDV, which is more accurate than if not. // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (hg *Histogram) EqualRowCount(sctx sessionctx.Context, value types.Datum, hasBucketNDV bool) (count float64, matched bool) { +func (hg *Histogram) EqualRowCount(sctx context.PlanContext, value types.Datum, hasBucketNDV bool) (count float64, matched bool) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { @@ -515,7 +515,7 @@ func (hg *Histogram) GreaterRowCount(value types.Datum) float64 { // locateBucket(val2): false, 2, false, false // locateBucket(val3): false, 2, true, false // locateBucket(val4): true, 3, false, false -func (hg *Histogram) LocateBucket(sctx sessionctx.Context, value types.Datum) (exceed bool, bucketIdx int, inBucket, matchLastValue bool) { +func (hg *Histogram) LocateBucket(sctx context.PlanContext, value types.Datum) (exceed bool, bucketIdx int, inBucket, matchLastValue bool) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { defer func() { debugTraceLocateBucket(sctx, &value, exceed, bucketIdx, inBucket, matchLastValue) @@ -548,7 +548,7 @@ func (hg *Histogram) LocateBucket(sctx sessionctx.Context, value types.Datum) (e // LessRowCountWithBktIdx estimates the row count where the column less than value. // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (hg *Histogram) LessRowCountWithBktIdx(sctx sessionctx.Context, value types.Datum) (result float64, bucketIdx int) { +func (hg *Histogram) LessRowCountWithBktIdx(sctx context.PlanContext, value types.Datum) (result float64, bucketIdx int) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { @@ -583,14 +583,14 @@ func (hg *Histogram) LessRowCountWithBktIdx(sctx sessionctx.Context, value types // LessRowCount estimates the row count where the column less than value. // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (hg *Histogram) LessRowCount(sctx sessionctx.Context, value types.Datum) float64 { +func (hg *Histogram) LessRowCount(sctx context.PlanContext, value types.Datum) float64 { result, _ := hg.LessRowCountWithBktIdx(sctx, value) return result } // BetweenRowCount estimates the row count where column greater or equal to a and less than b. // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (hg *Histogram) BetweenRowCount(sctx sessionctx.Context, a, b types.Datum) float64 { +func (hg *Histogram) BetweenRowCount(sctx context.PlanContext, a, b types.Datum) float64 { lessCountA := hg.LessRowCount(sctx, a) lessCountB := hg.LessRowCount(sctx, b) // If lessCountA is not less than lessCountB, it may be that they fall to the same bucket and we cannot estimate @@ -935,7 +935,7 @@ func (hg *Histogram) OutOfRange(val types.Datum) bool { // leftPercent = (math.Pow(actualR-boundL, 2) - math.Pow(actualL-boundL, 2)) / math.Pow(histWidth, 2) // You can find more details at https://github.com/pingcap/tidb/pull/47966#issuecomment-1778866876 func (hg *Histogram) OutOfRangeRowCount( - sctx sessionctx.Context, + sctx context.PlanContext, lDatum, rDatum *types.Datum, modifyCount, histNDV int64, ) (result float64) { diff --git a/pkg/statistics/index.go b/pkg/statistics/index.go index 6ae66914e8a3e..a7dd513a8d955 100644 --- a/pkg/statistics/index.go +++ b/pkg/statistics/index.go @@ -17,8 +17,8 @@ package statistics import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/twmb/murmur3" @@ -128,7 +128,7 @@ func (idx *Index) TotalRowCount() float64 { } // IsInvalid checks if this index is invalid. -func (idx *Index) IsInvalid(sctx sessionctx.Context, collPseudo bool) (res bool) { +func (idx *Index) IsInvalid(sctx context.PlanContext, collPseudo bool) (res bool) { idx.CheckStats() var totalCount float64 if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { @@ -181,7 +181,7 @@ func (idx *Index) MemoryUsage() CacheItemMemoryUsage { // QueryBytes is used to query the count of specified bytes. // The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (idx *Index) QueryBytes(sctx sessionctx.Context, d []byte) (result uint64) { +func (idx *Index) QueryBytes(sctx context.PlanContext, d []byte) (result uint64) { if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) defer func() { diff --git a/pkg/statistics/table.go b/pkg/statistics/table.go index e488dfff04c5c..b1e62cbfa10ce 100644 --- a/pkg/statistics/table.go +++ b/pkg/statistics/table.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/ranger" "go.uber.org/atomic" @@ -47,13 +47,13 @@ var ( // Note: all functions below will be removed after finishing moving all estimation functions into the cardinality package. // GetRowCountByIndexRanges is a function type to get row count by index ranges. - GetRowCountByIndexRanges func(sctx sessionctx.Context, coll *HistColl, idxID int64, indexRanges []*ranger.Range) (result float64, err error) + GetRowCountByIndexRanges func(sctx context.PlanContext, coll *HistColl, idxID int64, indexRanges []*ranger.Range) (result float64, err error) // GetRowCountByIntColumnRanges is a function type to get row count by int column ranges. - GetRowCountByIntColumnRanges func(sctx sessionctx.Context, coll *HistColl, colID int64, intRanges []*ranger.Range) (result float64, err error) + GetRowCountByIntColumnRanges func(sctx context.PlanContext, coll *HistColl, colID int64, intRanges []*ranger.Range) (result float64, err error) // GetRowCountByColumnRanges is a function type to get row count by column ranges. - GetRowCountByColumnRanges func(sctx sessionctx.Context, coll *HistColl, colID int64, colRanges []*ranger.Range) (result float64, err error) + GetRowCountByColumnRanges func(sctx context.PlanContext, coll *HistColl, colID int64, colRanges []*ranger.Range) (result float64, err error) ) // Table represents statistics for a table. diff --git a/pkg/table/temptable/infoschema.go b/pkg/table/temptable/infoschema.go index c42e2d0854f04..cca9207c394dd 100644 --- a/pkg/table/temptable/infoschema.go +++ b/pkg/table/temptable/infoschema.go @@ -16,11 +16,11 @@ package temptable import ( "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" ) // AttachLocalTemporaryTableInfoSchema attach local temporary table information schema to is -func AttachLocalTemporaryTableInfoSchema(sctx sessionctx.Context, is infoschema.InfoSchema) infoschema.InfoSchema { +func AttachLocalTemporaryTableInfoSchema(sctx variable.SessionVarsProvider, is infoschema.InfoSchema) infoschema.InfoSchema { localTemporaryTables := getLocalTemporaryTables(sctx) if localTemporaryTables == nil { return is @@ -47,7 +47,7 @@ func DetachLocalTemporaryTableInfoSchema(is infoschema.InfoSchema) infoschema.In return is } -func getLocalTemporaryTables(sctx sessionctx.Context) *infoschema.SessionTables { +func getLocalTemporaryTables(sctx variable.SessionVarsProvider) *infoschema.SessionTables { localTemporaryTables := sctx.GetSessionVars().LocalTemporaryTables if localTemporaryTables == nil { return nil @@ -56,7 +56,7 @@ func getLocalTemporaryTables(sctx sessionctx.Context) *infoschema.SessionTables return localTemporaryTables.(*infoschema.SessionTables) } -func ensureLocalTemporaryTables(sctx sessionctx.Context) *infoschema.SessionTables { +func ensureLocalTemporaryTables(sctx variable.SessionVarsProvider) *infoschema.SessionTables { sessVars := sctx.GetSessionVars() if sessVars.LocalTemporaryTables == nil { localTempTables := infoschema.NewSessionTables() diff --git a/pkg/util/context/BUILD.bazel b/pkg/util/context/BUILD.bazel index a879bb3b0ac9e..2a362524e55d7 100644 --- a/pkg/util/context/BUILD.bazel +++ b/pkg/util/context/BUILD.bazel @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "context", - srcs = ["warn.go"], + srcs = [ + "context.go", + "warn.go", + ], importpath = "github.com/pingcap/tidb/pkg/util/context", visibility = ["//visibility:public"], ) diff --git a/pkg/util/context/context.go b/pkg/util/context/context.go new file mode 100644 index 0000000000000..cfea2a65ca359 --- /dev/null +++ b/pkg/util/context/context.go @@ -0,0 +1,31 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "fmt" +) + +// ValueStoreContext is a context that can store values. +type ValueStoreContext interface { + // SetValue saves a value associated with this context for key. + SetValue(key fmt.Stringer, value any) + + // Value returns the value associated with this context for key. + Value(key fmt.Stringer) any + + // ClearValue clears the value associated with this context for key. + ClearValue(key fmt.Stringer) +} diff --git a/pkg/util/ranger/BUILD.bazel b/pkg/util/ranger/BUILD.bazel index 0953357f764ad..f1be0c5c6d121 100644 --- a/pkg/util/ranger/BUILD.bazel +++ b/pkg/util/ranger/BUILD.bazel @@ -12,6 +12,7 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/util/ranger", visibility = ["//visibility:public"], deps = [ + "//pkg/errctx", "//pkg/expression", "//pkg/kv", "//pkg/parser/ast", @@ -20,6 +21,7 @@ go_library( "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/parser/terror", + "//pkg/planner/context", "//pkg/planner/util/fixcontrol", "//pkg/sessionctx", "//pkg/sessionctx/stmtctx", diff --git a/pkg/util/ranger/checker.go b/pkg/util/ranger/checker.go index 7fe2131ed162f..32f2a5c3125d5 100644 --- a/pkg/util/ranger/checker.go +++ b/pkg/util/ranger/checker.go @@ -18,14 +18,13 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/collate" ) // conditionChecker checks if this condition can be pushed to index planner. type conditionChecker struct { - ctx sessionctx.Context + ctx expression.EvalContext checkerCol *expression.Column length int optPrefixIndexSingleScan bool diff --git a/pkg/util/ranger/detacher.go b/pkg/util/ranger/detacher.go index b1b67222c7765..a1f56792b495a 100644 --- a/pkg/util/ranger/detacher.go +++ b/pkg/util/ranger/detacher.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" + planctx "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" @@ -32,7 +33,7 @@ import ( // detachColumnCNFConditions detaches the condition for calculating range from the other conditions. // Please make sure that the top level is CNF form. -func detachColumnCNFConditions(sctx sessionctx.Context, conditions []expression.Expression, checker *conditionChecker) ([]expression.Expression, []expression.Expression) { +func detachColumnCNFConditions(sctx expression.BuildContext, conditions []expression.Expression, checker *conditionChecker) ([]expression.Expression, []expression.Expression) { var accessConditions, filterConditions []expression.Expression //nolint: prealloc for _, cond := range conditions { if sf, ok := cond.(*expression.ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { @@ -65,7 +66,7 @@ func detachColumnCNFConditions(sctx sessionctx.Context, conditions []expression. // detachColumnDNFConditions detaches the condition for calculating range from the other conditions. // Please make sure that the top level is DNF form. -func detachColumnDNFConditions(sctx sessionctx.Context, conditions []expression.Expression, checker *conditionChecker) ([]expression.Expression, bool) { +func detachColumnDNFConditions(sctx expression.BuildContext, conditions []expression.Expression, checker *conditionChecker) ([]expression.Expression, bool) { var ( hasResidualConditions bool accessConditions []expression.Expression @@ -101,7 +102,7 @@ func detachColumnDNFConditions(sctx sessionctx.Context, conditions []expression. // in function which is `column in (constant list)`. // If so, it will return the offset of this column in the slice, otherwise return -1 for not found. // Since combining `x >= 2` and `x <= 2` can lead to an eq condition `x = 2`, we take le/ge/lt/gt into consideration. -func getPotentialEqOrInColOffset(sctx sessionctx.Context, expr expression.Expression, cols []*expression.Column) int { +func getPotentialEqOrInColOffset(sctx planctx.PlanContext, expr expression.Expression, cols []*expression.Column) int { f, ok := expr.(*expression.ScalarFunction) if !ok { return -1 @@ -195,7 +196,7 @@ type cnfItemRangeResult struct { minColNum int } -func getCNFItemRangeResult(sctx sessionctx.Context, rangeResult *DetachRangeResult, offset int) *cnfItemRangeResult { +func getCNFItemRangeResult(sctx planctx.PlanContext, rangeResult *DetachRangeResult, offset int) *cnfItemRangeResult { sameLenPointRanges := true var maxColNum, minColNum int for i, ran := range rangeResult.Ranges { @@ -240,7 +241,7 @@ func compareCNFItemRangeResult(curResult, bestResult *cnfItemRangeResult) (curIs // item ranges. // e.g, for input CNF expressions ((a,b) in ((1,1),(2,2))) and a > 1 and ((a,b,c) in (1,1,1),(2,2,2)) // ((a,b,c) in (1,1,1),(2,2,2)) would be extracted. -func extractBestCNFItemRanges(sctx sessionctx.Context, conds []expression.Expression, cols []*expression.Column, +func extractBestCNFItemRanges(sctx planctx.PlanContext, conds []expression.Expression, cols []*expression.Column, lengths []int, rangeMaxSize int64, convertToSortKey bool) (*cnfItemRangeResult, []*valueInfo, error) { if len(conds) < 2 { return nil, nil, nil @@ -605,7 +606,7 @@ func extractValueInfo(expr expression.Expression) *valueInfo { // // columnValues: the constant column values for all index columns. columnValues[i] is nil if cols[i] is not constant. // bool: indicate whether there's nil range when merging eq and in conditions. -func ExtractEqAndInCondition(sctx sessionctx.Context, conditions []expression.Expression, cols []*expression.Column, +func ExtractEqAndInCondition(sctx planctx.PlanContext, conditions []expression.Expression, cols []*expression.Column, lengths []int) ([]expression.Expression, []expression.Expression, []expression.Expression, []*valueInfo, bool) { var filters []expression.Expression rb := builder{sctx: sctx} @@ -864,7 +865,7 @@ type DetachRangeResult struct { // rangeMaxSize is the max memory limit for ranges. O indicates no memory limit. If you ask that all conditions must be used // for building ranges, set rangeMemQuota to 0 to avoid range fallback. // The returned values are encapsulated into a struct DetachRangeResult, see its comments for explanation. -func DetachCondAndBuildRangeForIndex(sctx sessionctx.Context, conditions []expression.Expression, cols []*expression.Column, +func DetachCondAndBuildRangeForIndex(sctx planctx.PlanContext, conditions []expression.Expression, cols []*expression.Column, lengths []int, rangeMaxSize int64) (*DetachRangeResult, error) { d := &rangeDetacher{ sctx: sctx, @@ -880,7 +881,7 @@ func DetachCondAndBuildRangeForIndex(sctx sessionctx.Context, conditions []expre // detachCondAndBuildRangeWithoutMerging detaches the index filters from table filters and uses them to build ranges. // When building ranges, it doesn't merge consecutive ranges. -func detachCondAndBuildRangeWithoutMerging(sctx sessionctx.Context, conditions []expression.Expression, cols []*expression.Column, +func detachCondAndBuildRangeWithoutMerging(sctx planctx.PlanContext, conditions []expression.Expression, cols []*expression.Column, lengths []int, rangeMaxSize int64, convertToSortKey bool) (*DetachRangeResult, error) { d := &rangeDetacher{ sctx: sctx, @@ -898,13 +899,13 @@ func detachCondAndBuildRangeWithoutMerging(sctx sessionctx.Context, conditions [ // rangeMaxSize is the max memory limit for ranges. O indicates no memory limit. If you ask that all conditions must be used // for building ranges, set rangeMemQuota to 0 to avoid range fallback. // The returned values are encapsulated into a struct DetachRangeResult, see its comments for explanation. -func DetachCondAndBuildRangeForPartition(sctx sessionctx.Context, conditions []expression.Expression, cols []*expression.Column, +func DetachCondAndBuildRangeForPartition(sctx planctx.PlanContext, conditions []expression.Expression, cols []*expression.Column, lengths []int, rangeMaxSize int64) (*DetachRangeResult, error) { return detachCondAndBuildRangeWithoutMerging(sctx, conditions, cols, lengths, rangeMaxSize, false) } type rangeDetacher struct { - sctx sessionctx.Context + sctx planctx.PlanContext allConds []expression.Expression cols []*expression.Column lengths []int @@ -986,7 +987,7 @@ func AppendConditionsIfNotExist(conditions, condsToAppend []expression.Expressio // ExtractAccessConditionsForColumn extracts the access conditions used for range calculation. Since // we don't need to return the remained filter conditions, it is much simpler than DetachCondsForColumn. -func ExtractAccessConditionsForColumn(ctx sessionctx.Context, conds []expression.Expression, col *expression.Column) []expression.Expression { +func ExtractAccessConditionsForColumn(ctx planctx.PlanContext, conds []expression.Expression, col *expression.Column) []expression.Expression { checker := conditionChecker{ checkerCol: col, length: types.UnspecifiedLength, @@ -1002,7 +1003,7 @@ func ExtractAccessConditionsForColumn(ctx sessionctx.Context, conds []expression } // DetachCondsForColumn detaches access conditions for specified column from other filter conditions. -func DetachCondsForColumn(sctx sessionctx.Context, conds []expression.Expression, col *expression.Column) (accessConditions, otherConditions []expression.Expression) { +func DetachCondsForColumn(sctx planctx.PlanContext, conds []expression.Expression, col *expression.Column) (accessConditions, otherConditions []expression.Expression) { checker := &conditionChecker{ checkerCol: col, length: types.UnspecifiedLength, @@ -1014,7 +1015,7 @@ func DetachCondsForColumn(sctx sessionctx.Context, conds []expression.Expression // MergeDNFItems4Col receives a slice of DNF conditions, merges some of them which can be built into ranges on a single column, then returns. // For example, [a > 5, b > 6, c > 7, a = 1, b > 3] will become [a > 5 or a = 1, b > 6 or b > 3, c > 7]. -func MergeDNFItems4Col(ctx sessionctx.Context, dnfItems []expression.Expression) []expression.Expression { +func MergeDNFItems4Col(ctx planctx.PlanContext, dnfItems []expression.Expression) []expression.Expression { mergedDNFItems := make([]expression.Expression, 0, len(dnfItems)) col2DNFItems := make(map[int64][]expression.Expression) for _, dnfItem := range dnfItems { @@ -1067,7 +1068,7 @@ func MergeDNFItems4Col(ctx sessionctx.Context, dnfItems []expression.Expression) // @retval - []expression.Expression the new conditions after adding `tidb_shard() = xxx` prefix // // error if error gernerated, return error -func AddGcColumnCond(sctx sessionctx.Context, +func AddGcColumnCond(sctx planctx.PlanContext, cols []*expression.Column, accessesCond []expression.Expression, columnValues []*valueInfo) ([]expression.Expression, error) { @@ -1090,7 +1091,7 @@ func AddGcColumnCond(sctx sessionctx.Context, // @retval - []expression.Expression the new conditions after adding `tidb_shard() = xxx` prefix // // error if error gernerated, return error -func AddGcColumn4InCond(sctx sessionctx.Context, +func AddGcColumn4InCond(sctx planctx.PlanContext, cols []*expression.Column, accessesCond []expression.Expression) ([]expression.Expression, error) { var errRes error @@ -1158,7 +1159,7 @@ func AddGcColumn4InCond(sctx sessionctx.Context, // // []*valueInfo the values of every columns in the returned new conditions // error if error gernerated, return error -func AddGcColumn4EqCond(sctx sessionctx.Context, +func AddGcColumn4EqCond(sctx planctx.PlanContext, cols []*expression.Column, accessesCond []expression.Expression, columnValues []*valueInfo) ([]expression.Expression, error) { @@ -1200,7 +1201,7 @@ func AddGcColumn4EqCond(sctx sessionctx.Context, // @param[in] cols the columns of shard index, such as [tidb_shard(a), a, ...] // @param[in] lengths the length for every column of shard index // @retval - the new condition after adding tidb_shard() prefix -func AddExpr4EqAndInCondition(sctx sessionctx.Context, conditions []expression.Expression, +func AddExpr4EqAndInCondition(sctx planctx.PlanContext, conditions []expression.Expression, cols []*expression.Column) ([]expression.Expression, error) { accesses := make([]expression.Expression, len(cols)) columnValues := make([]*valueInfo, len(cols)) diff --git a/pkg/util/ranger/points.go b/pkg/util/ranger/points.go index 0d3b4e8e26893..05752d6f76ea4 100644 --- a/pkg/util/ranger/points.go +++ b/pkg/util/ranger/points.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" + planctx "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" @@ -130,7 +130,7 @@ func rangePointEqualValueLess(a, b *point) bool { return a.excl && !b.excl } -func pointsConvertToSortKey(sctx sessionctx.Context, inputPs []*point, newTp *types.FieldType) ([]*point, error) { +func pointsConvertToSortKey(sctx planctx.PlanContext, inputPs []*point, newTp *types.FieldType) ([]*point, error) { // Only handle normal string type here. // Currently, set won't be pushed down and it shouldn't reach here in theory. // For enum, we have separate logic for it, like handleEnumFromBinOp(). For now, it only supports point range, @@ -152,7 +152,7 @@ func pointsConvertToSortKey(sctx sessionctx.Context, inputPs []*point, newTp *ty } func pointConvertToSortKey( - sctx sessionctx.Context, + sctx planctx.PlanContext, inputP *point, newTp *types.FieldType, trimTrailingSpace bool, @@ -223,7 +223,7 @@ func NullRange() Ranges { // builder is the range builder struct. type builder struct { err error - sctx sessionctx.Context + sctx planctx.PlanContext } // build converts Expression on one column into point, which can be further built into Range. diff --git a/pkg/util/ranger/ranger.go b/pkg/util/ranger/ranger.go index 01a381c5f6b26..02498a340afc7 100644 --- a/pkg/util/ranger/ranger.go +++ b/pkg/util/ranger/ranger.go @@ -19,9 +19,11 @@ import ( "math" "regexp" "slices" + "time" "unicode/utf8" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" @@ -29,7 +31,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/format" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" + planctx "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/types" driver "github.com/pingcap/tidb/pkg/types/parser_driver" @@ -37,18 +39,17 @@ import ( "github.com/pingcap/tidb/pkg/util/collate" ) -func validInterval(sctx sessionctx.Context, low, high *point) (bool, error) { - sc := sctx.GetSessionVars().StmtCtx - l, err := codec.EncodeKey(sc.TimeZone(), nil, low.value) - err = sc.HandleError(err) +func validInterval(ec errctx.Context, loc *time.Location, low, high *point) (bool, error) { + l, err := codec.EncodeKey(loc, nil, low.value) + err = ec.HandleError(err) if err != nil { return false, errors.Trace(err) } if low.excl { l = kv.Key(l).PrefixNext() } - r, err := codec.EncodeKey(sc.TimeZone(), nil, high.value) - err = sc.HandleError(err) + r, err := codec.EncodeKey(loc, nil, high.value) + err = ec.HandleError(err) if err != nil { return false, errors.Trace(err) } @@ -60,7 +61,7 @@ func validInterval(sctx sessionctx.Context, low, high *point) (bool, error) { // convertPoints does some preprocessing on rangePoints to make them ready to build ranges. Preprocessing includes converting // points to the specified type, validating intervals and skipping impossible intervals. -func convertPoints(sctx sessionctx.Context, rangePoints []*point, newTp *types.FieldType, skipNull bool, tableRange bool) ([]*point, error) { +func convertPoints(sctx planctx.PlanContext, rangePoints []*point, newTp *types.FieldType, skipNull bool, tableRange bool) ([]*point, error) { i := 0 numPoints := len(rangePoints) var minValueDatum, maxValueDatum types.Datum @@ -100,7 +101,8 @@ func convertPoints(sctx sessionctx.Context, rangePoints []*point, newTp *types.F if skipNull && endPoint.value.Kind() == types.KindNull { continue } - less, err := validInterval(sctx, startPoint, endPoint) + sc := sctx.GetSessionVars().StmtCtx + less, err := validInterval(sc.ErrCtx(), sc.TimeZone(), startPoint, endPoint) if err != nil { return nil, errors.Trace(err) } @@ -124,7 +126,7 @@ func estimateMemUsageForPoints2Ranges(rangePoints []*point) int64 { // Only one column is built there. If there're multiple columns, use appendPoints2Ranges. // rangeMaxSize is the max memory limit for ranges. O indicates no memory limit. // If the second return value is true, it means that the estimated memory usage of ranges exceeds rangeMaxSize and it falls back to full range. -func points2Ranges(sctx sessionctx.Context, rangePoints []*point, newTp *types.FieldType, rangeMaxSize int64) (Ranges, bool, error) { +func points2Ranges(sctx planctx.PlanContext, rangePoints []*point, newTp *types.FieldType, rangeMaxSize int64) (Ranges, bool, error) { convertedPoints, err := convertPoints(sctx, rangePoints, newTp, mysql.HasNotNullFlag(newTp.GetFlag()), false) if err != nil { return nil, false, errors.Trace(err) @@ -154,7 +156,7 @@ func points2Ranges(sctx sessionctx.Context, rangePoints []*point, newTp *types.F return ranges, false, nil } -func convertPoint(sctx sessionctx.Context, point *point, newTp *types.FieldType) (*point, error) { +func convertPoint(sctx planctx.PlanContext, point *point, newTp *types.FieldType) (*point, error) { sc := sctx.GetSessionVars().StmtCtx switch point.value.Kind() { case types.KindMaxValue, types.KindMinNotNull: @@ -271,7 +273,7 @@ func estimateMemUsageForAppendPoints2Ranges(origin Ranges, rangePoints []*point) // rangeMaxSize is the max memory limit for ranges. O indicates no memory limit. // If the second return value is true, it means that the estimated memory usage of ranges after appending points exceeds // rangeMaxSize and the function rejects appending points to ranges. -func appendPoints2Ranges(sctx sessionctx.Context, origin Ranges, rangePoints []*point, +func appendPoints2Ranges(sctx planctx.PlanContext, origin Ranges, rangePoints []*point, newTp *types.FieldType, rangeMaxSize int64) (Ranges, bool, error) { convertedPoints, err := convertPoints(sctx, rangePoints, newTp, false, false) if err != nil { @@ -384,7 +386,7 @@ func AppendRanges2PointRanges(pointRanges Ranges, ranges Ranges, rangeMaxSize in // It will remove the nil and convert MinNotNull and MaxValue to MinInt64 or MinUint64 and MaxInt64 or MaxUint64. // rangeMaxSize is the max memory limit for ranges. O indicates no memory limit. // If the second return value is true, it means that the estimated memory usage of ranges exceeds rangeMaxSize and it falls back to full range. -func points2TableRanges(sctx sessionctx.Context, rangePoints []*point, newTp *types.FieldType, rangeMaxSize int64) (Ranges, bool, error) { +func points2TableRanges(sctx planctx.PlanContext, rangePoints []*point, newTp *types.FieldType, rangeMaxSize int64) (Ranges, bool, error) { convertedPoints, err := convertPoints(sctx, rangePoints, newTp, true, true) if err != nil { return nil, false, errors.Trace(err) @@ -410,7 +412,7 @@ func points2TableRanges(sctx sessionctx.Context, rangePoints []*point, newTp *ty // buildColumnRange builds range from CNF conditions. // rangeMaxSize is the max memory limit for ranges. O indicates no memory limit. // The second return value is the conditions used to build ranges and the third return value is the remained conditions. -func buildColumnRange(accessConditions []expression.Expression, sctx sessionctx.Context, tp *types.FieldType, tableRange bool, +func buildColumnRange(accessConditions []expression.Expression, sctx planctx.PlanContext, tp *types.FieldType, tableRange bool, colLen int, rangeMaxSize int64) (Ranges, []expression.Expression, []expression.Expression, error) { rb := builder{sctx: sctx} newTp := newFieldType(tp) @@ -455,7 +457,7 @@ func buildColumnRange(accessConditions []expression.Expression, sctx sessionctx. // The second return value is the conditions used to build ranges and the third return value is the remained conditions. // If you use the function to build ranges for some access path, you need to update the path's access conditions and filter // conditions by the second and third return values respectively. -func BuildTableRange(accessConditions []expression.Expression, sctx sessionctx.Context, tp *types.FieldType, +func BuildTableRange(accessConditions []expression.Expression, sctx planctx.PlanContext, tp *types.FieldType, rangeMaxSize int64) (Ranges, []expression.Expression, []expression.Expression, error) { return buildColumnRange(accessConditions, sctx, tp, true, types.UnspecifiedLength, rangeMaxSize) } @@ -466,7 +468,7 @@ func BuildTableRange(accessConditions []expression.Expression, sctx sessionctx.C // The second return value is the conditions used to build ranges and the third return value is the remained conditions. // If you use the function to build ranges for some access path, you need to update the path's access conditions and filter // conditions by the second and third return values respectively. -func BuildColumnRange(conds []expression.Expression, sctx sessionctx.Context, tp *types.FieldType, colLen int, +func BuildColumnRange(conds []expression.Expression, sctx planctx.PlanContext, tp *types.FieldType, colLen int, rangeMemQuota int64) (Ranges, []expression.Expression, []expression.Expression, error) { if len(conds) == 0 { return FullRange(), nil, nil, nil @@ -578,7 +580,7 @@ type sortRange struct { // For two intervals [a, b], [c, d], we have guaranteed that a <= c. If b >= c. Then two intervals are overlapped. // And this two can be merged as [a, max(b, d)]. // Otherwise they aren't overlapped. -func UnionRanges(sctx sessionctx.Context, ranges Ranges, mergeConsecutive bool) (Ranges, error) { +func UnionRanges(sctx planctx.PlanContext, ranges Ranges, mergeConsecutive bool) (Ranges, error) { sc := sctx.GetSessionVars().StmtCtx if len(ranges) == 0 { return nil, nil @@ -725,7 +727,7 @@ func newFieldType(tp *types.FieldType) *types.FieldType { // 'points'. `col` is the target column to construct the Equal or In condition. // NOTE: // 1. 'points' should not be empty. -func points2EqOrInCond(ctx sessionctx.Context, points []*point, col *expression.Column) expression.Expression { +func points2EqOrInCond(ctx expression.BuildContext, points []*point, col *expression.Column) expression.Expression { // len(points) cannot be 0 here, since we impose early termination in ExtractEqAndInCondition // Constant and Column args should have same RetType, simply get from first arg retType := col.GetType() diff --git a/pkg/util/ranger/types.go b/pkg/util/ranger/types.go index d55e14c04d2b2..dab82fabc2813 100644 --- a/pkg/util/ranger/types.go +++ b/pkg/util/ranger/types.go @@ -18,12 +18,13 @@ import ( "fmt" "math" "strings" + "time" "unsafe" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + planctx "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/collate" @@ -95,7 +96,7 @@ func (ran *Range) Clone() *Range { } // IsPoint returns if the range is a point. -func (ran *Range) IsPoint(sctx sessionctx.Context) bool { +func (ran *Range) IsPoint(sctx planctx.PlanContext) bool { return ran.isPoint(sctx.GetSessionVars().StmtCtx.TypeCtx(), sctx.GetSessionVars().RegardNULLAsPoint) } @@ -193,18 +194,18 @@ func (ran *Range) String() string { } // Encode encodes the range to its encoded value. -func (ran *Range) Encode(sc *stmtctx.StatementContext, lowBuffer, highBuffer []byte) ([]byte, []byte, error) { +func (ran *Range) Encode(ec errctx.Context, loc *time.Location, lowBuffer, highBuffer []byte) ([]byte, []byte, error) { var err error - lowBuffer, err = codec.EncodeKey(sc.TimeZone(), lowBuffer[:0], ran.LowVal...) - err = sc.HandleError(err) + lowBuffer, err = codec.EncodeKey(loc, lowBuffer[:0], ran.LowVal...) + err = ec.HandleError(err) if err != nil { return nil, nil, err } if ran.LowExclude { lowBuffer = kv.Key(lowBuffer).PrefixNext() } - highBuffer, err = codec.EncodeKey(sc.TimeZone(), highBuffer[:0], ran.HighVal...) - err = sc.HandleError(err) + highBuffer, err = codec.EncodeKey(loc, highBuffer[:0], ran.HighVal...) + err = ec.HandleError(err) if err != nil { return nil, nil, err } @@ -216,10 +217,10 @@ func (ran *Range) Encode(sc *stmtctx.StatementContext, lowBuffer, highBuffer []b // PrefixEqualLen tells you how long the prefix of the range is a point. // e.g. If this range is (1 2 3, 1 2 +inf), then the return value is 2. -func (ran *Range) PrefixEqualLen(sc *stmtctx.StatementContext) (int, error) { +func (ran *Range) PrefixEqualLen(tc types.Context) (int, error) { // Here, len(ran.LowVal) always equal to len(ran.HighVal) for i := 0; i < len(ran.LowVal); i++ { - cmp, err := ran.LowVal[i].Compare(sc.TypeCtx(), &ran.HighVal[i], ran.Collators[i]) + cmp, err := ran.LowVal[i].Compare(tc, &ran.HighVal[i], ran.Collators[i]) if err != nil { return 0, errors.Trace(err) }