Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: introduce PlanContext to provide context for planner phase #51074

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/domain/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ go_library(
"//pkg/types",
"//pkg/util",
"//pkg/util/chunk",
"//pkg/util/context",
"//pkg/util/dbterror",
"//pkg/util/disttask",
"//pkg/util/domainutil",
Expand Down
6 changes: 3 additions & 3 deletions pkg/domain/domainctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package domain

import (
"github.com/pingcap/tidb/pkg/sessionctx"
contextutil "github.com/pingcap/tidb/pkg/util/context"
)

// domainKeyType is a dummy type to avoid naming collision in context.
Expand All @@ -29,12 +29,12 @@ func (domainKeyType) String() string {
const domainKey domainKeyType = 0

// BindDomain binds domain to context.
func BindDomain(ctx sessionctx.Context, domain *Domain) {
func BindDomain(ctx contextutil.ValueStoreContext, domain *Domain) {
ctx.SetValue(domainKey, domain)
}

// GetDomain gets domain from context.
func GetDomain(ctx sessionctx.Context) *Domain {
func GetDomain(ctx contextutil.ValueStoreContext) *Domain {
v, ok := ctx.Value(domainKey).(*Domain)
if !ok {
return nil
Expand Down
1 change: 1 addition & 0 deletions pkg/executor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ go_library(
"//pkg/parser/types",
"//pkg/planner",
"//pkg/planner/cardinality",
"//pkg/planner/context",
"//pkg/planner/core",
"//pkg/planner/util",
"//pkg/planner/util/fixcontrol",
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/compact_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const (
func getTiFlashStores(ctx sessionctx.Context) ([]infoschema.ServerInfo, error) {
// TODO: Don't use infoschema, to preserve StoreID information.
aliveTiFlashStores := make([]infoschema.ServerInfo, 0)
stores, err := infoschema.GetStoreServerInfo(ctx)
stores, err := infoschema.GetStoreServerInfo(ctx.GetStore())
if err != nil {
return nil, err
}
Expand Down
12 changes: 10 additions & 2 deletions pkg/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
planctx "github.com/pingcap/tidb/pkg/planner/context"
plannercore "github.com/pingcap/tidb/pkg/planner/core"
"github.com/pingcap/tidb/pkg/planner/util/fixcontrol"
"github.com/pingcap/tidb/pkg/privilege"
Expand All @@ -74,6 +75,7 @@ import (
"github.com/pingcap/tidb/pkg/util/deadlockhistory"
"github.com/pingcap/tidb/pkg/util/disk"
"github.com/pingcap/tidb/pkg/util/execdetails"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/logutil/consistency"
"github.com/pingcap/tidb/pkg/util/memory"
Expand Down Expand Up @@ -1457,9 +1459,9 @@ func init() {
// While doing optimization in the plan package, we need to execute uncorrelated subquery,
// but the plan package cannot import the executor package because of the dependency cycle.
// So we assign a function implemented in the executor package to the plan package to avoid the dependency cycle.
plannercore.EvalSubqueryFirstRow = func(ctx context.Context, p plannercore.PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) ([]types.Datum, error) {
plannercore.EvalSubqueryFirstRow = func(ctx context.Context, p plannercore.PhysicalPlan, is infoschema.InfoSchema, pctx planctx.PlanContext) ([]types.Datum, error) {
defer func(begin time.Time) {
s := sctx.GetSessionVars()
s := pctx.GetSessionVars()
s.StmtCtx.SetSkipPlanCache(errors.NewNoStackError("query has uncorrelated sub-queries is un-cacheable"))
s.RewritePhaseInfo.PreprocessSubQueries++
s.RewritePhaseInfo.DurationPreprocessSubQuery += time.Since(begin)
Expand All @@ -1468,6 +1470,12 @@ func init() {
r, ctx := tracing.StartRegionEx(ctx, "executor.EvalSubQuery")
defer r.End()

sctx, ok := pctx.(sessionctx.Context)
intest.Assert(ok)
if !ok {
return nil, errors.New("plan context should be sessionctx.Context to EvalSubqueryFirstRow")
}

e := newExecutorBuilder(sctx, is, nil)
executor := e.build(p)
if e.err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3130,7 +3130,7 @@ func (e *TiFlashSystemTableRetriever) retrieve(ctx context.Context, sctx session
}

func (e *TiFlashSystemTableRetriever) initialize(sctx sessionctx.Context, tiflashInstances set.StringSet) error {
storeInfo, err := infoschema.GetStoreServerInfo(sctx)
storeInfo, err := infoschema.GetStoreServerInfo(sctx.GetStore())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/memtable_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (e *clusterServerInfoRetriever) retrieve(ctx context.Context, sctx sessionc
return nil, err
}
serversInfo = infoschema.FilterClusterServerInfo(serversInfo, e.extractor.NodeTypes, e.extractor.Instances)
return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx, serversInfo, e.serverInfoType, true)
return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, e.serverInfoType, true)
}

func parseFailpointServerInfo(s string) []infoschema.ServerInfo {
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/metrics_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (e *MetricRetriever) queryMetric(ctx context.Context, sctx sessionctx.Conte
promQLAPI := promv1.NewAPI(promClient)
ctx, cancel := context.WithTimeout(ctx, promReadTimeout)
defer cancel()
promQL := e.tblDef.GenPromQL(sctx, e.extractor.LabelConditions, quantile)
promQL := e.tblDef.GenPromQL(sctx.GetSessionVars().MetricSchemaRangeDuration, e.extractor.LabelConditions, quantile)

// Add retry to avoid network error.
for i := 0; i < 5; i++ {
Expand Down
5 changes: 2 additions & 3 deletions pkg/infoschema/metrics_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/pingcap/tidb/pkg/meta/autoid"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/set"
Expand Down Expand Up @@ -102,11 +101,11 @@ func (def *MetricTableDef) genColumnInfos() []columnInfo {
}

// GenPromQL generates the promQL.
func (def *MetricTableDef) GenPromQL(sctx sessionctx.Context, labels map[string]set.StringSet, quantile float64) string {
func (def *MetricTableDef) GenPromQL(metricsSchemaRangeDuration int64, labels map[string]set.StringSet, quantile float64) string {
promQL := def.PromQL
promQL = strings.ReplaceAll(promQL, promQLQuantileKey, strconv.FormatFloat(quantile, 'f', -1, 64))
promQL = strings.ReplaceAll(promQL, promQLLabelConditionKey, def.genLabelCondition(labels))
promQL = strings.ReplaceAll(promQL, promQRangeDurationKey, strconv.FormatInt(sctx.GetSessionVars().MetricSchemaRangeDuration, 10)+"s")
promQL = strings.ReplaceAll(promQL, promQRangeDurationKey, strconv.FormatInt(metricsSchemaRangeDuration, 10)+"s")
return promQL
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/infoschema/perfschema/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGorout
)
switch nodeType {
case "tikv":
servers, err = infoschema.GetStoreServerInfo(ctx)
servers, err = infoschema.GetStoreServerInfo(ctx.GetStore())
case "pd":
servers, err = infoschema.GetPDServerInfo(ctx)
default:
Expand Down
17 changes: 9 additions & 8 deletions pkg/infoschema/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,9 @@ func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) {
type retriever func(ctx sessionctx.Context) ([]ServerInfo, error)
//nolint: prealloc
var servers []ServerInfo
for _, r := range []retriever{GetTiDBServerInfo, GetPDServerInfo, GetStoreServerInfo, GetTiProxyServerInfo} {
for _, r := range []retriever{GetTiDBServerInfo, GetPDServerInfo, func(ctx sessionctx.Context) ([]ServerInfo, error) {
return GetStoreServerInfo(ctx.GetStore())
}, GetTiProxyServerInfo} {
nodes, err := r(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1968,7 +1970,7 @@ func isTiFlashWriteNode(store *metapb.Store) bool {
}

// GetStoreServerInfo returns all store nodes(TiKV or TiFlash) cluster information
func GetStoreServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) {
func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) {
failpoint.Inject("mockStoreServerInfo", func(val failpoint.Value) {
if s := val.(string); len(s) > 0 {
var servers []ServerInfo
Expand All @@ -1987,7 +1989,6 @@ func GetStoreServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) {
}
})

store := ctx.GetStore()
// Get TiKV servers info.
tikvStore, ok := store.(tikv.Storage)
if !ok {
Expand Down Expand Up @@ -2051,7 +2052,7 @@ func GetTiFlashStoreCount(ctx sessionctx.Context) (cnt uint64, err error) {
}
})

stores, err := GetStoreServerInfo(ctx)
stores, err := GetStoreServerInfo(ctx.GetStore())
if err != nil {
return cnt, err
}
Expand Down Expand Up @@ -2426,11 +2427,11 @@ func (vt *VirtualTable) Type() table.Type {
}

// GetTiFlashServerInfo returns all TiFlash server infos
func GetTiFlashServerInfo(sctx sessionctx.Context) ([]ServerInfo, error) {
func GetTiFlashServerInfo(store kv.Storage) ([]ServerInfo, error) {
if config.GetGlobalConfig().DisaggregatedTiFlash {
return nil, table.ErrUnsupportedOp
}
serversInfo, err := GetStoreServerInfo(sctx)
serversInfo, err := GetStoreServerInfo(store)
if err != nil {
return nil, err
}
Expand All @@ -2439,7 +2440,7 @@ func GetTiFlashServerInfo(sctx sessionctx.Context) ([]ServerInfo, error) {
}

// FetchClusterServerInfoWithoutPrivilegeCheck fetches cluster server information
func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, sctx sessionctx.Context, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) {
func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, vars *variable.SessionVars, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) {
type result struct {
idx int
rows [][]types.Datum
Expand Down Expand Up @@ -2476,7 +2477,7 @@ func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, sctx sessi
for result := range ch {
if result.err != nil {
if recordWarningInStmtCtx {
sctx.GetSessionVars().StmtCtx.AppendWarning(result.err)
vars.StmtCtx.AppendWarning(result.err)
} else {
log.Warn(result.err.Error())
}
Expand Down
1 change: 1 addition & 0 deletions pkg/planner/cardinality/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ go_library(
"//pkg/parser/format",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/planner/context",
"//pkg/planner/property",
"//pkg/planner/util",
"//pkg/planner/util/debugtrace",
Expand Down
14 changes: 7 additions & 7 deletions pkg/planner/cardinality/cross_estimation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"math"

"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/planner/context"
"github.com/pingcap/tidb/pkg/planner/property"
"github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/statistics"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/ranger"
Expand All @@ -33,7 +33,7 @@ const SelectionFactor = 0.8
// AdjustRowCountForTableScanByLimit will adjust the row count for table scan by limit.
// For a query like `select pk from t using index(primary) where pk > 10 limit 1`, the row count of the table scan
// should be adjusted by the limit number 1, because only one row is returned.
func AdjustRowCountForTableScanByLimit(sctx sessionctx.Context,
func AdjustRowCountForTableScanByLimit(sctx context.PlanContext,
dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table,
path *util.AccessPath, expectedCnt float64, desc bool) float64 {
rowCount := path.CountAfterAccess
Expand Down Expand Up @@ -67,7 +67,7 @@ func AdjustRowCountForTableScanByLimit(sctx sessionctx.Context,
// `select * from tbl where a = 1 order by pk limit 1`
// if order of column `a` is strictly correlated with column `pk`, the row count of table scan should be:
// `1 + row_count(a < 1 or a is null)`
func crossEstimateTableRowCount(sctx sessionctx.Context,
func crossEstimateTableRowCount(sctx context.PlanContext,
dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table,
path *util.AccessPath, expectedCnt float64, desc bool) (float64, bool, float64) {
if dsStatisticTable.Pseudo || len(path.TableFilters) == 0 || !sctx.GetSessionVars().EnableCorrelationAdjustment {
Expand All @@ -80,7 +80,7 @@ func crossEstimateTableRowCount(sctx sessionctx.Context,
// AdjustRowCountForIndexScanByLimit will adjust the row count for table scan by limit.
// For a query like `select k from t using index(k) where k > 10 limit 1`, the row count of the index scan
// should be adjusted by the limit number 1, because only one row is returned.
func AdjustRowCountForIndexScanByLimit(sctx sessionctx.Context,
func AdjustRowCountForIndexScanByLimit(sctx context.PlanContext,
dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table,
path *util.AccessPath, expectedCnt float64, desc bool) float64 {
rowCount := path.CountAfterAccess
Expand Down Expand Up @@ -114,7 +114,7 @@ func AdjustRowCountForIndexScanByLimit(sctx sessionctx.Context,
// `select * from tbl where a = 1 order by b limit 1`
// if order of column `a` is strictly correlated with column `b`, the row count of IndexScan(b) should be:
// `1 + row_count(a < 1 or a is null)`
func crossEstimateIndexRowCount(sctx sessionctx.Context,
func crossEstimateIndexRowCount(sctx context.PlanContext,
dsStatsInfo, dsTableStats *property.StatsInfo, dsStatisticTable *statistics.Table,
path *util.AccessPath, expectedCnt float64, desc bool) (float64, bool, float64) {
filtersLen := len(path.TableFilters) + len(path.IndexFilters)
Expand All @@ -130,7 +130,7 @@ func crossEstimateIndexRowCount(sctx sessionctx.Context,
}

// crossEstimateRowCount is the common logic of crossEstimateTableRowCount and crossEstimateIndexRowCount.
func crossEstimateRowCount(sctx sessionctx.Context,
func crossEstimateRowCount(sctx context.PlanContext,
dsStatsInfo, dsTableStats *property.StatsInfo,
path *util.AccessPath, conds []expression.Expression, col *expression.Column,
corr, expectedCnt float64, desc bool) (float64, bool, float64) {
Expand Down Expand Up @@ -182,7 +182,7 @@ func crossEstimateRowCount(sctx sessionctx.Context,
}

// getColumnRangeCounts estimates row count for each range respectively.
func getColumnRangeCounts(sctx sessionctx.Context, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) {
func getColumnRangeCounts(sctx context.PlanContext, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) {
var err error
var count float64
rangeCounts := make([]float64, len(ranges))
Expand Down
4 changes: 2 additions & 2 deletions pkg/planner/cardinality/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import (
"math"

"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/planner/context"
"github.com/pingcap/tidb/pkg/planner/property"
"github.com/pingcap/tidb/pkg/sessionctx"
)

// EstimateFullJoinRowCount estimates the row count of a full join.
func EstimateFullJoinRowCount(sctx sessionctx.Context,
func EstimateFullJoinRowCount(sctx context.PlanContext,
isCartesian bool,
leftProfile, rightProfile *property.StatsInfo,
leftJoinKeys, rightJoinKeys []*expression.Column,
Expand Down
11 changes: 5 additions & 6 deletions pkg/planner/cardinality/pseudo.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/statistics"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/ranger"
Expand Down Expand Up @@ -163,15 +162,15 @@ func getPseudoRowCountByUnsignedIntRanges(intRanges []*ranger.Range, tableRowCou
return rowCount
}

func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges []*ranger.Range,
func getPseudoRowCountByIndexRanges(tc types.Context, indexRanges []*ranger.Range,
tableRowCount float64, colsLen int) (float64, error) {
if tableRowCount == 0 {
return 0, nil
}
var totalCount float64
for _, indexRange := range indexRanges {
count := tableRowCount
i, err := indexRange.PrefixEqualLen(sc)
i, err := indexRange.PrefixEqualLen(tc)
if err != nil {
return 0, errors.Trace(err)
}
Expand All @@ -182,7 +181,7 @@ func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges []
if i >= len(indexRange.LowVal) {
i = len(indexRange.LowVal) - 1
}
rowCount, err := getPseudoRowCountByColumnRanges(sc, tableRowCount, []*ranger.Range{indexRange}, i)
rowCount, err := getPseudoRowCountByColumnRanges(tc, tableRowCount, []*ranger.Range{indexRange}, i)
if err != nil {
return 0, errors.Trace(err)
}
Expand All @@ -201,7 +200,7 @@ func getPseudoRowCountByIndexRanges(sc *stmtctx.StatementContext, indexRanges []
}

// getPseudoRowCountByColumnRanges calculate the row count by the ranges if there's no statistics information for this column.
func getPseudoRowCountByColumnRanges(sc *stmtctx.StatementContext, tableRowCount float64, columnRanges []*ranger.Range, colIdx int) (float64, error) {
func getPseudoRowCountByColumnRanges(tc types.Context, tableRowCount float64, columnRanges []*ranger.Range, colIdx int) (float64, error) {
var rowCount float64
for _, ran := range columnRanges {
if ran.LowVal[colIdx].Kind() == types.KindNull && ran.HighVal[colIdx].Kind() == types.KindMaxValue {
Expand All @@ -217,7 +216,7 @@ func getPseudoRowCountByColumnRanges(sc *stmtctx.StatementContext, tableRowCount
} else if ran.HighVal[colIdx].Kind() == types.KindMaxValue {
rowCount += tableRowCount / pseudoLessRate
} else {
compare, err := ran.LowVal[colIdx].Compare(sc.TypeCtx(), &ran.HighVal[colIdx], ran.Collators[colIdx])
compare, err := ran.LowVal[colIdx].Compare(tc, &ran.HighVal[colIdx], ran.Collators[colIdx])
if err != nil {
return 0, errors.Trace(err)
}
Expand Down
Loading