diff --git a/br/pkg/lightning/backend/kv/BUILD.bazel b/br/pkg/lightning/backend/kv/BUILD.bazel index ff24fe7493763..3c158d912c1e2 100644 --- a/br/pkg/lightning/backend/kv/BUILD.bazel +++ b/br/pkg/lightning/backend/kv/BUILD.bazel @@ -24,12 +24,15 @@ go_library( "//br/pkg/utils", "//pkg/errctx", "//pkg/expression", + "//pkg/expression/context", + "//pkg/expression/contextimpl", "//pkg/infoschema/context", "//pkg/kv", "//pkg/meta/autoid", "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/planner/context", + "//pkg/planner/contextimpl", "//pkg/sessionctx", "//pkg/sessionctx/variable", "//pkg/table", diff --git a/br/pkg/lightning/backend/kv/base.go b/br/pkg/lightning/backend/kv/base.go index b207faaf2fa34..63611beaf5e6d 100644 --- a/br/pkg/lightning/backend/kv/base.go +++ b/br/pkg/lightning/backend/kv/base.go @@ -280,7 +280,7 @@ func (e *BaseKVEncoder) getActualDatum(col *table.Column, rowID int64, inputDatu e.SessionCtx.Vars.TxnCtx = nil }() } - value, err = table.GetColDefaultValue(e.SessionCtx, col.ToInfo()) + value, err = table.GetColDefaultValue(e.SessionCtx.GetExprCtx(), col.ToInfo()) } return value, err } @@ -353,7 +353,7 @@ func evalGeneratedColumns(se *Session, record []types.Datum, cols []*table.Colum mutRow := chunk.MutRowFromDatums(record) for _, gc := range genCols { col := cols[gc.Index].ToInfo() - evaluated, err := gc.Expr.Eval(se, mutRow.ToRow()) + evaluated, err := gc.Expr.Eval(se.GetExprCtx(), mutRow.ToRow()) if err != nil { return col, err } diff --git a/br/pkg/lightning/backend/kv/session.go b/br/pkg/lightning/backend/kv/session.go index abc1e18d9e365..6d52a7a4788e3 100644 --- a/br/pkg/lightning/backend/kv/session.go +++ b/br/pkg/lightning/backend/kv/session.go @@ -30,10 +30,13 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/manual" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/errctx" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl" infoschema "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/model" planctx "github.com/pingcap/tidb/pkg/planner/context" + planctximpl "github.com/pingcap/tidb/pkg/planner/contextimpl" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" tbctx "github.com/pingcap/tidb/pkg/table/context" @@ -264,15 +267,21 @@ func (*transaction) SetAssertion(_ []byte, _ ...kv.FlagsOp) error { return nil } +type planCtxImpl struct { + *Session + *planctximpl.PlanCtxExtendedImpl +} + // Session is a trimmed down Session type which only wraps our own trimmed-down // transaction type and provides the session variables to the TiDB library // optimized for Lightning. type Session struct { sessionctx.Context planctx.EmptyPlanContextExtended - txn transaction - Vars *variable.SessionVars - tblctx *tbctximpl.TableContextImpl + txn transaction + Vars *variable.SessionVars + planctx *planCtxImpl + tblctx *tbctximpl.TableContextImpl // currently, we only set `CommonAddRecordCtx` values map[fmt.Stringer]any } @@ -333,7 +342,11 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { } vars.TxnCtx = nil s.Vars = vars - s.tblctx = tbctximpl.NewTableContextImpl(s) + s.planctx = &planCtxImpl{ + Session: s, + PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s, exprctximpl.NewExprExtendedImpl(s)), + } + s.tblctx = tbctximpl.NewTableContextImpl(s, s.planctx) s.txn.kvPairs = &Pairs{} return s @@ -363,7 +376,12 @@ func (se *Session) GetSessionVars() *variable.SessionVars { // GetPlanCtx returns the PlanContext. func (se *Session) GetPlanCtx() planctx.PlanContext { - return se + return se.planctx +} + +// GetExprCtx returns the expression context of the session. +func (se *Session) GetExprCtx() exprctx.BuildContext { + return se.planctx } // GetTableCtx returns the table.MutateContext diff --git a/br/pkg/lightning/backend/kv/sql2kv.go b/br/pkg/lightning/backend/kv/sql2kv.go index 4d6854c3749ee..ad4b188eddc81 100644 --- a/br/pkg/lightning/backend/kv/sql2kv.go +++ b/br/pkg/lightning/backend/kv/sql2kv.go @@ -116,7 +116,7 @@ func CollectGeneratedColumns(se *Session, meta *model.TableInfo, cols []*table.C for i, col := range cols { if col.GeneratedExpr != nil { expr, err := expression.BuildSimpleExpr( - se, + se.GetExprCtx(), col.GeneratedExpr.Internal(), expression.WithInputSchemaAndNames(schema, names, meta), expression.WithAllowCastArray(true), diff --git a/pkg/ddl/backfilling.go b/pkg/ddl/backfilling.go index 2ac76983d7c8d..5a7266fdb8ed8 100644 --- a/pkg/ddl/backfilling.go +++ b/pkg/ddl/backfilling.go @@ -627,7 +627,7 @@ func makeupDecodeColMap(sessCtx sessionctx.Context, dbName model.CIStr, t table. for _, col := range t.WritableCols() { writableColInfos = append(writableColInfos, col.ColumnInfo) } - exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx, dbName, t.Meta().Name, writableColInfos, t.Meta()) + exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta()) if err != nil { return nil, err } diff --git a/pkg/ddl/column.go b/pkg/ddl/column.go index cba5a4256815c..0ce3042037c69 100644 --- a/pkg/ddl/column.go +++ b/pkg/ddl/column.go @@ -1346,7 +1346,7 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra val := w.rowMap[w.oldColInfo.ID] col := w.newColInfo if val.Kind() == types.KindNull && col.FieldType.GetType() == mysql.TypeTimestamp && mysql.HasNotNullFlag(col.GetFlag()) { - if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx, col.GetType(), col.GetDecimal()); err == nil { + if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx.GetExprCtx(), col.GetType(), col.GetDecimal()); err == nil { // convert null value to timestamp should be substituted with current timestamp if NOT_NULL flag is set. w.rowMap[w.oldColInfo.ID] = v } @@ -1971,7 +1971,7 @@ func generateOriginDefaultValue(col *model.ColumnInfo, ctx sessionctx.Context) ( if ctx == nil { t = time.Now() } else { - t, _ = expression.GetStmtTimestamp(ctx) + t, _ = expression.GetStmtTimestamp(ctx.GetExprCtx()) } if col.GetType() == mysql.TypeTimestamp { odValue = types.NewTime(types.FromGoTime(t.UTC()), col.GetType(), col.GetDecimal()).String() diff --git a/pkg/ddl/copr/copr_ctx.go b/pkg/ddl/copr/copr_ctx.go index 0eb0ebe79befc..87f1b43adea6f 100644 --- a/pkg/ddl/copr/copr_ctx.go +++ b/pkg/ddl/copr/copr_ctx.go @@ -115,7 +115,7 @@ func NewCopContextBase( handleIDs = []int64{extra.ID} } - expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx, + expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(), model.CIStr{} /* unused */, tblInfo.Name, colInfos, tblInfo) if err != nil { return nil, err diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 0e1c074747af7..c32235df35484 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1032,7 +1032,7 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value an if value != nil && ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() && ctx.GetSessionVars().SQLMode.HasStrictMode() && types.IsTypeTime(col.GetType()) { if vv, ok := value.(string); ok { - timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil) + timeValue, err := expression.GetTimeValue(ctx.GetExprCtx(), vv, col.GetType(), col.GetDecimal(), nil) if err != nil { return hasDefaultValue, value, errors.Trace(err) } @@ -1385,7 +1385,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu } if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { - vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil) + vd, err := expression.GetTimeValue(ctx.GetExprCtx(), option.Expr, tp, fsp, nil) value := vd.GetValue() if err != nil { return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) @@ -1405,7 +1405,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu } // evaluate the non-function-call expr to a certain value. - v, err := expression.EvalSimpleAst(ctx, option.Expr) + v, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr) if err != nil { return nil, false, errors.Trace(err) } @@ -1614,7 +1614,7 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue if c.DefaultIsExpr { return nil } - if _, err := table.GetColDefaultValue(ctx, c.ToInfo()); err != nil { + if _, err := table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo()); err != nil { return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) } return nil @@ -3212,7 +3212,7 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) erro // checkPartitionByList checks validity of a "BY LIST" partition. func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - return checkListPartitionValue(ctx, tbInfo) + return checkListPartitionValue(ctx.GetExprCtx(), tbInfo) } func isValidKeyPartitionColType(fieldType types.FieldType) bool { @@ -3305,7 +3305,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef // PARTITION p1 VALUES LESS THAN (10,20,'mmm') // PARTITION p2 VALUES LESS THAN (15,30,'sss') colInfo := findColumnByName(pi.Columns[i].L, tbInfo) - succ, err := parseAndEvalBoolExpr(ctx, curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo) + succ, err := parseAndEvalBoolExpr(ctx.GetExprCtx(), curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo) if err != nil { return false, err } @@ -3317,7 +3317,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef return false, nil } -func parseAndEvalBoolExpr(ctx sessionctx.Context, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) { +func parseAndEvalBoolExpr(ctx expression.BuildContext, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) { lexpr, err := expression.ParseSimpleExpr(ctx, l, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType)) if err != nil { return false, err @@ -4366,7 +4366,7 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec * return d.hashPartitionManagement(ctx, ident, spec, pi) } - partInfo, err := BuildAddedPartitionInfo(ctx, meta, spec) + partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec) if err != nil { return errors.Trace(err) } @@ -4612,7 +4612,7 @@ func (d *ddl) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, spec if err != nil { return errors.Trace(err) } - partInfo, err := BuildAddedPartitionInfo(ctx, meta, spec) + partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec) if err != nil { return errors.Trace(err) } @@ -4678,7 +4678,7 @@ func (d *ddl) RemovePartitioning(ctx sessionctx.Context, ident ast.Ident, spec * partNames[i] = pi.Definitions[i].Name.L } meta.Partition.Type = model.PartitionTypeNone - partInfo, err := BuildAddedPartitionInfo(ctx, meta, newSpec) + partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, newSpec) if err != nil { return errors.Trace(err) } @@ -4756,11 +4756,11 @@ func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tb } isUnsigned := isPartExprUnsigned(tblInfo) - currentRangeValue, _, err := getRangeValue(ctx, pi.Definitions[lastPartIdx].LessThan[0], isUnsigned) + currentRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), pi.Definitions[lastPartIdx].LessThan[0], isUnsigned) if err != nil { return errors.Trace(err) } - newRangeValue, _, err := getRangeValue(ctx, partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned) + newRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned) if err != nil { return errors.Trace(err) } @@ -4935,7 +4935,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec * } if spec.Tp == ast.AlterTableDropFirstPartition { - intervalOptions := getPartitionIntervalFromTable(ctx, meta) + intervalOptions := getPartitionIntervalFromTable(ctx.GetExprCtx(), meta) if intervalOptions == nil { return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( "FIRST PARTITION, does not seem like an INTERVAL partitioned table") @@ -4945,7 +4945,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec * "FIRST PARTITION, table info already contains partition definitions") } spec.Partition.Interval = intervalOptions - err = GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition) + err = GeneratePartDefsFromInterval(ctx.GetExprCtx(), spec.Tp, meta, spec.Partition) if err != nil { return err } @@ -5447,7 +5447,7 @@ func setDefaultValueWithBinaryPadding(col *table.Column, value any) error { } func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error { - value, err := expression.EvalSimpleAst(ctx, option.Expr) + value, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr) if err != nil { return errors.Trace(err) } @@ -5808,7 +5808,7 @@ func GetModifiableColumnJob( oldTypeFlags := sv.TypeFlags() newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false) sv.SetTypeFlags(newTypeFlags) - _, err = buildPartitionDefinitionsInfo(sctx, pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions))) + _, err = buildPartitionDefinitionsInfo(sctx.GetExprCtx(), pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions))) sv.SetTypeFlags(oldTypeFlags) if err != nil { return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error()) @@ -7373,7 +7373,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as if err != nil { return nil, errors.Trace(err) } - expr, err := expression.BuildSimpleExpr(ctx, idxPart.Expr, + expr, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), idxPart.Expr, expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo), expression.WithAllowCastArray(true), ) @@ -7968,7 +7968,7 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str } // BuildAddedPartitionInfo build alter table add partition info -func BuildAddedPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) { +func BuildAddedPartitionInfo(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) { numParts := uint64(0) switch meta.Partition.Type { case model.PartitionTypeNone: @@ -8034,7 +8034,7 @@ func BuildAddedPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, spec return part, nil } -func buildAddedPartitionDefs(ctx sessionctx.Context, meta *model.TableInfo, spec *ast.AlterTableSpec) error { +func buildAddedPartitionDefs(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) error { partInterval := getPartitionIntervalFromTable(ctx, meta) if partInterval == nil { return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( @@ -8052,7 +8052,7 @@ func buildAddedPartitionDefs(ctx sessionctx.Context, meta *model.TableInfo, spec return GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition) } -func checkAndGetColumnsTypeAndValuesMatch(ctx sessionctx.Context, colTypes []types.FieldType, exprs []ast.ExprNode) ([]string, error) { +func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes []types.FieldType, exprs []ast.ExprNode) ([]string, error) { // Validate() has already checked len(colNames) = len(exprs) // create table ... partition by range columns (cols) // partition p0 values less than (expr) diff --git a/pkg/ddl/index_cop.go b/pkg/ddl/index_cop.go index 51a2d2b26ff2e..739361abd5cbc 100644 --- a/pkg/ddl/index_cop.go +++ b/pkg/ddl/index_cop.go @@ -360,7 +360,7 @@ func buildDAGPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*m func constructTableScanPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.Executor, error) { tblScan := tables.BuildTableScanFromInfos(tblInfo, colInfos) tblScan.TableId = tblInfo.ID - err := tables.SetPBColumnsDefaultValue(sCtx, tblScan.Columns, colInfos) + err := tables.SetPBColumnsDefaultValue(sCtx.GetExprCtx(), tblScan.Columns, colInfos) return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err } diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index 7910da77c6b9d..9989e76721c2a 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -559,7 +559,7 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } tbInfo.Partition = pi if s.Expr != nil { - if err := checkPartitionFuncValid(ctx, tbInfo, s.Expr); err != nil { + if err := checkPartitionFuncValid(ctx.GetExprCtx(), tbInfo, s.Expr); err != nil { return errors.Trace(err) } buf := new(bytes.Buffer) @@ -591,12 +591,13 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } } - err := generatePartitionDefinitionsFromInterval(ctx, s, tbInfo) + exprCtx := ctx.GetExprCtx() + err := generatePartitionDefinitionsFromInterval(exprCtx, s, tbInfo) if err != nil { return errors.Trace(err) } - defs, err := buildPartitionDefinitionsInfo(ctx, s.Definitions, tbInfo, s.Num) + defs, err := buildPartitionDefinitionsInfo(exprCtx, s.Definitions, tbInfo, s.Num) if err != nil { return errors.Trace(err) } @@ -619,7 +620,7 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb } } - partCols, err := getPartitionColSlices(ctx, tbInfo, s) + partCols, err := getPartitionColSlices(exprCtx, tbInfo, s) if err != nil { return errors.Trace(err) } @@ -632,7 +633,7 @@ func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tb return nil } -func getPartitionColSlices(sctx sessionctx.Context, tblInfo *model.TableInfo, s *ast.PartitionOptions) (partCols stringSlice, err error) { +func getPartitionColSlices(sctx expression.BuildContext, tblInfo *model.TableInfo, s *ast.PartitionOptions) (partCols stringSlice, err error) { if s.Expr != nil { extractCols := newPartitionExprChecker(sctx, tblInfo) s.Expr.Accept(extractCols) @@ -659,7 +660,7 @@ func getPartitionColSlices(sctx sessionctx.Context, tblInfo *model.TableInfo, s // getPartitionIntervalFromTable checks if a partitioned table matches a generated INTERVAL partitioned scheme // will return nil if error occurs, i.e. not an INTERVAL partitioned table -func getPartitionIntervalFromTable(ctx sessionctx.Context, tbInfo *model.TableInfo) *ast.PartitionInterval { +func getPartitionIntervalFromTable(ctx expression.BuildContext, tbInfo *model.TableInfo) *ast.PartitionInterval { if tbInfo.Partition == nil || tbInfo.Partition.Type != model.PartitionTypeRange { return nil @@ -795,7 +796,7 @@ func getPartitionIntervalFromTable(ctx sessionctx.Context, tbInfo *model.TableIn } // comparePartitionAstAndModel compares a generated *ast.PartitionOptions and a *model.PartitionInfo -func comparePartitionAstAndModel(ctx sessionctx.Context, pAst *ast.PartitionOptions, pModel *model.PartitionInfo, partCol *model.ColumnInfo) error { +func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.PartitionOptions, pModel *model.PartitionInfo, partCol *model.ColumnInfo) error { a := pAst.Definitions m := pModel.Definitions if len(pAst.Definitions) != len(pModel.Definitions) { @@ -853,7 +854,7 @@ func comparePartitionAstAndModel(ctx sessionctx.Context, pAst *ast.PartitionOpti // comparePartitionDefinitions check if generated definitions are the same as the given ones // Allow names to differ // returns error in case of error or non-accepted difference -func comparePartitionDefinitions(ctx sessionctx.Context, a, b []*ast.PartitionDefinition) error { +func comparePartitionDefinitions(ctx expression.BuildContext, a, b []*ast.PartitionDefinition) error { if len(a) != len(b) { return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("number of partitions generated != partition defined (%d != %d)", len(a), len(b)) } @@ -910,7 +911,7 @@ func getLowerBoundInt(partCols ...*model.ColumnInfo) int64 { } // generatePartitionDefinitionsFromInterval generates partition Definitions according to INTERVAL options on partOptions -func generatePartitionDefinitionsFromInterval(ctx sessionctx.Context, partOptions *ast.PartitionOptions, tbInfo *model.TableInfo) error { +func generatePartitionDefinitionsFromInterval(ctx expression.BuildContext, partOptions *ast.PartitionOptions, tbInfo *model.TableInfo) error { if partOptions.Interval == nil { return nil } @@ -1050,7 +1051,7 @@ func astIntValueExprFromStr(s string, unsigned bool) (ast.ExprNode, error) { // - ALTER TABLE LAST PARTITION (expr): Creates new partitions from (excluding) old LAST partition to (including) new LAST partition // // partition definitions will be set on partitionOptions -func GeneratePartDefsFromInterval(ctx sessionctx.Context, tp ast.AlterTableType, tbInfo *model.TableInfo, partitionOptions *ast.PartitionOptions) error { +func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTableType, tbInfo *model.TableInfo, partitionOptions *ast.PartitionOptions) error { if partitionOptions == nil { return nil } @@ -1205,7 +1206,7 @@ func GeneratePartDefsFromInterval(ctx sessionctx.Context, tp ast.AlterTableType, } // buildPartitionDefinitionsInfo build partition definitions info without assign partition id. tbInfo will be constant -func buildPartitionDefinitionsInfo(ctx sessionctx.Context, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) (partitions []model.PartitionDefinition, err error) { +func buildPartitionDefinitionsInfo(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) (partitions []model.PartitionDefinition, err error) { switch tbInfo.Partition.Type { case model.PartitionTypeNone: if len(defs) != 1 { @@ -1218,7 +1219,7 @@ func buildPartitionDefinitionsInfo(ctx sessionctx.Context, defs []*ast.Partition case model.PartitionTypeRange: partitions, err = buildRangePartitionDefinitions(ctx, defs, tbInfo) case model.PartitionTypeHash, model.PartitionTypeKey: - partitions, err = buildHashPartitionDefinitions(ctx, defs, tbInfo, numParts) + partitions, err = buildHashPartitionDefinitions(defs, tbInfo, numParts) case model.PartitionTypeList: partitions, err = buildListPartitionDefinitions(ctx, defs, tbInfo) default: @@ -1266,7 +1267,7 @@ func isNonDefaultPartitionOptionsUsed(defs []model.PartitionDefinition) bool { return false } -func buildHashPartitionDefinitions(_ sessionctx.Context, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) ([]model.PartitionDefinition, error) { +func buildHashPartitionDefinitions(defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) ([]model.PartitionDefinition, error) { if err := checkAddPartitionTooManyPartitions(tbInfo.Partition.Num); err != nil { return nil, err } @@ -1296,7 +1297,7 @@ func buildHashPartitionDefinitions(_ sessionctx.Context, defs []*ast.PartitionDe return definitions, nil } -func buildListPartitionDefinitions(ctx sessionctx.Context, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { +func buildListPartitionDefinitions(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { definitions := make([]model.PartitionDefinition, 0, len(defs)) exprChecker := newPartitionExprChecker(ctx, nil, checkPartitionExprAllowed) colTypes := collectColumnsType(tbInfo) @@ -1374,7 +1375,7 @@ func collectColumnsType(tbInfo *model.TableInfo) []types.FieldType { return nil } -func buildRangePartitionDefinitions(ctx sessionctx.Context, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { +func buildRangePartitionDefinitions(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { definitions := make([]model.PartitionDefinition, 0, len(defs)) exprChecker := newPartitionExprChecker(ctx, nil, checkPartitionExprAllowed) colTypes := collectColumnsType(tbInfo) @@ -1448,7 +1449,7 @@ func buildRangePartitionDefinitions(ctx sessionctx.Context, defs []*ast.Partitio return definitions, nil } -func checkPartitionValuesIsInt(ctx sessionctx.Context, defName any, exprs []ast.ExprNode, tbInfo *model.TableInfo) error { +func checkPartitionValuesIsInt(ctx expression.BuildContext, defName any, exprs []ast.ExprNode, tbInfo *model.TableInfo) error { tp := types.NewFieldType(mysql.TypeLonglong) if isPartExprUnsigned(tbInfo) { tp.AddFlag(mysql.UnsignedFlag) @@ -1575,7 +1576,7 @@ func checkAndOverridePartitionID(newTableInfo, oldTableInfo *model.TableInfo) er } // checkPartitionFuncValid checks partition function validly. -func checkPartitionFuncValid(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.ExprNode) error { +func checkPartitionFuncValid(ctx expression.BuildContext, tblInfo *model.TableInfo, expr ast.ExprNode) error { if expr == nil { return nil } @@ -1611,7 +1612,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, schema st schema = ctx.GetSessionVars().CurrentDB } - e, err := expression.BuildSimpleExpr(ctx, expr, expression.WithTableInfo(schema, tblInfo)) + e, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), expr, expression.WithTableInfo(schema, tblInfo)) if err != nil { return errors.Trace(err) } @@ -1645,7 +1646,7 @@ func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) return errors.Trace(dbterror.ErrPartitionMaxvalue) } - currentRangeValue, fromExpr, err := getRangeValue(ctx, defs[i].LessThan[0], isUnsigned) + currentRangeValue, fromExpr, err := getRangeValue(ctx.GetExprCtx(), defs[i].LessThan[0], isUnsigned) if err != nil { return errors.Trace(err) } @@ -1673,7 +1674,7 @@ func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) return nil } -func checkListPartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) error { +func checkListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) error { pi := tblInfo.Partition if len(pi.Definitions) == 0 { return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("LIST") @@ -1694,7 +1695,7 @@ func checkListPartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) e return nil } -func formatListPartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) ([]string, error) { +func formatListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) ([]string, error) { defs := tblInfo.Partition.Definitions pi := tblInfo.Partition var colTps []*types.FieldType @@ -1770,7 +1771,7 @@ func formatListPartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) // getRangeValue gets an integer from the range value string. // The returned boolean value indicates whether the input string is a constant expression. -func getRangeValue(ctx sessionctx.Context, str string, unsigned bool) (any, bool, error) { +func getRangeValue(ctx expression.BuildContext, str string, unsigned bool) (any, bool, error) { // Unsigned bigint was converted to uint64 handle. if unsigned { if value, err := strconv.ParseUint(str, 10, 64); err == nil { @@ -3331,7 +3332,7 @@ func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reo } tmpRow[offset] = d } - p, err := w.reorgedTbl.GetPartitionByRow(w.sessCtx, tmpRow) + p, err := w.reorgedTbl.GetPartitionByRow(w.sessCtx.GetExprCtx(), tmpRow) if err != nil { return false, errors.Trace(err) } @@ -3831,7 +3832,7 @@ func checkPartitioningKeysConstraints(sctx sessionctx.Context, s *ast.CreateTabl return nil } - partCols, err := getPartitionColSlices(sctx, tblInfo, s.Partition) + partCols, err := getPartitionColSlices(sctx.GetExprCtx(), tblInfo, s.Partition) if err != nil { return errors.Trace(err) } @@ -4033,18 +4034,18 @@ func truncateTableByReassignPartitionIDs(t *meta.Meta, tblInfo *model.TableInfo, return nil } -type partitionExprProcessor func(sessionctx.Context, *model.TableInfo, ast.ExprNode) error +type partitionExprProcessor func(expression.BuildContext, *model.TableInfo, ast.ExprNode) error type partitionExprChecker struct { processors []partitionExprProcessor - ctx sessionctx.Context + ctx expression.BuildContext tbInfo *model.TableInfo err error columns []*model.ColumnInfo } -func newPartitionExprChecker(ctx sessionctx.Context, tbInfo *model.TableInfo, processor ...partitionExprProcessor) *partitionExprChecker { +func newPartitionExprChecker(ctx expression.BuildContext, tbInfo *model.TableInfo, processor ...partitionExprProcessor) *partitionExprChecker { p := &partitionExprChecker{processors: processor, ctx: ctx, tbInfo: tbInfo} p.processors = append(p.processors, p.extractColumns) return p @@ -4069,7 +4070,7 @@ func (p *partitionExprChecker) Leave(n ast.Node) (node ast.Node, ok bool) { return n, p.err == nil } -func (p *partitionExprChecker) extractColumns(_ sessionctx.Context, _ *model.TableInfo, expr ast.ExprNode) error { +func (p *partitionExprChecker) extractColumns(_ expression.BuildContext, _ *model.TableInfo, expr ast.ExprNode) error { columnNameExpr, ok := expr.(*ast.ColumnNameExpr) if !ok { return nil @@ -4083,7 +4084,7 @@ func (p *partitionExprChecker) extractColumns(_ sessionctx.Context, _ *model.Tab return nil } -func checkPartitionExprAllowed(_ sessionctx.Context, tb *model.TableInfo, e ast.ExprNode) error { +func checkPartitionExprAllowed(_ expression.BuildContext, tb *model.TableInfo, e ast.ExprNode) error { switch v := e.(type) { case *ast.FuncCallExpr: if _, ok := expression.AllowedPartitionFuncMap[v.FnName.L]; ok { @@ -4104,7 +4105,7 @@ func checkPartitionExprAllowed(_ sessionctx.Context, tb *model.TableInfo, e ast. return errors.Trace(dbterror.ErrPartitionFunctionIsNotAllowed) } -func checkPartitionExprArgs(_ sessionctx.Context, tblInfo *model.TableInfo, e ast.ExprNode) error { +func checkPartitionExprArgs(_ expression.BuildContext, tblInfo *model.TableInfo, e ast.ExprNode) error { expr, ok := e.(*ast.FuncCallExpr) if !ok { return nil diff --git a/pkg/ddl/schematracker/dm_tracker.go b/pkg/ddl/schematracker/dm_tracker.go index 4ddc4852479e2..4d80800060e83 100644 --- a/pkg/ddl/schematracker/dm_tracker.go +++ b/pkg/ddl/schematracker/dm_tracker.go @@ -787,7 +787,7 @@ func (d SchemaTracker) addTablePartitions(ctx sessionctx.Context, ident ast.Iden return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) } - partInfo, err := ddl.BuildAddedPartitionInfo(ctx, tblInfo, spec) + partInfo, err := ddl.BuildAddedPartitionInfo(ctx.GetExprCtx(), tblInfo, spec) if err != nil { return errors.Trace(err) } diff --git a/pkg/ddl/ttl.go b/pkg/ddl/ttl.go index 53c7b529d1761..7b6b1c4a7ec04 100644 --- a/pkg/ddl/ttl.go +++ b/pkg/ddl/ttl.go @@ -97,7 +97,7 @@ func onTTLInfoChange(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err er } func checkTTLInfoValid(ctx sessionctx.Context, schema model.CIStr, tblInfo *model.TableInfo) error { - if err := checkTTLIntervalExpr(ctx, tblInfo.TTLInfo); err != nil { + if err := checkTTLIntervalExpr(ctx.GetExprCtx(), tblInfo.TTLInfo); err != nil { return err } @@ -108,7 +108,7 @@ func checkTTLInfoValid(ctx sessionctx.Context, schema model.CIStr, tblInfo *mode return checkTTLInfoColumnType(tblInfo) } -func checkTTLIntervalExpr(ctx sessionctx.Context, ttlInfo *model.TTLInfo) error { +func checkTTLIntervalExpr(ctx expression.BuildContext, ttlInfo *model.TTLInfo) error { // FIXME: use a better way to validate the interval expression in ttl var nowAddIntervalExpr ast.ExprNode diff --git a/pkg/executor/admin.go b/pkg/executor/admin.go index 1c48817eb739a..8731dda55d961 100644 --- a/pkg/executor/admin.go +++ b/pkg/executor/admin.go @@ -152,7 +152,7 @@ func (e *CheckIndexRangeExec) buildDAGPB() (*tipb.DAGRequest, error) { execPB := e.constructIndexScanPB() dagReq.Executors = append(dagReq.Executors, execPB) - err := tables.SetPBColumnsDefaultValue(e.Ctx(), dagReq.Executors[0].IdxScan.Columns, e.cols) + err := tables.SetPBColumnsDefaultValue(e.Ctx().GetExprCtx(), dagReq.Executors[0].IdxScan.Columns, e.cols) if err != nil { return nil, err } @@ -231,7 +231,7 @@ func (e *RecoverIndexExec) Open(ctx context.Context) error { func (e *RecoverIndexExec) constructTableScanPB(tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.Executor, error) { tblScan := tables.BuildTableScanFromInfos(tblInfo, colInfos) tblScan.TableId = e.physicalID - err := tables.SetPBColumnsDefaultValue(e.Ctx(), tblScan.Columns, colInfos) + err := tables.SetPBColumnsDefaultValue(e.Ctx().GetExprCtx(), tblScan.Columns, colInfos) return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err } @@ -403,7 +403,7 @@ func (e *RecoverIndexExec) buildIndexedValues(row chunk.Row, idxVals []types.Dat } if e.cols == nil { - columns, _, err := expression.ColumnInfos2ColumnsAndNames(e.Ctx(), model.NewCIStr("mock"), e.table.Meta().Name, e.table.Meta().Columns, e.table.Meta()) + columns, _, err := expression.ColumnInfos2ColumnsAndNames(e.Ctx().GetExprCtx(), model.NewCIStr("mock"), e.table.Meta().Name, e.table.Meta().Columns, e.table.Meta()) if err != nil { return nil, err } @@ -419,7 +419,7 @@ func (e *RecoverIndexExec) buildIndexedValues(row chunk.Row, idxVals []types.Dat sctx := e.Ctx() for i, col := range e.index.Meta().Columns { if e.table.Meta().Columns[col.Offset].IsGenerated() { - val, err := e.cols[col.Offset].EvalVirtualColumn(sctx, row) + val, err := e.cols[col.Offset].EvalVirtualColumn(sctx.GetExprCtx(), row) if err != nil { return nil, err } @@ -857,7 +857,7 @@ func (e *CleanupIndexExec) buildIdxDAGPB() (*tipb.DAGRequest, error) { execPB := e.constructIndexScanPB() dagReq.Executors = append(dagReq.Executors, execPB) - err := tables.SetPBColumnsDefaultValue(e.Ctx(), dagReq.Executors[0].IdxScan.Columns, e.columns) + err := tables.SetPBColumnsDefaultValue(e.Ctx().GetExprCtx(), dagReq.Executors[0].IdxScan.Columns, e.columns) if err != nil { return nil, err } diff --git a/pkg/executor/aggfuncs/aggfunc_test.go b/pkg/executor/aggfuncs/aggfunc_test.go index 0ca2e19dadbfc..474a5e5124d0b 100644 --- a/pkg/executor/aggfuncs/aggfunc_test.go +++ b/pkg/executor/aggfuncs/aggfunc_test.go @@ -373,7 +373,7 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i return pt } -func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p multiArgsAggTest) { +func testMultiArgsMergePartialResult(t *testing.T, ctx expression.BuildContext, p multiArgsAggTest) { srcChk := p.genSrcChk() iter := chunk.NewIterator4Chunk(srcChk) @@ -666,7 +666,7 @@ func testAggMemFunc(t *testing.T, p aggMemTest) { } } -func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTest) { +func testMultiArgsAggFunc(t *testing.T, ctx expression.BuildContext, p multiArgsAggTest) { srcChk := p.genSrcChk() args := make([]expression.Expression, len(p.dataTypes)) @@ -790,7 +790,7 @@ func testMultiArgsAggMemFunc(t *testing.T, p multiArgsAggMemTest) { } } -func benchmarkAggFunc(b *testing.B, ctx sessionctx.Context, p aggTest) { +func benchmarkAggFunc(b *testing.B, ctx expression.BuildContext, p aggTest) { srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows) for i := 0; i < p.numRows; i++ { dt := p.dataGen(i) @@ -838,7 +838,7 @@ func benchmarkAggFunc(b *testing.B, ctx sessionctx.Context, p aggTest) { }) } -func benchmarkMultiArgsAggFunc(b *testing.B, ctx sessionctx.Context, p multiArgsAggTest) { +func benchmarkMultiArgsAggFunc(b *testing.B, ctx expression.BuildContext, p multiArgsAggTest) { srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows) for i := 0; i < p.numRows; i++ { for j := 0; j < len(p.dataGens); j++ { @@ -892,7 +892,7 @@ func benchmarkMultiArgsAggFunc(b *testing.B, ctx sessionctx.Context, p multiArgs }) } -func baseBenchmarkAggFunc(b *testing.B, ctx sessionctx.Context, finalFunc aggfuncs.AggFunc, input []chunk.Row, output *chunk.Chunk) { +func baseBenchmarkAggFunc(b *testing.B, ctx expression.BuildContext, finalFunc aggfuncs.AggFunc, input []chunk.Row, output *chunk.Chunk) { finalPr, _ := finalFunc.AllocPartialResult() output.Reset() b.ResetTimer() diff --git a/pkg/executor/aggfuncs/func_group_concat_test.go b/pkg/executor/aggfuncs/func_group_concat_test.go index dc3e805440cbc..f2d472d82508e 100644 --- a/pkg/executor/aggfuncs/func_group_concat_test.go +++ b/pkg/executor/aggfuncs/func_group_concat_test.go @@ -129,7 +129,7 @@ func groupConcatOrderMultiArgsUpdateMemDeltaGens(ctx sessionctx.Context, srcChk } memDelta := int64(buffer.Cap() - oldMemSize) for _, byItem := range byItems { - fdt, _ := byItem.Expr.Eval(ctx, row) + fdt, _ := byItem.Expr.Eval(ctx.GetPlanCtx(), row) datumMem := aggfuncs.GetDatumMemSize(&fdt) memDelta += datumMem } @@ -202,7 +202,7 @@ func groupConcatDistinctOrderMultiArgsUpdateMemDeltaGens(ctx sessionctx.Context, valSet.Insert(joinedVal) memDelta := int64(len(joinedVal) + (valsBuf.Cap() + cap(encodeBytesBuffer) - oldMemSize)) for _, byItem := range byItems { - fdt, _ := byItem.Expr.Eval(ctx, row) + fdt, _ := byItem.Expr.Eval(ctx.GetExprCtx(), row) datumMem := aggfuncs.GetDatumMemSize(&fdt) memDelta += datumMem } diff --git a/pkg/executor/aggregate/agg_hash_executor.go b/pkg/executor/aggregate/agg_hash_executor.go index 65e3ee4224d6d..df18b7436d2f0 100644 --- a/pkg/executor/aggregate/agg_hash_executor.go +++ b/pkg/executor/aggregate/agg_hash_executor.go @@ -625,6 +625,7 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() for { + exprCtx := e.Ctx().GetExprCtx() if e.prepared.Load() { // Since we return e.MaxChunkSize() rows every time, so we should not traverse // `groupSet` because of its randomness. @@ -634,7 +635,7 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro chk.SetNumVirtualRows(chk.NumRows() + 1) } for i, af := range e.PartialAggFuncs { - if err := af.AppendFinalResult2Chunk(e.Ctx(), partialResults[i], chk); err != nil { + if err := af.AppendFinalResult2Chunk(exprCtx, partialResults[i], chk); err != nil { return err } } @@ -684,6 +685,7 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) { e.tmpChkForSpill.Reset() } }() + exprCtx := e.Ctx().GetExprCtx() for { mSize := e.childResult.MemoryUsage() if err := e.getNextChunk(ctx); err != nil { @@ -726,7 +728,7 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) { partialResults := e.getPartialResults(groupKey) for i, af := range e.PartialAggFuncs { tmpBuf[0] = e.childResult.GetRow(j) - memDelta, err := af.UpdatePartialResult(e.Ctx(), tmpBuf[:], partialResults[i]) + memDelta, err := af.UpdatePartialResult(exprCtx, tmpBuf[:], partialResults[i]) if err != nil { return err } diff --git a/pkg/executor/aggregate/agg_hash_final_worker.go b/pkg/executor/aggregate/agg_hash_final_worker.go index afc404ad99b98..e35288e88ec6e 100644 --- a/pkg/executor/aggregate/agg_hash_final_worker.go +++ b/pkg/executor/aggregate/agg_hash_final_worker.go @@ -95,6 +95,7 @@ func (w *HashAggFinalWorker) mergeInputIntoResultMap(sctx sessionctx.Context, in execStart := time.Now() allMemDelta := int64(0) + exprCtx := sctx.GetExprCtx() for key, value := range *input { dstVal, ok := w.partialResultMap[key] if !ok { @@ -103,7 +104,7 @@ func (w *HashAggFinalWorker) mergeInputIntoResultMap(sctx sessionctx.Context, in } for j, af := range w.aggFuncs { - memDelta, err := af.MergePartialResult(sctx, value[j], dstVal[j]) + memDelta, err := af.MergePartialResult(exprCtx, value[j], dstVal[j]) if err != nil { return err } @@ -140,9 +141,10 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) error { func (w *HashAggFinalWorker) generateResultAndSend(sctx sessionctx.Context, result *chunk.Chunk) { var finished bool + exprCtx := sctx.GetExprCtx() for _, results := range w.partialResultMap { for j, af := range w.aggFuncs { - if err := af.AppendFinalResult2Chunk(sctx, results[j], result); err != nil { + if err := af.AppendFinalResult2Chunk(exprCtx, results[j], result); err != nil { logutil.BgLogger().Error("HashAggFinalWorker failed to append final result to Chunk", zap.Error(err)) } } diff --git a/pkg/executor/aggregate/agg_hash_partial_worker.go b/pkg/executor/aggregate/agg_hash_partial_worker.go index 5151eef1895c7..0c7227c5b1c48 100644 --- a/pkg/executor/aggregate/agg_hash_partial_worker.go +++ b/pkg/executor/aggregate/agg_hash_partial_worker.go @@ -267,11 +267,12 @@ func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, chk * numRows := chk.NumRows() rows := make([]chunk.Row, 1) allMemDelta := int64(0) + exprCtx := ctx.GetExprCtx() for i := 0; i < numRows; i++ { partialResult := partialResultOfEachRow[i] rows[0] = chk.GetRow(i) for j, af := range w.aggFuncs { - memDelta, err := af.UpdatePartialResult(ctx, rows, partialResult[j]) + memDelta, err := af.UpdatePartialResult(exprCtx, rows, partialResult[j]) if err != nil { return err } diff --git a/pkg/executor/aggregate/agg_spill.go b/pkg/executor/aggregate/agg_spill.go index 59d73bbb10f94..8d6fd8090c3ec 100644 --- a/pkg/executor/aggregate/agg_spill.go +++ b/pkg/executor/aggregate/agg_spill.go @@ -291,9 +291,10 @@ func (p *parallelHashAggSpillHelper) processRow(context *processRowContext) (tot key := context.chunk.GetRow(context.rowPos).GetString(context.keyColPos) prs, ok := (*context.restoreadData)[key] if ok { + exprCtx := context.ctx.GetExprCtx() // The key has appeared before, merge results. for aggPos := 0; aggPos < context.aggFuncNum; aggPos++ { - memDelta, err := p.aggFuncsForRestoring[aggPos].MergePartialResult(context.ctx, context.partialResultsRestored[aggPos][context.rowPos], prs[aggPos]) + memDelta, err := p.aggFuncsForRestoring[aggPos].MergePartialResult(exprCtx, context.partialResultsRestored[aggPos][context.rowPos], prs[aggPos]) if err != nil { return totalMemDelta, 0, err } diff --git a/pkg/executor/aggregate/agg_spill_test.go b/pkg/executor/aggregate/agg_spill_test.go index 731d4c7324679..5e05b613d6595 100644 --- a/pkg/executor/aggregate/agg_spill_test.go +++ b/pkg/executor/aggregate/agg_spill_test.go @@ -191,12 +191,12 @@ func buildHashAggExecutor(t *testing.T, ctx sessionctx.Context, child exec.Execu schema := expression.NewSchema(childCols...) groupItems := []expression.Expression{childCols[0]} - aggFirstRow, err := aggregation.NewAggFuncDesc(ctx, ast.AggFuncFirstRow, []expression.Expression{childCols[0]}, false) + aggFirstRow, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{childCols[0]}, false) if err != nil { t.Fatal(err) } - aggFunc, err := aggregation.NewAggFuncDesc(ctx, ast.AggFuncSum, []expression.Expression{childCols[1]}, false) + aggFunc, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncSum, []expression.Expression{childCols[1]}, false) if err != nil { t.Fatal(err) } @@ -216,8 +216,8 @@ func buildHashAggExecutor(t *testing.T, ctx sessionctx.Context, child exec.Execu ordinal := []int{partialOrdinal} partialOrdinal++ partialAggDesc, finalDesc := aggDesc.Split(ordinal) - partialAggFunc := aggfuncs.Build(ctx, partialAggDesc, i) - finalAggFunc := aggfuncs.Build(ctx, finalDesc, i) + partialAggFunc := aggfuncs.Build(ctx.GetExprCtx(), partialAggDesc, i) + finalAggFunc := aggfuncs.Build(ctx.GetExprCtx(), finalDesc, i) aggExec.PartialAggFuncs = append(aggExec.PartialAggFuncs, partialAggFunc) aggExec.FinalAggFuncs = append(aggExec.FinalAggFuncs, finalAggFunc) } diff --git a/pkg/executor/aggregate/agg_stream_executor.go b/pkg/executor/aggregate/agg_stream_executor.go index ceb785d1745ff..be4cb215f782d 100644 --- a/pkg/executor/aggregate/agg_stream_executor.go +++ b/pkg/executor/aggregate/agg_stream_executor.go @@ -167,8 +167,9 @@ func (e *StreamAggExec) consumeGroupRows() error { } allMemDelta := int64(0) + exprCtx := e.Ctx().GetExprCtx() for i, aggFunc := range e.AggFuncs { - memDelta, err := aggFunc.UpdatePartialResult(e.Ctx(), e.groupRows, e.partialResults[i]) + memDelta, err := aggFunc.UpdatePartialResult(exprCtx, e.groupRows, e.partialResults[i]) if err != nil { return err } @@ -214,8 +215,9 @@ func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, ch // appendResult2Chunk appends result of all the aggregation functions to the // result chunk, and reset the evaluation context for each aggregation. func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { + exprCtx := e.Ctx().GetExprCtx() for i, aggFunc := range e.AggFuncs { - err := aggFunc.AppendFinalResult2Chunk(e.Ctx(), e.partialResults[i], chk) + err := aggFunc.AppendFinalResult2Chunk(exprCtx, e.partialResults[i], chk) if err != nil { return err } diff --git a/pkg/executor/aggregate/agg_util.go b/pkg/executor/aggregate/agg_util.go index 41e2220fff781..23ac736242f2e 100644 --- a/pkg/executor/aggregate/agg_util.go +++ b/pkg/executor/aggregate/agg_util.go @@ -76,6 +76,7 @@ func GetGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte, } errCtx := ctx.GetSessionVars().StmtCtx.ErrCtx() + exprCtx := ctx.GetExprCtx() for _, item := range groupByItems { tp := item.GetType() @@ -96,7 +97,7 @@ func GetGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte, tp = &newTp } - if err := expression.EvalExpr(ctx, item, tp.EvalType(), input, buf); err != nil { + if err := expression.EvalExpr(exprCtx, item, tp.EvalType(), input, buf); err != nil { expression.PutColumn(buf) return nil, err } diff --git a/pkg/executor/batch_checker.go b/pkg/executor/batch_checker.go index f871b83b60f9c..09e6a3132ac76 100644 --- a/pkg/executor/batch_checker.go +++ b/pkg/executor/batch_checker.go @@ -98,7 +98,7 @@ func getKeysNeedCheckOneRow(ctx sessionctx.Context, t table.Table, row []types.D pkIdxInfo *model.IndexInfo, result []toBeCheckedRow) ([]toBeCheckedRow, error) { var err error if p, ok := t.(table.PartitionedTable); ok { - t, err = p.GetPartitionByRow(ctx, row) + t, err = p.GetPartitionByRow(ctx.GetExprCtx(), row) if err != nil { if terr, ok := errors.Cause(err).(*terror.Error); ok && (terr.Code() == errno.ErrNoPartitionForGivenValue || terr.Code() == errno.ErrRowDoesNotMatchGivenPartitionSet) { ec := ctx.GetSessionVars().StmtCtx.ErrCtx() @@ -166,7 +166,7 @@ func getKeysNeedCheckOneRow(ctx sessionctx.Context, t table.Table, row []types.D if col.State != model.StatePublic { // only append origin default value for index fetch values if col.Offset >= len(row) { - value, err := table.GetColOriginDefaultValue(ctx, col.ToInfo()) + value, err := table.GetColOriginDefaultValue(ctx.GetExprCtx(), col.ToInfo()) if err != nil { return nil, err } @@ -278,11 +278,12 @@ func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, } // Fill write-only and write-reorg columns with originDefaultValue if not found in oldValue. gIdx := 0 + exprCtx := sctx.GetExprCtx() for _, col := range cols { if col.State != model.StatePublic && oldRow[col.Offset].IsNull() { _, found := oldRowMap[col.ID] if !found { - oldRow[col.Offset], err = table.GetColOriginDefaultValue(sctx, col.ToInfo()) + oldRow[col.Offset], err = table.GetColOriginDefaultValue(exprCtx, col.ToInfo()) if err != nil { return nil, err } @@ -292,7 +293,7 @@ func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, // only the virtual column needs fill back. // Insert doesn't fill the generated columns at non-public state. if !col.GeneratedStored { - val, err := genExprs[gIdx].Eval(sctx, chunk.MutRowFromDatums(oldRow).ToRow()) + val, err := genExprs[gIdx].Eval(sctx.GetExprCtx(), chunk.MutRowFromDatums(oldRow).ToRow()) if err != nil { return nil, err } diff --git a/pkg/executor/benchmark_test.go b/pkg/executor/benchmark_test.go index c4447b0897623..5eeb1a3a391b6 100644 --- a/pkg/executor/benchmark_test.go +++ b/pkg/executor/benchmark_test.go @@ -137,7 +137,7 @@ func buildAggExecutor(b *testing.B, testCase *testutil.AggTestCase, child exec.E childCols := testCase.Columns() schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc, err := aggregation.NewAggFuncDesc(testCase.Ctx, testCase.AggFunc, []expression.Expression{childCols[0]}, testCase.HasDistinct) + aggFunc, err := aggregation.NewAggFuncDesc(testCase.Ctx.GetExprCtx(), testCase.AggFunc, []expression.Expression{childCols[0]}, testCase.HasDistinct) if err != nil { b.Fatal(err) } @@ -298,7 +298,7 @@ func buildWindowExecutor(ctx sessionctx.Context, windowFunc string, funcs int, f default: args = append(args, partitionBy[0]) } - desc, _ := aggregation.NewWindowFuncDesc(ctx, windowFunc, args, false) + desc, _ := aggregation.NewWindowFuncDesc(ctx.GetExprCtx(), windowFunc, args, false) win.WindowFuncDescs = append(win.WindowFuncDescs, desc) winSchema.Append(&expression.Column{ diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index 54960749e35f3..b096a047a2992 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -1667,9 +1667,10 @@ func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) exec.Exec e.IsUnparallelExec = true } partialOrdinal := 0 + exprCtx := b.ctx.GetExprCtx() for i, aggDesc := range v.AggFuncs { if e.IsUnparallelExec { - e.PartialAggFuncs = append(e.PartialAggFuncs, aggfuncs.Build(b.ctx, aggDesc, i)) + e.PartialAggFuncs = append(e.PartialAggFuncs, aggfuncs.Build(exprCtx, aggDesc, i)) } else { ordinal := []int{partialOrdinal} partialOrdinal++ @@ -1678,8 +1679,8 @@ func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) exec.Exec partialOrdinal++ } partialAggDesc, finalDesc := aggDesc.Split(ordinal) - partialAggFunc := aggfuncs.Build(b.ctx, partialAggDesc, i) - finalAggFunc := aggfuncs.Build(b.ctx, finalDesc, i) + partialAggFunc := aggfuncs.Build(exprCtx, partialAggDesc, i) + finalAggFunc := aggfuncs.Build(exprCtx, finalDesc, i) e.PartialAggFuncs = append(e.PartialAggFuncs, partialAggFunc) e.FinalAggFuncs = append(e.FinalAggFuncs, finalAggFunc) if partialAggDesc.Name == ast.AggFuncGroupConcat { @@ -1704,9 +1705,10 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) exec. if b.err != nil { return nil } + exprCtx := b.ctx.GetExprCtx() e := &aggregate.StreamAggExec{ BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), src), - GroupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx, v.GroupByItems), + GroupChecker: vecgroupchecker.NewVecGroupChecker(exprCtx, v.GroupByItems), AggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), } @@ -1719,7 +1721,7 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) exec. } } for i, aggDesc := range v.AggFuncs { - aggFunc := aggfuncs.Build(b.ctx, aggDesc, i) + aggFunc := aggfuncs.Build(exprCtx, aggDesc, i) e.AggFuncs = append(e.AggFuncs, aggFunc) if e.DefaultVal != nil { value := aggDesc.GetDefaultValue() @@ -2622,7 +2624,7 @@ func (b *executorBuilder) buildAnalyzeSamplingPushdown( e.analyzePB.ColReq.PrimaryPrefixColumnIds = tables.PrimaryPrefixColumnIDs(task.TblInfo) } } - b.err = tables.SetPBColumnsDefaultValue(b.ctx, e.analyzePB.ColReq.ColumnsInfo, task.ColsInfo) + b.err = tables.SetPBColumnsDefaultValue(b.ctx.GetExprCtx(), e.analyzePB.ColReq.ColumnsInfo, task.ColsInfo) return &analyzeTask{taskType: colTask, colExec: e, job: job} } @@ -2779,7 +2781,7 @@ func (b *executorBuilder) buildAnalyzeColumnsPushdown( e.analyzePB.Tp = tipb.AnalyzeType_TypeMixed e.commonHandle = task.CommonHandleInfo } - b.err = tables.SetPBColumnsDefaultValue(b.ctx, e.analyzePB.ColReq.ColumnsInfo, cols) + b.err = tables.SetPBColumnsDefaultValue(b.ctx.GetExprCtx(), e.analyzePB.ColReq.ColumnsInfo, cols) return &analyzeTask{taskType: colTask, colExec: e, job: job} } @@ -2798,9 +2800,10 @@ func (b *executorBuilder) buildAnalyze(v *plannercore.Analyze) exec.Executor { if b.ctx.GetSessionVars().InRestrictedSQL { autoAnalyze = "auto " } + exprCtx := b.ctx.GetExprCtx() for _, task := range v.ColTasks { columns, _, err := expression.ColumnInfos2ColumnsAndNames( - b.ctx, + exprCtx, model.NewCIStr(task.AnalyzeInfo.DBName), task.TblInfo.Name, task.ColsInfo, @@ -3489,11 +3492,12 @@ func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table locateKey := make([]types.Datum, len(partitionTbl.Cols())) partitions := make(map[int64]table.PhysicalTable) contentPos = make([]int64, len(lookUpContent)) + exprCtx := builder.ctx.GetExprCtx() for idx, content := range lookUpContent { for i, data := range content.keys { locateKey[keyColOffsets[i]] = data } - p, err := partitionTbl.GetPartitionByRow(builder.ctx, locateKey) + p, err := partitionTbl.GetPartitionByRow(exprCtx, locateKey) if table.ErrNoPartitionForGivenValue.Equal(err) { continue } @@ -4124,11 +4128,12 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) // lookUpContentsByPID groups lookUpContents by pid(partition) so that kv ranges for same partition can be merged. lookUpContentsByPID := make(map[int64][]*indexJoinLookUpContent) + exprCtx := e.Ctx().GetExprCtx() for _, content := range lookUpContents { for i, data := range content.keys { locateKey[keyColOffsets[i]] = data } - p, err := pt.GetPartitionByRow(e.Ctx(), locateKey) + p, err := pt.GetPartitionByRow(exprCtx, locateKey) if table.ErrNoPartitionForGivenValue.Equal(err) { continue } @@ -4171,11 +4176,12 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte if len(keyColOffsets) > 0 { locateKey := make([]types.Datum, len(pt.Cols())) kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) + exprCtx := e.Ctx().GetExprCtx() for _, content := range lookUpContents { for i, data := range content.keys { locateKey[keyColOffsets[i]] = data } - p, err := pt.GetPartitionByRow(e.Ctx(), locateKey) + p, err := pt.GetPartitionByRow(exprCtx, locateKey) if table.ErrNoPartitionForGivenValue.Equal(err) { continue } @@ -4660,13 +4666,14 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut windowFuncs := make([]aggfuncs.AggFunc, 0, len(v.WindowFuncDescs)) partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) + exprCtx := b.ctx.GetExprCtx() for _, desc := range v.WindowFuncDescs { - aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(b.ctx, desc, false) + aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(exprCtx, desc, false) if err != nil { b.err = err return nil } - agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols) + agg := aggfuncs.BuildWindowFunctions(exprCtx, aggDesc, resultColIdx, orderByCols) windowFuncs = append(windowFuncs, agg) partialResult, _ := agg.AllocPartialResult() partialResults = append(partialResults, partialResult) @@ -4677,7 +4684,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut if b.ctx.GetSessionVars().EnablePipelinedWindowExec { exec := &PipelinedWindowExec{ BaseExecutor: base, - groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx, groupByItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx.GetExprCtx(), groupByItems), numWindowFuncs: len(v.WindowFuncDescs), } @@ -4755,7 +4762,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut } return &WindowExec{BaseExecutor: base, processor: processor, - groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx, groupByItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx.GetExprCtx(), groupByItems), numWindowFuncs: len(v.WindowFuncDescs), } } @@ -4903,7 +4910,7 @@ func NewRowDecoder(ctx sessionctx.Context, schema *expression.Schema, tbl *model } ci := getColInfoByID(tbl, reqCols[i].ID) - d, err := table.GetColOriginDefaultValue(ctx, ci) + d, err := table.GetColOriginDefaultValue(ctx.GetExprCtx(), ci) if err != nil { return err } diff --git a/pkg/executor/cluster_table_test.go b/pkg/executor/cluster_table_test.go index 373414b13d91d..9d5deb209fac0 100644 --- a/pkg/executor/cluster_table_test.go +++ b/pkg/executor/cluster_table_test.go @@ -301,7 +301,7 @@ func TestSQLDigestTextRetriever(t *testing.T) { updateDigest.String(): "", }, } - err := r.RetrieveLocal(context.Background(), tk.Session()) + err := r.RetrieveLocal(context.Background(), tk.Session().GetExprCtx()) require.NoError(t, err) require.Equal(t, insertNormalized, r.SQLDigestsMap[insertDigest.String()]) require.Equal(t, "", r.SQLDigestsMap[updateDigest.String()]) diff --git a/pkg/executor/distsql.go b/pkg/executor/distsql.go index 6430c5af3adb9..1e01546905f8b 100644 --- a/pkg/executor/distsql.go +++ b/pkg/executor/distsql.go @@ -152,7 +152,7 @@ func closeAll(objs ...Closeable) error { func rebuildIndexRanges(ctx sessionctx.Context, is *plannercore.PhysicalIndexScan, idxCols []*expression.Column, colLens []int) (ranges []*ranger.Range, err error) { access := make([]expression.Expression, 0, len(is.AccessCondition)) for _, cond := range is.AccessCondition { - newCond, err1 := expression.SubstituteCorCol2Constant(ctx, cond) + newCond, err1 := expression.SubstituteCorCol2Constant(ctx.GetExprCtx(), cond) if err1 != nil { return nil, err1 } @@ -341,11 +341,11 @@ func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) args = append(args, expression.NewInt64Const(pid)) } - inCondition, err := expression.NewFunction(e.Ctx(), ast.In, types.NewFieldType(mysql.TypeLonglong), args...) + inCondition, err := expression.NewFunction(e.Ctx().GetExprCtx(), ast.In, types.NewFieldType(mysql.TypeLonglong), args...) if err != nil { return err } - pbConditions, err := expression.ExpressionsToPBList(e.Ctx(), []expression.Expression{inCondition}, e.Ctx().GetClient()) + pbConditions, err := expression.ExpressionsToPBList(e.Ctx().GetExprCtx(), []expression.Expression{inCondition}, e.Ctx().GetClient()) if err != nil { return err } diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 8677ae8528492..8560997f7988b 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -1621,7 +1621,7 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.childResult.NumRows() == 0 { return nil } - e.selected, err = expression.VectorizedFilter(e.Ctx(), e.filters, e.inputIter, e.selected) + e.selected, err = expression.VectorizedFilter(e.Ctx().GetExprCtx(), e.filters, e.inputIter, e.selected) if err != nil { return err } @@ -1633,9 +1633,10 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.Chunk) error { // For sql with "SETVAR" in filter and "GETVAR" in projection, for example: "SELECT @a FROM t WHERE (@a := 2) > 0", // we have to set batch size to 1 to do the evaluation of filter and projection. func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) error { + exprCtx := e.Ctx().GetExprCtx() for { for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - selected, _, err := expression.EvalBool(e.Ctx(), e.filters, e.inputRow) + selected, _, err := expression.EvalBool(exprCtx, e.filters, e.inputRow) if err != nil { return err } diff --git a/pkg/executor/executor_required_rows_test.go b/pkg/executor/executor_required_rows_test.go index fcfd719d1d027..ad222b6ddf920 100644 --- a/pkg/executor/executor_required_rows_test.go +++ b/pkg/executor/executor_required_rows_test.go @@ -459,7 +459,7 @@ func TestSelectionRequiredRows(t *testing.T) { } else { ds = newRequiredRowsDataSourceWithGenerator(sctx, testCase.totalRows, testCase.expectedRowsDS, testCase.gen) f, err := expression.NewFunction( - sctx, ast.EQ, types.NewFieldType(byte(types.ETInt)), ds.Schema().Columns[1], &expression.Constant{ + sctx.GetExprCtx(), ast.EQ, types.NewFieldType(byte(types.ETInt)), ds.Schema().Columns[1], &expression.Constant{ Value: types.NewDatum(testCase.filtersOfCol1), RetType: types.NewFieldType(mysql.TypeTiny), }) @@ -667,7 +667,7 @@ func TestStreamAggRequiredRows(t *testing.T) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + aggFunc, err := aggregation.NewAggFuncDesc(sctx.GetExprCtx(), testCase.aggFunc, []expression.Expression{childCols[0]}, true) require.NoError(t, err) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} executor := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy, 1, true) diff --git a/pkg/executor/importer/import.go b/pkg/executor/importer/import.go index b9bef1cdd20e0..d5d31b0885718 100644 --- a/pkg/executor/importer/import.go +++ b/pkg/executor/importer/import.go @@ -598,7 +598,7 @@ func (p *Plan) initOptions(ctx context.Context, seCtx sessionctx.Context, option if opt.Value.GetType().GetType() != mysql.TypeVarString { return "", exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name) } - val, isNull, err2 := opt.Value.EvalString(seCtx, chunk.Row{}) + val, isNull, err2 := opt.Value.EvalString(seCtx.GetExprCtx(), chunk.Row{}) if err2 != nil || isNull { return "", exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name) } @@ -609,7 +609,7 @@ func (p *Plan) initOptions(ctx context.Context, seCtx sessionctx.Context, option if opt.Value.GetType().GetType() != mysql.TypeLonglong || mysql.HasIsBooleanFlag(opt.Value.GetType().GetFlag()) { return 0, exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name) } - val, isNull, err2 := opt.Value.EvalInt(seCtx, chunk.Row{}) + val, isNull, err2 := opt.Value.EvalInt(seCtx.GetExprCtx(), chunk.Row{}) if err2 != nil || isNull { return 0, exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name) } diff --git a/pkg/executor/importer/kv_encode.go b/pkg/executor/importer/kv_encode.go index 5db149f369f24..7436cc0eed7df 100644 --- a/pkg/executor/importer/kv_encode.go +++ b/pkg/executor/importer/kv_encode.go @@ -142,7 +142,7 @@ func (en *tableKVEncoder) parserData2TableData(parserData []types.Datum, rowID i } for i := 0; i < len(en.columnAssignments); i++ { // eval expression of `SET` clause - d, err := en.columnAssignments[i].Eval(en.SessionCtx, chunk.Row{}) + d, err := en.columnAssignments[i].Eval(en.SessionCtx.GetExprCtx(), chunk.Row{}) if err != nil { return nil, err } diff --git a/pkg/executor/index_lookup_join.go b/pkg/executor/index_lookup_join.go index 9c40333ee8201..e8fcc3486f42c 100644 --- a/pkg/executor/index_lookup_join.go +++ b/pkg/executor/index_lookup_join.go @@ -456,11 +456,12 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { if ow.filter != nil { task.outerMatch = make([][]bool, task.outerResult.NumChunks()) var err error + exprCtx := ow.ctx.GetExprCtx() for i := 0; i < numChks; i++ { chk := task.outerResult.GetChunk(i) outerMatch := make([]bool, 0, chk.NumRows()) task.memTracker.Consume(int64(cap(outerMatch))) - task.outerMatch[i], err = expression.VectorizedFilter(ow.ctx, ow.filter, chunk.NewIterator4Chunk(chk), outerMatch) + task.outerMatch[i], err = expression.VectorizedFilter(exprCtx, ow.filter, chunk.NewIterator4Chunk(chk), outerMatch) if err != nil { return task, err } diff --git a/pkg/executor/index_lookup_merge_join.go b/pkg/executor/index_lookup_merge_join.go index f526058962998..2904313094490 100644 --- a/pkg/executor/index_lookup_merge_join.go +++ b/pkg/executor/index_lookup_merge_join.go @@ -422,10 +422,11 @@ func (imw *innerMergeWorker) handleTask(ctx context.Context, task *lookUpMergeJo numOuterChks := task.outerResult.NumChunks() if imw.outerMergeCtx.filter != nil { task.outerMatch = make([][]bool, numOuterChks) + exprCtx := imw.ctx.GetExprCtx() for i := 0; i < numOuterChks; i++ { chk := task.outerResult.GetChunk(i) task.outerMatch[i] = make([]bool, chk.NumRows()) - task.outerMatch[i], err = expression.VectorizedFilter(imw.ctx, imw.outerMergeCtx.filter, chunk.NewIterator4Chunk(chk), task.outerMatch[i]) + task.outerMatch[i], err = expression.VectorizedFilter(exprCtx, imw.outerMergeCtx.filter, chunk.NewIterator4Chunk(chk), task.outerMatch[i]) if err != nil { return err } @@ -449,13 +450,14 @@ func (imw *innerMergeWorker) handleTask(ctx context.Context, task *lookUpMergeJo // Because the necessary condition of merge join is both outer and inner keep order of join keys. // In this case, we need sort the outer side. if imw.outerMergeCtx.needOuterSort { + exprCtx := imw.ctx.GetExprCtx() slices.SortFunc(task.outerOrderIdx, func(idxI, idxJ chunk.RowPtr) int { rowI, rowJ := task.outerResult.GetRow(idxI), task.outerResult.GetRow(idxJ) var c int64 var err error for _, keyOff := range imw.keyOff2KeyOffOrderByIdx { joinKey := imw.outerMergeCtx.joinKeys[keyOff] - c, _, err = imw.outerMergeCtx.compareFuncs[keyOff](imw.ctx, joinKey, joinKey, rowI, rowJ) + c, _, err = imw.outerMergeCtx.compareFuncs[keyOff](exprCtx, joinKey, joinKey, rowI, rowJ) terror.Log(err) if c != 0 { break @@ -628,8 +630,9 @@ func (imw *innerMergeWorker) fetchInnerRowsWithSameKey(ctx context.Context, task } func (imw *innerMergeWorker) compare(outerRow, innerRow chunk.Row) (int, error) { + exprCtx := imw.ctx.GetExprCtx() for _, keyOff := range imw.innerMergeCtx.keyOff2KeyOffOrderByIdx { - cmp, _, err := imw.innerMergeCtx.compareFuncs[keyOff](imw.ctx, imw.outerMergeCtx.joinKeys[keyOff], imw.innerMergeCtx.joinKeys[keyOff], outerRow, innerRow) + cmp, _, err := imw.innerMergeCtx.compareFuncs[keyOff](exprCtx, imw.outerMergeCtx.joinKeys[keyOff], imw.innerMergeCtx.joinKeys[keyOff], outerRow, innerRow) if err != nil || cmp != 0 { return int(cmp), err } diff --git a/pkg/executor/infoschema_reader.go b/pkg/executor/infoschema_reader.go index cc2ed8a4d07b3..54fe7243a3e5f 100644 --- a/pkg/executor/infoschema_reader.go +++ b/pkg/executor/infoschema_reader.go @@ -951,7 +951,7 @@ ForColumnsTag: case "CURRENT_TIMESTAMP": default: if ft.GetType() == mysql.TypeTimestamp && columnDefault != types.ZeroDatetimeStr { - timeValue, err := table.GetColDefaultValue(sctx, col) + timeValue, err := table.GetColDefaultValue(sctx.GetExprCtx(), col) if err == nil { columnDefault = timeValue.GetMysqlTime().String() } @@ -2650,7 +2650,7 @@ func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Co } // Retrieve the SQL texts if necessary. if sqlRetriever != nil { - err1 := sqlRetriever.RetrieveLocal(ctx, sctx) + err1 := sqlRetriever.RetrieveLocal(ctx, sctx.GetExprCtx()) if err1 != nil { return errors.Trace(err1) } @@ -2771,7 +2771,7 @@ func (r *dataLockWaitsTableRetriever) retrieve(ctx context.Context, sctx session sqlRetriever.SQLDigestsMap[digest] = "" } } - err := sqlRetriever.RetrieveGlobal(ctx, sctx) + err := sqlRetriever.RetrieveGlobal(ctx, sctx.GetExprCtx()) if err != nil { return errors.Trace(err) } @@ -2969,7 +2969,7 @@ func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx. } // Retrieve the SQL texts if necessary. if sqlRetriever != nil { - err1 := sqlRetriever.RetrieveGlobal(ctx, sctx) + err1 := sqlRetriever.RetrieveGlobal(ctx, sctx.GetExprCtx()) if err1 != nil { return errors.Trace(err1) } diff --git a/pkg/executor/insert.go b/pkg/executor/insert.go index 806a7c62b02d7..99bf773f2a63d 100644 --- a/pkg/executor/insert.go +++ b/pkg/executor/insert.go @@ -397,13 +397,14 @@ func (e *InsertExec) doDupRowUpdate(ctx context.Context, handle kv.Handle, oldRo // Update old row when the key is duplicated. e.evalBuffer4Dup.SetDatums(e.row4Update...) sctx := e.Ctx() + exprCtx := sctx.GetExprCtx() sc := sctx.GetSessionVars().StmtCtx warnCnt := int(sc.WarningCount()) for _, col := range cols { if col.LazyErr != nil { return col.LazyErr } - val, err1 := col.Expr.Eval(sctx, e.evalBuffer4Dup.ToRow()) + val, err1 := col.Expr.Eval(exprCtx, e.evalBuffer4Dup.ToRow()) if err1 != nil { return err1 } diff --git a/pkg/executor/insert_common.go b/pkg/executor/insert_common.go index af4b5e271bfa8..562841a6d39b6 100644 --- a/pkg/executor/insert_common.go +++ b/pkg/executor/insert_common.go @@ -348,10 +348,11 @@ func (e *InsertValues) evalRow(ctx context.Context, list []expression.Expression e.evalBuffer.SetDatums(row...) sctx := e.Ctx() + exprCtx := sctx.GetExprCtx() sc := sctx.GetSessionVars().StmtCtx warnCnt := int(sc.WarningCount()) for i, expr := range list { - val, err := expr.Eval(sctx, e.evalBuffer.ToRow()) + val, err := expr.Eval(exprCtx, e.evalBuffer.ToRow()) if err != nil { return nil, err } @@ -388,11 +389,12 @@ func (e *InsertValues) fastEvalRow(ctx context.Context, list []expression.Expres row := make([]types.Datum, rowLen) hasValue := make([]bool, rowLen) sctx := e.Ctx() + exprCtx := sctx.GetExprCtx() sc := sctx.GetSessionVars().StmtCtx warnCnt := int(sc.WarningCount()) for i, expr := range list { con := expr.(*expression.Constant) - val, err := con.Eval(sctx, emptyRow) + val, err := con.Eval(exprCtx, emptyRow) if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil { return nil, err } @@ -580,9 +582,9 @@ func (e *InsertValues) getColDefaultValue(idx int, col *table.Column) (d types.D var defaultVal types.Datum if col.DefaultIsExpr && col.DefaultExpr != nil { - defaultVal, err = table.EvalColDefaultExpr(e.Ctx(), col.ToInfo(), col.DefaultExpr) + defaultVal, err = table.EvalColDefaultExpr(e.Ctx().GetExprCtx(), col.ToInfo(), col.DefaultExpr) } else { - defaultVal, err = table.GetColDefaultValue(e.Ctx(), col.ToInfo()) + defaultVal, err = table.GetColDefaultValue(e.Ctx().GetExprCtx(), col.ToInfo()) } if err != nil { return types.Datum{}, err @@ -701,11 +703,12 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue } sctx := e.Ctx() + exprCtx := sctx.GetExprCtx() sc := sctx.GetSessionVars().StmtCtx warnCnt := int(sc.WarningCount()) for i, gCol := range gCols { colIdx := gCol.ColumnInfo.Offset - val, err := e.GenExprs[i].Eval(sctx, chunk.MutRowFromDatums(row).ToRow()) + val, err := e.GenExprs[i].Eval(exprCtx, chunk.MutRowFromDatums(row).ToRow()) if err != nil && gCol.FieldType.IsArray() { return nil, completeError(tbl, gCol.Offset, rowIdx, err) } diff --git a/pkg/executor/internal/builder/builder_utils.go b/pkg/executor/internal/builder/builder_utils.go index 6d2bb03f4fa31..a29754ab352bb 100644 --- a/pkg/executor/internal/builder/builder_utils.go +++ b/pkg/executor/internal/builder/builder_utils.go @@ -25,7 +25,7 @@ import ( // ConstructTreeBasedDistExec constructs tree based DAGRequest func ConstructTreeBasedDistExec(sctx sessionctx.Context, p plannercore.PhysicalPlan) ([]*tipb.Executor, error) { - execPB, err := p.ToPB(sctx, kv.TiFlash) + execPB, err := p.ToPB(sctx.GetPlanCtx(), kv.TiFlash) return []*tipb.Executor{execPB}, err } @@ -33,7 +33,7 @@ func ConstructTreeBasedDistExec(sctx sessionctx.Context, p plannercore.PhysicalP func ConstructListBasedDistExec(sctx sessionctx.Context, plans []plannercore.PhysicalPlan) ([]*tipb.Executor, error) { executors := make([]*tipb.Executor, 0, len(plans)) for _, p := range plans { - execPB, err := p.ToPB(sctx, kv.TiKV) + execPB, err := p.ToPB(sctx.GetPlanCtx(), kv.TiKV) if err != nil { return nil, err } diff --git a/pkg/executor/internal/querywatch/query_watch.go b/pkg/executor/internal/querywatch/query_watch.go index c0ca9db8af3f2..4ffe51d1c67a6 100644 --- a/pkg/executor/internal/querywatch/query_watch.go +++ b/pkg/executor/internal/querywatch/query_watch.go @@ -48,7 +48,7 @@ func setWatchOption(ctx context.Context, if err != nil { return err } - name, isNull, err := expr.EvalString(sctx, chunk.Row{}) + name, isNull, err := expr.EvalString(sctx.GetExprCtx(), chunk.Row{}) if err != nil { return err } @@ -66,7 +66,7 @@ func setWatchOption(ctx context.Context, if err != nil { return err } - strval, isNull, err := expr.EvalString(sctx, chunk.Row{}) + strval, isNull, err := expr.EvalString(sctx.GetExprCtx(), chunk.Row{}) if err != nil { return err } diff --git a/pkg/executor/internal/vecgroupchecker/BUILD.bazel b/pkg/executor/internal/vecgroupchecker/BUILD.bazel index 2c2cdcab18b5d..ba5fce361b700 100644 --- a/pkg/executor/internal/vecgroupchecker/BUILD.bazel +++ b/pkg/executor/internal/vecgroupchecker/BUILD.bazel @@ -7,7 +7,6 @@ go_library( visibility = ["//pkg/executor:__subpackages__"], deps = [ "//pkg/expression", - "//pkg/sessionctx", "//pkg/types", "//pkg/util/chunk", "//pkg/util/codec", diff --git a/pkg/executor/internal/vecgroupchecker/vec_group_checker.go b/pkg/executor/internal/vecgroupchecker/vec_group_checker.go index e50e35415e951..75ff6176cbd2d 100644 --- a/pkg/executor/internal/vecgroupchecker/vec_group_checker.go +++ b/pkg/executor/internal/vecgroupchecker/vec_group_checker.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" @@ -28,7 +27,7 @@ import ( // VecGroupChecker is used to split a given chunk according to the `group by` expression in a vectorized manner // It is usually used for streamAgg type VecGroupChecker struct { - ctx sessionctx.Context + ctx expression.EvalContext releaseBuffer func(buf *chunk.Column) // set these functions for testing @@ -61,7 +60,7 @@ type VecGroupChecker struct { } // NewVecGroupChecker creates a new VecGroupChecker -func NewVecGroupChecker(ctx sessionctx.Context, items []expression.Expression) *VecGroupChecker { +func NewVecGroupChecker(ctx expression.EvalContext, items []expression.Expression) *VecGroupChecker { return &VecGroupChecker{ ctx: ctx, GroupByItems: items, diff --git a/pkg/executor/join.go b/pkg/executor/join.go index 9eed4654ea5f1..96afc1af3723c 100644 --- a/pkg/executor/join.go +++ b/pkg/executor/join.go @@ -987,7 +987,7 @@ func (w *probeWorker) getNewJoinResult() (bool, *hashjoinWorkerResult) { func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, joinResult *hashjoinWorkerResult, selected []bool) (ok bool, _ *hashjoinWorkerResult) { var err error - selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected) + selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx.GetExprCtx(), w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected) if err != nil { joinResult.err = err return false, joinResult @@ -1249,7 +1249,7 @@ func (w *buildWorker) buildHashTableForList(buildSideResultCh <-chan *chunk.Chun if len(w.hashJoinCtx.outerFilter) == 0 { err = w.hashJoinCtx.rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ) } else { - selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(chk), selected) + selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx.GetExprCtx(), w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(chk), selected) if err != nil { return err } @@ -1396,7 +1396,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch if e.outerChunk.NumRows() == 0 { return nil, nil } - e.outerSelected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, outerIter, e.outerSelected) + e.outerSelected, err = expression.VectorizedFilter(e.ctx.GetExprCtx(), e.outerFilter, outerIter, e.outerSelected) if err != nil { return nil, err } @@ -1447,7 +1447,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { return nil } - e.innerSelected, err = expression.VectorizedFilter(e.ctx, e.innerFilter, innerIter, e.innerSelected) + e.innerSelected, err = expression.VectorizedFilter(e.ctx.GetExprCtx(), e.innerFilter, innerIter, e.innerSelected) if err != nil { return err } diff --git a/pkg/executor/joiner.go b/pkg/executor/joiner.go index 16b961d2ca2a9..50dca9781aee7 100644 --- a/pkg/executor/joiner.go +++ b/pkg/executor/joiner.go @@ -261,7 +261,7 @@ func (j *baseJoiner) makeShallowJoinRow(isRightJoin bool, inner, outer chunk.Row // indicates whether the outer row matches any inner rows. func (j *baseJoiner) filter(input, output *chunk.Chunk, outerColLen int, lUsed, rUsed []int) (bool, error) { var err error - j.selected, err = expression.VectorizedFilter(j.ctx, j.conditions, chunk.NewIterator4Chunk(input), j.selected) + j.selected, err = expression.VectorizedFilter(j.ctx.GetExprCtx(), j.conditions, chunk.NewIterator4Chunk(input), j.selected) if err != nil { return false, err } @@ -301,7 +301,7 @@ func (j *baseJoiner) filterAndCheckOuterRowStatus( input, output *chunk.Chunk, innerColsLen int, outerRowStatus []outerRowStatusFlag, lUsed, rUsed []int) ([]outerRowStatusFlag, error) { var err error - j.selected, j.isNull, err = expression.VectorizedFilterConsiderNull(j.ctx, j.conditions, chunk.NewIterator4Chunk(input), j.selected, j.isNull) + j.selected, j.isNull, err = expression.VectorizedFilterConsiderNull(j.ctx.GetExprCtx(), j.conditions, chunk.NewIterator4Chunk(input), j.selected, j.isNull) if err != nil { return nil, err } @@ -376,12 +376,13 @@ func (j *semiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator, ch return true, false, nil } + exprCtx := j.ctx.GetExprCtx() for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(j.outerIsRight, inner, outer) // For SemiJoin, we can safely treat null result of join conditions as false, // so we ignore the nullness returned by EvalBool here. - matched, _, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, _, err = expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return false, false, err } @@ -405,11 +406,12 @@ func (j *semiJoiner) tryToMatchOuters(outers chunk.Iterator, inner chunk.Row, ch } return outerRowStatus, nil } + exprCtx := j.ctx.GetExprCtx() for outer := outers.Current(); outer != outers.End() && numToAppend > 0; outer, numToAppend = outers.Next(), numToAppend-1 { j.makeShallowJoinRow(j.outerIsRight, inner, outer) // For SemiJoin, we can safely treat null result of join conditions as false, // so we ignore the nullness returned by EvalBool here. - matched, _, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, _, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return outerRowStatus, err } @@ -467,9 +469,10 @@ func (naaj *nullAwareAntiSemiJoiner) tryToMatchInners(outer chunk.Row, inners ch inners.ReachEnd() return true, false, nil } + exprCtx := naaj.ctx.GetExprCtx() for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { naaj.makeShallowJoinRow(naaj.outerIsRight, inner, outer) - valid, _, err := expression.EvalBool(naaj.ctx, naaj.conditions, naaj.shallowRow.ToRow()) + valid, _, err := expression.EvalBool(exprCtx, naaj.conditions, naaj.shallowRow.ToRow()) if err != nil { return false, false, err } @@ -517,10 +520,11 @@ func (j *antiSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Iterator return true, false, nil } + exprCtx := j.ctx.GetExprCtx() for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(j.outerIsRight, inner, outer) - matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return false, false, err } @@ -543,9 +547,10 @@ func (j *antiSemiJoiner) tryToMatchOuters(outers chunk.Iterator, inner chunk.Row } return outerRowStatus, nil } + exprCtx := j.ctx.GetExprCtx() for outer := outers.Current(); outer != outers.End() && numToAppend > 0; outer, numToAppend = outers.Next(), numToAppend-1 { j.makeShallowJoinRow(j.outerIsRight, inner, outer) - matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return outerRowStatus, err } @@ -591,10 +596,11 @@ func (j *leftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk.Ite return true, false, nil } + exprCtx := j.ctx.GetExprCtx() for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(false, inner, outer) - matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return false, false, err } @@ -620,9 +626,10 @@ func (j *leftOuterSemiJoiner) tryToMatchOuters(outers chunk.Iterator, inner chun return outerRowStatus, nil } + exprCtx := j.ctx.GetExprCtx() for ; outer != outers.End() && numToAppend > 0; outer, numToAppend = outers.Next(), numToAppend-1 { j.makeShallowJoinRow(false, inner, outer) - matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return nil, err } @@ -680,10 +687,12 @@ func (naal *nullAwareAntiLeftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners.ReachEnd() return true, false, nil } + + exprCtx := naal.ctx.GetExprCtx() for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { naal.makeShallowJoinRow(false, inner, outer) - valid, _, err := expression.EvalBool(naal.ctx, naal.conditions, naal.shallowRow.ToRow()) + valid, _, err := expression.EvalBool(exprCtx, naal.conditions, naal.shallowRow.ToRow()) if err != nil { return false, false, err } @@ -753,10 +762,11 @@ func (j *antiLeftOuterSemiJoiner) tryToMatchInners(outer chunk.Row, inners chunk return true, false, nil } + exprCtx := j.ctx.GetExprCtx() for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(false, inner, outer) - matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return false, false, err } @@ -782,9 +792,10 @@ func (j *antiLeftOuterSemiJoiner) tryToMatchOuters(outers chunk.Iterator, inner return outerRowStatus, nil } + exprCtx := j.ctx.GetExprCtx() for i := 0; outer != outers.End() && numToAppend > 0; outer, numToAppend, i = outers.Next(), numToAppend-1, i+1 { j.makeShallowJoinRow(false, inner, outer) - matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(exprCtx, j.conditions, j.shallowRow.ToRow()) if err != nil { return nil, err } diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index fb3482637dcf8..e3e0e5f6a325b 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -540,7 +540,7 @@ func (w *encodeWorker) parserData2TableData( } for i := 0; i < len(w.colAssignExprs); i++ { // eval expression of `SET` clause - d, err := w.colAssignExprs[i].Eval(w.Ctx(), chunk.Row{}) + d, err := w.colAssignExprs[i].Eval(w.Ctx().GetExprCtx(), chunk.Row{}) if err != nil { if w.controller.Restrictive { return nil, err diff --git a/pkg/executor/mem_reader.go b/pkg/executor/mem_reader.go index 2fcb1ed766edc..2f84d2eeff3ac 100644 --- a/pkg/executor/mem_reader.go +++ b/pkg/executor/mem_reader.go @@ -155,7 +155,7 @@ func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error } mutableRow.SetDatums(data...) - matched, _, err := expression.EvalBool(m.ctx, m.conditions, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(m.ctx.GetExprCtx(), m.conditions, mutableRow.ToRow()) if err != nil || !matched { return err } @@ -248,7 +248,7 @@ func buildMemTableReader(ctx context.Context, us *UnionScanExec, kvRanges []kv.K } defVal := func(i int) ([]byte, error) { - d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(us.Ctx(), us.columns[i]) + d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(us.Ctx().GetExprCtx(), us.columns[i]) if err != nil { return nil, err } @@ -414,7 +414,7 @@ func (m *memTableReader) getMemRows(ctx context.Context) ([][]types.Datum, error } mutableRow.SetDatums(resultRows...) - matched, _, err := expression.EvalBool(m.ctx, m.conditions, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(m.ctx.GetExprCtx(), m.conditions, mutableRow.ToRow()) if err != nil || !matched { return err } @@ -905,7 +905,7 @@ func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { mutableRow := chunk.MutRowFromTypes(iter.retFieldTypes) mutableRow.SetDatums(iter.datumRow...) - matched, _, err := expression.EvalBool(iter.ctx, iter.conditions, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(iter.ctx.GetExprCtx(), iter.conditions, mutableRow.ToRow()) if err != nil { return nil, errors.Trace(err) } @@ -921,7 +921,7 @@ func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { } row := iter.chk.GetRow(0) - matched, _, err := expression.EvalBool(iter.ctx, iter.conditions, row) + matched, _, err := expression.EvalBool(iter.ctx.GetExprCtx(), iter.conditions, row) if err != nil { return nil, errors.Trace(err) } @@ -962,7 +962,7 @@ func (iter *memRowsIterForIndex) Next() ([]types.Datum, error) { } iter.mutableRow.SetDatums(data...) - matched, _, err := expression.EvalBool(iter.memIndexReader.ctx, iter.memIndexReader.conditions, iter.mutableRow.ToRow()) + matched, _, err := expression.EvalBool(iter.memIndexReader.ctx.GetExprCtx(), iter.memIndexReader.conditions, iter.mutableRow.ToRow()) if err != nil { return nil, errors.Trace(err) } @@ -1158,7 +1158,7 @@ func getColIDAndPkColIDs(ctx sessionctx.Context, tbl table.Table, columns []*mod pkColIDs = []int64{-1} } defVal := func(i int) ([]byte, error) { - d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx, columns[i]) + d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx.GetExprCtx(), columns[i]) if err != nil { return nil, err } diff --git a/pkg/executor/merge_join.go b/pkg/executor/merge_join.go index 2f129b1babeb0..fc7e2a3f0cc52 100644 --- a/pkg/executor/merge_join.go +++ b/pkg/executor/merge_join.go @@ -86,7 +86,7 @@ func (t *mergeJoinTable) init(executor *MergeJoinExec) { for _, col := range t.joinKeys { items = append(items, col) } - t.groupChecker = vecgroupchecker.NewVecGroupChecker(executor.Ctx(), items) + t.groupChecker = vecgroupchecker.NewVecGroupChecker(executor.Ctx().GetExprCtx(), items) t.groupRowsIter = chunk.NewIterator4Chunk(t.childChunk) if t.isInner { @@ -257,7 +257,7 @@ func (t *mergeJoinTable) fetchNextOuterGroup(ctx context.Context, exec *MergeJoi } t.childChunkIter.Begin() - t.filtersSelected, err = expression.VectorizedFilter(exec.Ctx(), t.filters, t.childChunkIter, t.filtersSelected) + t.filtersSelected, err = expression.VectorizedFilter(exec.Ctx().GetExprCtx(), t.filters, t.childChunkIter, t.filtersSelected) if err != nil { return err } @@ -401,7 +401,7 @@ func (e *MergeJoinExec) compare(outerRow, innerRow chunk.Row) (int, error) { outerJoinKeys := e.outerTable.joinKeys innerJoinKeys := e.innerTable.joinKeys for i := range outerJoinKeys { - cmp, _, err := e.compareFuncs[i](e.Ctx(), outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) + cmp, _, err := e.compareFuncs[i](e.Ctx().GetPlanCtx(), outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) if err != nil { return 0, err } diff --git a/pkg/executor/parallel_apply.go b/pkg/executor/parallel_apply.go index d697b4f4cf0f5..d1fdaeed213b5 100644 --- a/pkg/executor/parallel_apply.go +++ b/pkg/executor/parallel_apply.go @@ -219,7 +219,7 @@ func (e *ParallelNestedLoopApplyExec) outerWorker(ctx context.Context) { } e.outerList.Add(chk) outerIter := chunk.NewIterator4Chunk(chk) - selected, err = expression.VectorizedFilter(e.Ctx(), e.outerFilter, outerIter, selected) + selected, err = expression.VectorizedFilter(e.Ctx().GetExprCtx(), e.outerFilter, outerIter, selected) if err != nil { e.putResult(nil, err) return @@ -326,7 +326,7 @@ func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int break } - e.innerSelected[id], err = expression.VectorizedFilter(e.Ctx(), e.innerFilter[id], innerIter, e.innerSelected[id]) + e.innerSelected[id], err = expression.VectorizedFilter(e.Ctx().GetExprCtx(), e.innerFilter[id], innerIter, e.innerSelected[id]) if err != nil { return err } diff --git a/pkg/executor/pipelined_window.go b/pkg/executor/pipelined_window.go index 2b43349fcd6fa..d661a4cde2e0d 100644 --- a/pkg/executor/pipelined_window.go +++ b/pkg/executor/pipelined_window.go @@ -263,7 +263,7 @@ func (e *PipelinedWindowExec) getStart(ctx sessionctx.Context) (uint64, error) { var res int64 var err error for i := range e.orderByCols { - res, _, err = e.start.CmpFuncs[i](ctx, e.start.CompareCols[i], e.start.CalcFuncs[i], e.getRow(start), e.getRow(e.curRowIdx)) + res, _, err = e.start.CmpFuncs[i](ctx.GetExprCtx(), e.start.CompareCols[i], e.start.CalcFuncs[i], e.getRow(start), e.getRow(e.curRowIdx)) if err != nil { return 0, err } @@ -303,7 +303,7 @@ func (e *PipelinedWindowExec) getEnd(ctx sessionctx.Context) (uint64, error) { var res int64 var err error for i := range e.orderByCols { - res, _, err = e.end.CmpFuncs[i](ctx, e.end.CalcFuncs[i], e.end.CompareCols[i], e.getRow(e.curRowIdx), e.getRow(end)) + res, _, err = e.end.CmpFuncs[i](ctx.GetExprCtx(), e.end.CalcFuncs[i], e.end.CompareCols[i], e.getRow(e.curRowIdx), e.getRow(end)) if err != nil { return 0, err } @@ -369,7 +369,7 @@ func (e *PipelinedWindowExec) produce(ctx sessionctx.Context, chk *chunk.Chunk, if !e.emptyFrame { wf.ResetPartialResult(e.partialResults[i]) } - err = wf.AppendFinalResult2Chunk(ctx, e.partialResults[i], chk) + err = wf.AppendFinalResult2Chunk(ctx.GetExprCtx(), e.partialResults[i], chk) if err != nil { return } @@ -384,7 +384,7 @@ func (e *PipelinedWindowExec) produce(ctx sessionctx.Context, chk *chunk.Chunk, slidingWindowAggFunc := e.slidingWindowFuncs[i] if e.lastStartRow != start || e.lastEndRow != end { if slidingWindowAggFunc != nil && e.initializedSlidingWindow { - err = slidingWindowAggFunc.Slide(ctx, e.getRow, e.lastStartRow, e.lastEndRow, start-e.lastStartRow, end-e.lastEndRow, e.partialResults[i]) + err = slidingWindowAggFunc.Slide(ctx.GetExprCtx(), e.getRow, e.lastStartRow, e.lastEndRow, start-e.lastStartRow, end-e.lastEndRow, e.partialResults[i]) } else { // For MinMaxSlidingWindowAggFuncs, it needs the absolute value of each start of window, to compare // whether elements inside deque are out of current window. @@ -394,13 +394,13 @@ func (e *PipelinedWindowExec) produce(ctx sessionctx.Context, chk *chunk.Chunk, } // TODO(zhifeng): track memory usage here wf.ResetPartialResult(e.partialResults[i]) - _, err = wf.UpdatePartialResult(ctx, e.getRows(start, end), e.partialResults[i]) + _, err = wf.UpdatePartialResult(ctx.GetExprCtx(), e.getRows(start, end), e.partialResults[i]) } } if err != nil { return } - err = wf.AppendFinalResult2Chunk(ctx, e.partialResults[i], chk) + err = wf.AppendFinalResult2Chunk(ctx.GetExprCtx(), e.partialResults[i], chk) if err != nil { return } diff --git a/pkg/executor/point_get.go b/pkg/executor/point_get.go index abcd5ae18f44a..545fcea0de493 100644 --- a/pkg/executor/point_get.go +++ b/pkg/executor/point_get.go @@ -637,7 +637,7 @@ func decodeOldRowValToChunk(sctx sessionctx.Context, schema *expression.Schema, cutPos := colID2CutPos[col.ID] if len(cutVals[cutPos]) == 0 { colInfo := getColInfoByID(tblInfo, col.ID) - d, err1 := table.GetColOriginDefaultValue(sctx, colInfo) + d, err1 := table.GetColOriginDefaultValue(sctx.GetExprCtx(), colInfo) if err1 != nil { return err1 } diff --git a/pkg/executor/projection.go b/pkg/executor/projection.go index 1c8c451829184..3fbe1f64d4598 100644 --- a/pkg/executor/projection.go +++ b/pkg/executor/projection.go @@ -204,7 +204,7 @@ func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk if e.childResult.NumRows() == 0 { return nil } - err = e.evaluatorSuit.Run(e.Ctx(), e.childResult, chk) + err = e.evaluatorSuit.Run(e.Ctx().GetExprCtx(), e.childResult, chk) return err } @@ -448,7 +448,7 @@ func (w *projectionWorker) run(ctx context.Context) { } mSize := output.chk.MemoryUsage() + input.chk.MemoryUsage() - err := w.evaluatorSuit.Run(w.sctx, input.chk, output.chk) + err := w.evaluatorSuit.Run(w.sctx.GetExprCtx(), input.chk, output.chk) failpoint.Inject("ConsumeRandomPanic", nil) w.proj.memTracker.Consume(output.chk.MemoryUsage() + input.chk.MemoryUsage() - mSize) output.done <- err diff --git a/pkg/executor/set.go b/pkg/executor/set.go index ffc7632b41d32..1b92c5fb8fa1c 100644 --- a/pkg/executor/set.go +++ b/pkg/executor/set.go @@ -71,7 +71,7 @@ func (e *SetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { } continue } - dt, err := v.Expr.(*expression.Constant).Eval(sctx, chunk.Row{}) + dt, err := v.Expr.(*expression.Constant).Eval(sctx.GetExprCtx(), chunk.Row{}) if err != nil { return err } @@ -89,7 +89,7 @@ func (e *SetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { name := strings.ToLower(v.Name) if !v.IsSystem { // Set user variable. - value, err := v.Expr.Eval(sctx, chunk.Row{}) + value, err := v.Expr.Eval(sctx.GetExprCtx(), chunk.Row{}) if err != nil { return err } @@ -306,7 +306,7 @@ func (e *SetExecutor) getVarValue(ctx context.Context, v *expression.VarAssignme } return e.Ctx().GetSessionVars().GetGlobalSystemVar(ctx, v.Name) } - nativeVal, err := v.Expr.Eval(e.Ctx(), chunk.Row{}) + nativeVal, err := v.Expr.Eval(e.Ctx().GetExprCtx(), chunk.Row{}) if err != nil || nativeVal.IsNull() { return "", err } diff --git a/pkg/executor/set_config.go b/pkg/executor/set_config.go index c9a47cd4e4ef3..220174de0755d 100644 --- a/pkg/executor/set_config.go +++ b/pkg/executor/set_config.go @@ -190,13 +190,13 @@ func ConvertConfigItem2JSON(ctx sessionctx.Context, key string, val expression.E switch val.GetType().EvalType() { case types.ETString: var s string - s, isNull, err = val.EvalString(ctx, chunk.Row{}) + s, isNull, err = val.EvalString(ctx.GetExprCtx(), chunk.Row{}) if err == nil && !isNull { str = fmt.Sprintf("%q", s) } case types.ETInt: var i int64 - i, isNull, err = val.EvalInt(ctx, chunk.Row{}) + i, isNull, err = val.EvalInt(ctx.GetExprCtx(), chunk.Row{}) if err == nil && !isNull { if mysql.HasIsBooleanFlag(val.GetType().GetFlag()) { str = "true" @@ -209,13 +209,13 @@ func ConvertConfigItem2JSON(ctx sessionctx.Context, key string, val expression.E } case types.ETReal: var f float64 - f, isNull, err = val.EvalReal(ctx, chunk.Row{}) + f, isNull, err = val.EvalReal(ctx.GetExprCtx(), chunk.Row{}) if err == nil && !isNull { str = fmt.Sprintf("%v", f) } case types.ETDecimal: var d *types.MyDecimal - d, isNull, err = val.EvalDecimal(ctx, chunk.Row{}) + d, isNull, err = val.EvalDecimal(ctx.GetExprCtx(), chunk.Row{}) if err == nil && !isNull { str = string(d.ToString()) } diff --git a/pkg/executor/show.go b/pkg/executor/show.go index 5817884218254..73087f568af3e 100644 --- a/pkg/executor/show.go +++ b/pkg/executor/show.go @@ -653,7 +653,7 @@ func (e *ShowExec) fetchShowColumns(ctx context.Context) error { defaultValStr := fmt.Sprintf("%v", desc.DefaultValue) // If column is timestamp, and default value is not current_timestamp, should convert the default value to the current session time zone. if col.GetType() == mysql.TypeTimestamp && defaultValStr != types.ZeroDatetimeStr && !strings.HasPrefix(strings.ToUpper(defaultValStr), strings.ToUpper(ast.CurrentTimestamp)) { - timeValue, err := table.GetColDefaultValue(e.Ctx(), col.ToInfo()) + timeValue, err := table.GetColDefaultValue(e.Ctx().GetExprCtx(), col.ToInfo()) if err != nil { return errors.Trace(err) } @@ -1042,7 +1042,7 @@ func constructResultOfShowCreateTable(ctx sessionctx.Context, dbName *model.CISt defaultValStr := fmt.Sprintf("%v", defaultValue) // If column is timestamp, and default value is not current_timestamp, should convert the default value to the current session time zone. if defaultValStr != types.ZeroDatetimeStr && col.GetType() == mysql.TypeTimestamp { - timeValue, err := table.GetColDefaultValue(ctx, col) + timeValue, err := table.GetColDefaultValue(ctx.GetExprCtx(), col) if err != nil { return errors.Trace(err) } diff --git a/pkg/executor/shuffle.go b/pkg/executor/shuffle.go index e390c952bc03b..63603ddbd6491 100644 --- a/pkg/executor/shuffle.go +++ b/pkg/executor/shuffle.go @@ -465,7 +465,7 @@ func buildPartitionRangeSplitter(ctx sessionctx.Context, concurrency int, byItem return &partitionRangeSplitter{ byItems: byItems, numWorkers: concurrency, - groupChecker: vecgroupchecker.NewVecGroupChecker(ctx, byItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(ctx.GetExprCtx(), byItems), idx: 0, } } diff --git a/pkg/executor/test/executor/executor_test.go b/pkg/executor/test/executor/executor_test.go index 1013afc5c6b9d..dc658d0811264 100644 --- a/pkg/executor/test/executor/executor_test.go +++ b/pkg/executor/test/executor/executor_test.go @@ -463,7 +463,7 @@ func TestTimestampDefaultValueTimeZone(t *testing.T) { tk.MustExec(`set time_zone = '+00:00'`) timeIn0 := tk.MustQuery("select b from t").Rows()[0][0] require.NotEqual(t, timeIn8, timeIn0) - datumTimeIn8, err := expression.GetTimeValue(tk.Session(), timeIn8, mysql.TypeTimestamp, 0, nil) + datumTimeIn8, err := expression.GetTimeValue(tk.Session().GetExprCtx(), timeIn8, mysql.TypeTimestamp, 0, nil) require.NoError(t, err) tIn8To0 := datumTimeIn8.GetMysqlTime() timeZoneIn8, err := time.LoadLocation("Asia/Shanghai") diff --git a/pkg/executor/union_scan.go b/pkg/executor/union_scan.go index ce7e853c0d46a..19f397a51f3a4 100644 --- a/pkg/executor/union_scan.go +++ b/pkg/executor/union_scan.go @@ -155,7 +155,7 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { sctx := us.Ctx() for _, idx := range us.virtualColumnIndex { - datum, err := us.Schema().Columns[idx].EvalVirtualColumn(sctx, mutableRow.ToRow()) + datum, err := us.Schema().Columns[idx].EvalVirtualColumn(sctx.GetExprCtx(), mutableRow.ToRow()) if err != nil { return err } @@ -172,7 +172,7 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { mutableRow.SetDatum(idx, castDatum) } - matched, _, err := expression.EvalBool(us.Ctx(), us.conditionsWithVirCol, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(us.Ctx().GetExprCtx(), us.conditionsWithVirCol, mutableRow.ToRow()) if err != nil { return err } diff --git a/pkg/executor/update.go b/pkg/executor/update.go index 44ae2c371ba9a..3b33644cfd8b9 100644 --- a/pkg/executor/update.go +++ b/pkg/executor/update.go @@ -352,7 +352,7 @@ func (e *UpdateExec) fastComposeNewRow(rowIdx int, oldRow []types.Datum, cols [] continue } con := assign.Expr.(*expression.Constant) - val, err := con.Eval(e.Ctx(), emptyRow) + val, err := con.Eval(e.Ctx().GetExprCtx(), emptyRow) if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { return nil, err } @@ -379,7 +379,7 @@ func (e *UpdateExec) composeNewRow(rowIdx int, oldRow []types.Datum, cols []*tab if tblIdx >= 0 && !e.tableUpdatable[tblIdx] { continue } - val, err := assign.Expr.Eval(e.Ctx(), e.evalBuffer.ToRow()) + val, err := assign.Expr.Eval(e.Ctx().GetExprCtx(), e.evalBuffer.ToRow()) if err != nil { return nil, err } @@ -408,7 +408,7 @@ func (e *UpdateExec) composeGeneratedColumns(rowIdx int, newRowData []types.Datu if tblIdx >= 0 && !e.tableUpdatable[tblIdx] { continue } - val, err := assign.Expr.Eval(e.Ctx(), e.evalBuffer.ToRow()) + val, err := assign.Expr.Eval(e.Ctx().GetExprCtx(), e.evalBuffer.ToRow()) if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { return nil, err } diff --git a/pkg/executor/window.go b/pkg/executor/window.go index 22f85b71a2129..2a719e0f40a46 100644 --- a/pkg/executor/window.go +++ b/pkg/executor/window.go @@ -205,7 +205,7 @@ type aggWindowProcessor struct { func (p *aggWindowProcessor) consumeGroupRows(ctx sessionctx.Context, rows []chunk.Row) ([]chunk.Row, error) { for i, windowFunc := range p.windowFuncs { // @todo Add memory trace - _, err := windowFunc.UpdatePartialResult(ctx, rows, p.partialResults[i]) + _, err := windowFunc.UpdatePartialResult(ctx.GetExprCtx(), rows, p.partialResults[i]) if err != nil { return nil, err } @@ -218,7 +218,7 @@ func (p *aggWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, rows []c for remained > 0 { for i, windowFunc := range p.windowFuncs { // TODO: We can extend the agg func interface to avoid the `for` loop here. - err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + err := windowFunc.AppendFinalResult2Chunk(ctx.GetExprCtx(), p.partialResults[i], chk) if err != nil { return nil, err } @@ -321,14 +321,14 @@ func (p *rowFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, row for i, windowFunc := range p.windowFuncs { slidingWindowAggFunc := slidingWindowAggFuncs[i] if slidingWindowAggFunc != nil && initializedSlidingWindow { - err = slidingWindowAggFunc.Slide(ctx, func(u uint64) chunk.Row { + err = slidingWindowAggFunc.Slide(ctx.GetExprCtx(), func(u uint64) chunk.Row { return rows[u] }, lastStart, lastEnd, shiftStart, shiftEnd, p.partialResults[i]) if err != nil { return nil, err } } - err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + err = windowFunc.AppendFinalResult2Chunk(ctx.GetExprCtx(), p.partialResults[i], chk) if err != nil { return nil, err } @@ -339,7 +339,7 @@ func (p *rowFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, row for i, windowFunc := range p.windowFuncs { slidingWindowAggFunc := slidingWindowAggFuncs[i] if slidingWindowAggFunc != nil && initializedSlidingWindow { - err = slidingWindowAggFunc.Slide(ctx, func(u uint64) chunk.Row { + err = slidingWindowAggFunc.Slide(ctx.GetExprCtx(), func(u uint64) chunk.Row { return rows[u] }, lastStart, lastEnd, shiftStart, shiftEnd, p.partialResults[i]) } else { @@ -349,12 +349,12 @@ func (p *rowFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, row // Store start inside MaxMinSlidingWindowAggFunc.windowInfo minMaxSlidingWindowAggFunc.SetWindowStart(start) } - _, err = windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResults[i]) + _, err = windowFunc.UpdatePartialResult(ctx.GetExprCtx(), rows[start:end], p.partialResults[i]) } if err != nil { return nil, err } - err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + err = windowFunc.AppendFinalResult2Chunk(ctx.GetExprCtx(), p.partialResults[i], chk) if err != nil { return nil, err } @@ -398,7 +398,7 @@ func (p *rangeFrameWindowProcessor) getStartOffset(ctx sessionctx.Context, rows var res int64 var err error for i := range p.orderByCols { - res, _, err = p.start.CmpFuncs[i](ctx, p.start.CompareCols[i], p.start.CalcFuncs[i], rows[p.lastStartOffset], rows[p.curRowIdx]) + res, _, err = p.start.CmpFuncs[i](ctx.GetExprCtx(), p.start.CompareCols[i], p.start.CalcFuncs[i], rows[p.lastStartOffset], rows[p.curRowIdx]) if err != nil { return 0, err } @@ -424,7 +424,7 @@ func (p *rangeFrameWindowProcessor) getEndOffset(ctx sessionctx.Context, rows [] var res int64 var err error for i := range p.orderByCols { - res, _, err = p.end.CmpFuncs[i](ctx, p.end.CalcFuncs[i], p.end.CompareCols[i], rows[p.curRowIdx], rows[p.lastEndOffset]) + res, _, err = p.end.CmpFuncs[i](ctx.GetExprCtx(), p.end.CalcFuncs[i], p.end.CompareCols[i], rows[p.curRowIdx], rows[p.lastEndOffset]) if err != nil { return 0, err } @@ -475,14 +475,14 @@ func (p *rangeFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, r for i, windowFunc := range p.windowFuncs { slidingWindowAggFunc := slidingWindowAggFuncs[i] if slidingWindowAggFunc != nil && initializedSlidingWindow { - err = slidingWindowAggFunc.Slide(ctx, func(u uint64) chunk.Row { + err = slidingWindowAggFunc.Slide(ctx.GetExprCtx(), func(u uint64) chunk.Row { return rows[u] }, lastStart, lastEnd, shiftStart, shiftEnd, p.partialResults[i]) if err != nil { return nil, err } } - err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + err = windowFunc.AppendFinalResult2Chunk(ctx.GetExprCtx(), p.partialResults[i], chk) if err != nil { return nil, err } @@ -493,19 +493,19 @@ func (p *rangeFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, r for i, windowFunc := range p.windowFuncs { slidingWindowAggFunc := slidingWindowAggFuncs[i] if slidingWindowAggFunc != nil && initializedSlidingWindow { - err = slidingWindowAggFunc.Slide(ctx, func(u uint64) chunk.Row { + err = slidingWindowAggFunc.Slide(ctx.GetExprCtx(), func(u uint64) chunk.Row { return rows[u] }, lastStart, lastEnd, shiftStart, shiftEnd, p.partialResults[i]) } else { if minMaxSlidingWindowAggFunc, ok := windowFunc.(aggfuncs.MaxMinSlidingWindowAggFunc); ok { minMaxSlidingWindowAggFunc.SetWindowStart(start) } - _, err = windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResults[i]) + _, err = windowFunc.UpdatePartialResult(ctx.GetExprCtx(), rows[start:end], p.partialResults[i]) } if err != nil { return nil, err } - err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + err = windowFunc.AppendFinalResult2Chunk(ctx.GetExprCtx(), p.partialResults[i], chk) if err != nil { return nil, err } diff --git a/pkg/executor/write.go b/pkg/executor/write.go index 4437f5a66ff5a..a0b5d4129a503 100644 --- a/pkg/executor/write.go +++ b/pkg/executor/write.go @@ -144,7 +144,7 @@ func updateRecord( // Fill values into on-update-now fields, only if they are really changed. for i, col := range t.Cols() { if mysql.HasOnUpdateNowFlag(col.GetFlag()) && !modified[i] && !onUpdateSpecified[i] { - v, err := expression.GetTimeValue(sctx, strings.ToUpper(ast.CurrentTimestamp), col.GetType(), col.GetDecimal(), nil) + v, err := expression.GetTimeValue(sctx.GetExprCtx(), strings.ToUpper(ast.CurrentTimestamp), col.GetType(), col.GetDecimal(), nil) if err != nil { return false, err } @@ -246,7 +246,7 @@ func addUnchangedKeysForLockByRow( count := 0 physicalID := t.Meta().ID if pt, ok := t.(table.PartitionedTable); ok { - p, err := pt.GetPartitionByRow(sctx, row) + p, err := pt.GetPartitionByRow(sctx.GetExprCtx(), row) if err != nil { return 0, err } diff --git a/pkg/expression/context.go b/pkg/expression/context.go index d5605bcfed907..894f8852580d4 100644 --- a/pkg/expression/context.go +++ b/pkg/expression/context.go @@ -32,7 +32,7 @@ type EvalContext = context.EvalContext type BuildContext = context.BuildContext func sqlMode(ctx EvalContext) mysql.SQLMode { - return ctx.GetSessionVars().SQLMode + return ctx.SQLMode() } func typeCtx(ctx EvalContext) types.Context { diff --git a/pkg/expression/context/BUILD.bazel b/pkg/expression/context/BUILD.bazel index 408555e638ba1..5cfad68731423 100644 --- a/pkg/expression/context/BUILD.bazel +++ b/pkg/expression/context/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//pkg/infoschema/context", "//pkg/kv", + "//pkg/parser/mysql", "//pkg/sessionctx/variable", ], ) diff --git a/pkg/expression/context/context.go b/pkg/expression/context/context.go index 791a32b65ec76..3f953593e52eb 100644 --- a/pkg/expression/context/context.go +++ b/pkg/expression/context/context.go @@ -19,11 +19,14 @@ import ( infoschema "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/variable" ) // EvalContext is used to evaluate an expression type EvalContext interface { + // SQLMode returns the sql mode + SQLMode() mysql.SQLMode // GetSessionVars gets the session variables. GetSessionVars() *variable.SessionVars // Value returns the value associated with this context for key. diff --git a/pkg/expression/contextimpl/BUILD.bazel b/pkg/expression/contextimpl/BUILD.bazel new file mode 100644 index 0000000000000..5c637edbe3818 --- /dev/null +++ b/pkg/expression/contextimpl/BUILD.bazel @@ -0,0 +1,13 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "contextimpl", + srcs = ["expression.go"], + importpath = "github.com/pingcap/tidb/pkg/expression/contextimpl", + visibility = ["//visibility:public"], + deps = [ + "//pkg/expression/context", + "//pkg/parser/mysql", + "//pkg/sessionctx", + ], +) diff --git a/pkg/expression/contextimpl/expression.go b/pkg/expression/contextimpl/expression.go new file mode 100644 index 0000000000000..f5ff39fdf434b --- /dev/null +++ b/pkg/expression/contextimpl/expression.go @@ -0,0 +1,43 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package contextimpl + +import ( + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" +) + +// sessionctx.Context + *ExprCtxExtendedImpl should implement `expression.BuildContext` +// Only used to assert `ExprCtxExtendedImpl` should implement all methods not in `sessionctx.Context` +var _ exprctx.BuildContext = struct { + sessionctx.Context + *ExprCtxExtendedImpl +}{} + +// ExprCtxExtendedImpl extends the sessionctx.Context to implement `expression.BuildContext` +type ExprCtxExtendedImpl struct { + sctx sessionctx.Context +} + +// NewExprExtendedImpl creates a new ExprCtxExtendedImpl. +func NewExprExtendedImpl(sctx sessionctx.Context) *ExprCtxExtendedImpl { + return &ExprCtxExtendedImpl{sctx: sctx} +} + +// SQLMode returns the sql mode +func (ctx *ExprCtxExtendedImpl) SQLMode() mysql.SQLMode { + return ctx.sctx.GetSessionVars().SQLMode +} diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 41986cd93f00d..6d011e3eee624 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -411,9 +411,9 @@ func TestFilterExtractFromDNF(t *testing.T) { selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } - afterFunc := expression.ExtractFiltersFromDNFs(sctx, conds) + afterFunc := expression.ExtractFiltersFromDNFs(sctx.GetExprCtx(), conds) sort.Slice(afterFunc, func(i, j int) bool { return bytes.Compare(afterFunc[i].HashCode(), afterFunc[j].HashCode()) < 0 }) diff --git a/pkg/planner/contextimpl/BUILD.bazel b/pkg/planner/contextimpl/BUILD.bazel index 7daf4a767c97d..08778b9480db4 100644 --- a/pkg/planner/contextimpl/BUILD.bazel +++ b/pkg/planner/contextimpl/BUILD.bazel @@ -1,23 +1,12 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") -go_library( - name = "impl", - srcs = ["impl.go"], - importpath = "github.com/pingcap/tidb/pkg/planner/context/impl", - visibility = ["//visibility:public"], - deps = [ - "//pkg/planner/context", - "//pkg/sessionctx", - "//pkg/sessiontxn", - ], -) - go_library( name = "contextimpl", srcs = ["impl.go"], importpath = "github.com/pingcap/tidb/pkg/planner/contextimpl", visibility = ["//visibility:public"], deps = [ + "//pkg/expression/contextimpl", "//pkg/planner/context", "//pkg/sessionctx", "//pkg/sessiontxn", diff --git a/pkg/planner/contextimpl/impl.go b/pkg/planner/contextimpl/impl.go index 1d16fff10d8f4..76c42128f2e60 100644 --- a/pkg/planner/contextimpl/impl.go +++ b/pkg/planner/contextimpl/impl.go @@ -15,6 +15,7 @@ package contextimpl import ( + exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl" "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessiontxn" @@ -22,20 +23,21 @@ import ( var _ context.PlanContext = struct { sessionctx.Context - SessionContextExtended + *PlanCtxExtendedImpl }{} -// SessionContextExtended provides extended method for session context to implement `PlanContext` -type SessionContextExtended struct { +// PlanCtxExtendedImpl provides extended method for session context to implement `PlanContext` +type PlanCtxExtendedImpl struct { sctx sessionctx.Context + *exprctximpl.ExprCtxExtendedImpl } -// NewSessionContextExtended creates a new SessionContextExtended. -func NewSessionContextExtended(sctx sessionctx.Context) SessionContextExtended { - return SessionContextExtended{sctx: sctx} +// NewPlanCtxExtendedImpl creates a new PlanCtxExtendedImpl. +func NewPlanCtxExtendedImpl(sctx sessionctx.Context, exprCtx *exprctximpl.ExprCtxExtendedImpl) *PlanCtxExtendedImpl { + return &PlanCtxExtendedImpl{sctx: sctx, ExprCtxExtendedImpl: exprCtx} } // AdviseTxnWarmup advises the txn to warm up. -func (ctx SessionContextExtended) AdviseTxnWarmup() error { +func (ctx *PlanCtxExtendedImpl) AdviseTxnWarmup() error { return sessiontxn.GetTxnManager(ctx.sctx).AdviseWarmup() } diff --git a/pkg/planner/core/expression_test.go b/pkg/planner/core/expression_test.go index 03410fc07b675..c31da835b47ab 100644 --- a/pkg/planner/core/expression_test.go +++ b/pkg/planner/core/expression_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/testkit/testutil" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -40,7 +39,7 @@ func parseExpr(t *testing.T, expr string) ast.ExprNode { return stmt.Fields.Fields[0].Expr } -func buildExpr(t *testing.T, ctx sessionctx.Context, exprNode any, opts ...expression.BuildOption) (expr expression.Expression, err error) { +func buildExpr(t *testing.T, ctx expression.BuildContext, exprNode any, opts ...expression.BuildOption) (expr expression.Expression, err error) { switch x := exprNode.(type) { case string: node := parseExpr(t, x) @@ -59,7 +58,7 @@ func buildExpr(t *testing.T, ctx sessionctx.Context, exprNode any, opts ...expre return } -func buildExprAndEval(t *testing.T, ctx sessionctx.Context, exprNode any) types.Datum { +func buildExprAndEval(t *testing.T, ctx expression.BuildContext, exprNode any) types.Datum { expr, err := buildExpr(t, ctx, exprNode) require.NoError(t, err) val, err := expr.Eval(ctx, chunk.Row{}) diff --git a/pkg/planner/core/integration_test.go b/pkg/planner/core/integration_test.go index 0ad06b689112c..543e15a9330dd 100644 --- a/pkg/planner/core/integration_test.go +++ b/pkg/planner/core/integration_test.go @@ -164,10 +164,10 @@ func TestPartitionPruningForEQ(t *testing.T) { tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) pt := tbl.(table.PartitionedTable) - query, err := expression.ParseSimpleExpr(tk.Session(), "a = '2020-01-01 00:00:00'", expression.WithTableInfo("", tbl.Meta())) + query, err := expression.ParseSimpleExpr(tk.Session().GetExprCtx(), "a = '2020-01-01 00:00:00'", expression.WithTableInfo("", tbl.Meta())) require.NoError(t, err) dbName := model.NewCIStr(tk.Session().GetSessionVars().CurrentDB) - columns, names, err := expression.ColumnInfos2ColumnsAndNames(tk.Session(), dbName, tbl.Meta().Name, tbl.Meta().Cols(), tbl.Meta()) + columns, names, err := expression.ColumnInfos2ColumnsAndNames(tk.Session().GetExprCtx(), dbName, tbl.Meta().Name, tbl.Meta().Cols(), tbl.Meta()) require.NoError(t, err) // Even the partition is not monotonous, EQ condition should be prune! // select * from t where a = '2020-01-01 00:00:00' diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index 47a3bffeb53bb..24a88593be030 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -2058,7 +2058,7 @@ func (fb *FrameBound) UpdateCompareCols(ctx sessionctx.Context, orderByCols []*e fb.CompareCols = make([]expression.Expression, len(orderByCols)) if fb.CalcFuncs[0].GetType().EvalType() != orderByCols[0].GetType().EvalType() { var err error - fb.CompareCols[0], err = expression.NewFunctionBase(ctx, ast.Cast, fb.CalcFuncs[0].GetType(), orderByCols[0]) + fb.CompareCols[0], err = expression.NewFunctionBase(ctx.GetExprCtx(), ast.Cast, fb.CalcFuncs[0].GetType(), orderByCols[0]) if err != nil { return err } diff --git a/pkg/planner/core/physical_plan_test.go b/pkg/planner/core/physical_plan_test.go index d58a8073f46ab..94120fc911cf7 100644 --- a/pkg/planner/core/physical_plan_test.go +++ b/pkg/planner/core/physical_plan_test.go @@ -485,7 +485,7 @@ func TestPhysicalTableScanExtractCorrelatedCols(t *testing.T) { ts := findTableScan(p) require.NotNil(t, ts) - pb, err := ts.ToPB(tk.Session(), kv.TiFlash) + pb, err := ts.ToPB(tk.Session().GetPlanCtx(), kv.TiFlash) require.NoError(t, err) // make sure the pushed down filter condition is correct require.Equal(t, 1, len(pb.TblScan.PushedDownFilterConditions)) diff --git a/pkg/planner/core/plan.go b/pkg/planner/core/plan.go index f131c970f0207..12e062ed3c473 100644 --- a/pkg/planner/core/plan.go +++ b/pkg/planner/core/plan.go @@ -371,7 +371,7 @@ type PhysicalPlan interface { attach2Task(...task) task // ToPB converts physical plan to tipb executor. - ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) + ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) // GetChildReqProps gets the required property by child index. GetChildReqProps(idx int) *property.PhysicalProperty diff --git a/pkg/planner/core/plan_to_pb.go b/pkg/planner/core/plan_to_pb.go index fb04acac1e2db..912eddb25b635 100644 --- a/pkg/planner/core/plan_to_pb.go +++ b/pkg/planner/core/plan_to_pb.go @@ -20,7 +20,6 @@ import ( "github.com/pingcap/tidb/pkg/expression/aggregation" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table/tables" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/ranger" @@ -28,12 +27,12 @@ import ( ) // ToPB implements PhysicalPlan ToPB interface. -func (p *basePhysicalPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Executor, error) { +func (p *basePhysicalPlan) ToPB(_ PlanContext, _ kv.StoreType) (*tipb.Executor, error) { return nil, errors.Errorf("plan %s fails converts to PB", p.Plan.ExplainID()) } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalExpand) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalExpand) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { if len(p.LevelExprs) > 0 { return p.toPBV2(ctx, storeType) } @@ -57,7 +56,7 @@ func (p *PhysicalExpand) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (* return &tipb.Executor{Tp: tipb.ExecType_TypeExpand, Expand: expand, ExecutorId: &executorID}, nil } -func (p *PhysicalExpand) toPBV2(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalExpand) toPBV2(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() projExprsPB := make([]*tipb.ExprSlice, 0, len(p.LevelExprs)) for _, exprs := range p.LevelExprs { @@ -84,7 +83,7 @@ func (p *PhysicalExpand) toPBV2(ctx sessionctx.Context, storeType kv.StoreType) } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalHashAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalHashAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() groupByExprs, err := expression.ExpressionsToPBList(ctx, p.GroupByItems, client) if err != nil { @@ -119,7 +118,7 @@ func (p *PhysicalHashAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) ( } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalStreamAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalStreamAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() groupByExprs, err := expression.ExpressionsToPBList(ctx, p.GroupByItems, client) if err != nil { @@ -148,7 +147,7 @@ func (p *PhysicalStreamAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalSelection) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalSelection) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() conditions, err := expression.ExpressionsToPBList(ctx, p.Conditions, client) if err != nil { @@ -170,7 +169,7 @@ func (p *PhysicalSelection) ToPB(ctx sessionctx.Context, storeType kv.StoreType) } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalProjection) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalProjection) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() exprs, err := expression.ExpressionsToPBList(ctx, p.Exprs, client) if err != nil { @@ -192,7 +191,7 @@ func (p *PhysicalProjection) ToPB(ctx sessionctx.Context, storeType kv.StoreType } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalTopN) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalTopN) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() topNExec := &tipb.TopN{ Limit: p.Count, @@ -216,7 +215,7 @@ func (p *PhysicalTopN) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*ti } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalLimit) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalLimit) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() limitExec := &tipb.Limit{ Limit: p.Count, @@ -237,7 +236,7 @@ func (p *PhysicalLimit) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*t } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalTableScan) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalTableScan) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { if storeType == kv.TiFlash && p.Table.GetPartitionInfo() != nil && p.IsMPPOrBatchCop && p.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { return p.partitionTableScanToPBForFlash(ctx) } @@ -274,7 +273,7 @@ func (p *PhysicalTableScan) ToPB(ctx sessionctx.Context, storeType kv.StoreType) return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tsExec, ExecutorId: &executorID}, err } -func (p *PhysicalTableScan) partitionTableScanToPBForFlash(ctx sessionctx.Context) (*tipb.Executor, error) { +func (p *PhysicalTableScan) partitionTableScanToPBForFlash(ctx PlanContext) (*tipb.Executor, error) { ptsExec := tables.BuildPartitionTableScanFromInfos(p.Table, p.Columns, ctx.GetSessionVars().TiFlashFastScan) if len(p.LateMaterializationFilterCondition) > 0 { @@ -338,7 +337,7 @@ func FindColumnInfoByID(colInfos []*model.ColumnInfo, id int64) *model.ColumnInf } // ToPB generates the pb structure. -func (e *PhysicalExchangeSender) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (e *PhysicalExchangeSender) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { child, err := e.Children()[0].ToPB(ctx, kv.TiFlash) if err != nil { return nil, errors.Trace(err) @@ -413,7 +412,7 @@ func (e *PhysicalExchangeSender) ToPB(ctx sessionctx.Context, storeType kv.Store } // ToPB generates the pb structure. -func (e *PhysicalExchangeReceiver) ToPB(ctx sessionctx.Context, _ kv.StoreType) (*tipb.Executor, error) { +func (e *PhysicalExchangeReceiver) ToPB(ctx PlanContext, _ kv.StoreType) (*tipb.Executor, error) { encodedTask := make([][]byte, 0, len(e.Tasks)) for _, task := range e.Tasks { @@ -452,7 +451,7 @@ func (e *PhysicalExchangeReceiver) ToPB(ctx sessionctx.Context, _ kv.StoreType) } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalIndexScan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalIndexScan) ToPB(_ PlanContext, _ kv.StoreType) (*tipb.Executor, error) { columns := make([]*model.ColumnInfo, 0, p.schema.Len()) tableColumns := p.Table.Cols() for _, col := range p.schema.Columns { @@ -486,7 +485,7 @@ func (p *PhysicalIndexScan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Ex } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalHashJoin) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalHashJoin) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() if len(p.LeftJoinKeys) > 0 && len(p.LeftNAJoinKeys) > 0 { @@ -633,7 +632,7 @@ func (p *PhysicalHashJoin) ToPB(ctx sessionctx.Context, storeType kv.StoreType) } // ToPB converts FrameBound to tipb structure. -func (fb *FrameBound) ToPB(ctx sessionctx.Context) (*tipb.WindowFrameBound, error) { +func (fb *FrameBound) ToPB(ctx PlanContext) (*tipb.WindowFrameBound, error) { pbBound := &tipb.WindowFrameBound{ Type: tipb.WindowBoundType(fb.Type), Unbounded: fb.UnBounded, @@ -655,7 +654,7 @@ func (fb *FrameBound) ToPB(ctx sessionctx.Context) (*tipb.WindowFrameBound, erro } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalWindow) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalWindow) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() windowExec := &tipb.Window{} @@ -707,7 +706,7 @@ func (p *PhysicalWindow) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (* } // ToPB implements PhysicalPlan ToPB interface. -func (p *PhysicalSort) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) { +func (p *PhysicalSort) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { if !p.IsPartialSort { return nil, errors.Errorf("sort %s can't convert to pb, because it isn't a partial sort", p.Plan.ExplainID()) } diff --git a/pkg/planner/core/point_get_plan.go b/pkg/planner/core/point_get_plan.go index 1f835b3623f18..5cbe2f75eabeb 100644 --- a/pkg/planner/core/point_get_plan.go +++ b/pkg/planner/core/point_get_plan.go @@ -153,7 +153,7 @@ func (*PointGetPlan) attach2Task(...task) task { } // ToPB converts physical plan to tipb executor. -func (*PointGetPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Executor, error) { +func (*PointGetPlan) ToPB(_ PlanContext, _ kv.StoreType) (*tipb.Executor, error) { return nil, nil } @@ -409,7 +409,7 @@ func (*BatchPointGetPlan) attach2Task(...task) task { } // ToPB converts physical plan to tipb executor. -func (*BatchPointGetPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Executor, error) { +func (*BatchPointGetPlan) ToPB(_ PlanContext, _ kv.StoreType) (*tipb.Executor, error) { return nil, nil } diff --git a/pkg/planner/core/runtime_filter.go b/pkg/planner/core/runtime_filter.go index 12d46b2c147b4..cb468260fa0a1 100644 --- a/pkg/planner/core/runtime_filter.go +++ b/pkg/planner/core/runtime_filter.go @@ -20,7 +20,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" @@ -202,7 +201,7 @@ func (rf *RuntimeFilter) Clone() *RuntimeFilter { } // RuntimeFilterListToPB convert runtime filter list to PB list -func RuntimeFilterListToPB(ctx sessionctx.Context, runtimeFilterList []*RuntimeFilter, client kv.Client) ([]*tipb.RuntimeFilter, error) { +func RuntimeFilterListToPB(ctx PlanContext, runtimeFilterList []*RuntimeFilter, client kv.Client) ([]*tipb.RuntimeFilter, error) { result := make([]*tipb.RuntimeFilter, 0, len(runtimeFilterList)) for _, runtimeFilter := range runtimeFilterList { rfPB, err := runtimeFilter.ToPB(ctx, client) @@ -215,7 +214,7 @@ func RuntimeFilterListToPB(ctx sessionctx.Context, runtimeFilterList []*RuntimeF } // ToPB convert runtime filter to PB -func (rf *RuntimeFilter) ToPB(ctx sessionctx.Context, client kv.Client) (*tipb.RuntimeFilter, error) { +func (rf *RuntimeFilter) ToPB(ctx PlanContext, client kv.Client) (*tipb.RuntimeFilter, error) { pc := expression.NewPBConverter(client, ctx) srcExprListPB := make([]*tipb.Expr, 0, len(rf.srcExprList)) for _, srcExpr := range rf.srcExprList { diff --git a/pkg/session/BUILD.bazel b/pkg/session/BUILD.bazel index 737c828d93f68..f2ae3889b45be 100644 --- a/pkg/session/BUILD.bazel +++ b/pkg/session/BUILD.bazel @@ -5,9 +5,9 @@ go_library( srcs = [ "advisory_locks.go", "bootstrap.go", + "contextimpl.go", "mock_bootstrap.go", "nontransactional.go", - "plancontext.go", "session.go", "sync_upgrade.go", "testutil.go", #keep @@ -30,6 +30,8 @@ go_library( "//pkg/errno", "//pkg/executor", "//pkg/expression", + "//pkg/expression/context", + "//pkg/expression/contextimpl", "//pkg/extension", "//pkg/extension/extensionimpl", "//pkg/infoschema", diff --git a/pkg/session/bootstrap_test.go b/pkg/session/bootstrap_test.go index ef144832801a2..6e1632d5f9598 100644 --- a/pkg/session/bootstrap_test.go +++ b/pkg/session/bootstrap_test.go @@ -150,8 +150,9 @@ func TestBootstrapWithError(t *testing.T) { store: store, sessionVars: variable.NewSessionVars(nil), } - se.pctx = newPlanContextImpl(se) - se.tblctx = tbctximpl.NewTableContextImpl(se) + se.exprctx = newExpressionContextImpl(se) + se.pctx = newPlanContextImpl(se, se.exprctx.ExprCtxExtendedImpl) + se.tblctx = tbctximpl.NewTableContextImpl(se, se.exprctx) globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() se.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor se.txn.init() diff --git a/pkg/session/plancontext.go b/pkg/session/contextimpl.go similarity index 66% rename from pkg/session/plancontext.go rename to pkg/session/contextimpl.go index ac8eff2560b4e..bbbbf65a94437 100644 --- a/pkg/session/plancontext.go +++ b/pkg/session/contextimpl.go @@ -15,6 +15,7 @@ package session import ( + exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl" planctx "github.com/pingcap/tidb/pkg/planner/context" planctximpl "github.com/pingcap/tidb/pkg/planner/contextimpl" ) @@ -26,13 +27,25 @@ var _ planctx.PlanContext = &planContextImpl{} // the `session` here to make it safe for casting. type planContextImpl struct { *session - planctximpl.SessionContextExtended + *planctximpl.PlanCtxExtendedImpl } // NewPlanContextImpl creates a new PlanContextImpl. -func newPlanContextImpl(s *session) *planContextImpl { +func newPlanContextImpl(s *session, exprExtended *exprctximpl.ExprCtxExtendedImpl) *planContextImpl { return &planContextImpl{ - session: s, - SessionContextExtended: planctximpl.NewSessionContextExtended(s), + session: s, + PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s, exprExtended), + } +} + +type expressionContextImpl struct { + *session + *exprctximpl.ExprCtxExtendedImpl +} + +func newExpressionContextImpl(s *session) *expressionContextImpl { + return &expressionContextImpl{ + session: s, + ExprCtxExtendedImpl: exprctximpl.NewExprExtendedImpl(s), } } diff --git a/pkg/session/session.go b/pkg/session/session.go index 3b79a885c30ec..cd04ff34061c2 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -48,6 +48,7 @@ import ( "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/executor" "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/extension" "github.com/pingcap/tidb/pkg/extension/extensionimpl" "github.com/pingcap/tidb/pkg/infoschema" @@ -177,8 +178,9 @@ type session struct { sessionVars *variable.SessionVars sessionManager util.SessionManager - pctx *planContextImpl - tblctx *tbctximpl.TableContextImpl + pctx *planContextImpl + exprctx *expressionContextImpl + tblctx *tbctximpl.TableContextImpl statsCollector *usage.SessionStatsItem // ddlOwnerManager is used in `select tidb_is_ddl_owner()` statement; @@ -2597,6 +2599,11 @@ func (s *session) GetPlanCtx() planctx.PlanContext { return s.pctx } +// GetExprCtx returns the expression context of the session. +func (s *session) GetExprCtx() exprctx.BuildContext { + return s.exprctx +} + // GetTableCtx returns the table.MutateContext func (s *session) GetTableCtx() tbctx.MutateContext { return s.tblctx @@ -3563,8 +3570,9 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), } s.sessionVars = variable.NewSessionVars(s) - s.pctx = newPlanContextImpl(s) - s.tblctx = tbctximpl.NewTableContextImpl(s) + s.exprctx = newExpressionContextImpl(s) + s.pctx = newPlanContextImpl(s, s.exprctx.ExprCtxExtendedImpl) + s.tblctx = tbctximpl.NewTableContextImpl(s, s.exprctx) if opt != nil && opt.PreparedPlanCache != nil { s.sessionPlanCache = opt.PreparedPlanCache @@ -3625,8 +3633,9 @@ func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er stmtStats: stmtstats.CreateStatementStats(), sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), } - s.pctx = newPlanContextImpl(s) - s.tblctx = tbctximpl.NewTableContextImpl(s) + s.exprctx = newExpressionContextImpl(s) + s.pctx = newPlanContextImpl(s, s.exprctx.ExprCtxExtendedImpl) + s.tblctx = tbctximpl.NewTableContextImpl(s, s.exprctx) s.mu.values = make(map[fmt.Stringer]any) s.lockedTables = make(map[int64]model.TableLockTpInfo) domain.BindDomain(s, dom) diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go index d1eeffe3445b6..62b2537e269e8 100644 --- a/pkg/sessionctx/context.go +++ b/pkg/sessionctx/context.go @@ -66,8 +66,6 @@ type PlanCache interface { type Context interface { SessionStatesHandler contextutil.ValueStoreContext - exprctx.EvalContext - exprctx.BuildContext tablelock.TableLockContext // SetDiskFullOpt set the disk full opt when tikv disk full happened. SetDiskFullOpt(level kvrpcpb.DiskFullOpt) @@ -99,14 +97,17 @@ type Context interface { GetSessionVars() *variable.SessionVars + // GetExprCtx returns the expression context of the session. + GetExprCtx() exprctx.BuildContext + // GetTableCtx returns the table.MutateContext GetTableCtx() tbctx.MutateContext - GetSessionManager() util.SessionManager - // GetPlanCtx gets the plan context of the current session. GetPlanCtx() planctx.PlanContext + GetSessionManager() util.SessionManager + // RefreshTxnCtx commits old transaction without retry, // and creates a new transaction. // now just for load data and batch insert. diff --git a/pkg/sessiontxn/staleread/util.go b/pkg/sessiontxn/staleread/util.go index c5d30d4ce38b4..71716c61da995 100644 --- a/pkg/sessiontxn/staleread/util.go +++ b/pkg/sessiontxn/staleread/util.go @@ -68,12 +68,12 @@ func CalculateAsOfTsExpr(ctx context.Context, sctx pctx.PlanContext, tsExpr ast. // CalculateTsWithReadStaleness calculates the TsExpr for readStaleness duration func CalculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { - nowVal, err := expression.GetStmtTimestamp(sctx) + nowVal, err := expression.GetStmtTimestamp(sctx.GetExprCtx()) if err != nil { return 0, err } tsVal := nowVal.Add(readStaleness) - minTsVal := expression.GetMinSafeTime(sctx) + minTsVal := expression.GetMinSafeTime(sctx.GetExprCtx()) return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil } diff --git a/pkg/store/mockstore/mockcopr/aggregate.go b/pkg/store/mockstore/mockcopr/aggregate.go index 1a996929783f2..015f8b18a22d6 100644 --- a/pkg/store/mockstore/mockcopr/aggregate.go +++ b/pkg/store/mockstore/mockcopr/aggregate.go @@ -147,7 +147,7 @@ func (e *hashAggExec) getGroupKey() ([]byte, [][]byte, error) { sc := e.evalCtx.sctx.GetSessionVars().StmtCtx errCtx := sc.ErrCtx() for _, item := range e.groupByExprs { - v, err := item.Eval(e.evalCtx.sctx, chunk.MutRowFromDatums(e.row).ToRow()) + v, err := item.Eval(e.evalCtx.sctx.GetExprCtx(), chunk.MutRowFromDatums(e.row).ToRow()) if err != nil { return nil, nil, errors.Trace(err) } @@ -199,7 +199,7 @@ func (e *hashAggExec) getContexts(groupKey []byte) []*aggregation.AggEvaluateCon if !ok { aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.aggExprs)) for _, agg := range e.aggExprs { - aggCtxs = append(aggCtxs, agg.CreateContext(e.evalCtx.sctx)) + aggCtxs = append(aggCtxs, agg.CreateContext(e.evalCtx.sctx.GetExprCtx())) } e.aggCtxsMap[groupKeyString] = aggCtxs } @@ -265,7 +265,7 @@ func (e *streamAggExec) getPartialResult() ([][]byte, error) { value = append(value, data) } // Clear the aggregate context. - e.aggCtxs[i] = agg.CreateContext(e.evalCtx.sctx) + e.aggCtxs[i] = agg.CreateContext(e.evalCtx.sctx.GetExprCtx()) } e.currGroupByValues = e.currGroupByValues[:0] for _, d := range e.currGroupByRow { @@ -291,7 +291,7 @@ func (e *streamAggExec) meetNewGroup(row [][]byte) (bool, error) { matched, firstGroup = false, true } for i, item := range e.groupByExprs { - d, err := item.Eval(e.evalCtx.sctx, chunk.MutRowFromDatums(e.row).ToRow()) + d, err := item.Eval(e.evalCtx.sctx.GetExprCtx(), chunk.MutRowFromDatums(e.row).ToRow()) if err != nil { return false, errors.Trace(err) } diff --git a/pkg/store/mockstore/mockcopr/cop_handler_dag.go b/pkg/store/mockstore/mockcopr/cop_handler_dag.go index d9c7e2a1a15ab..a5a53931b3195 100644 --- a/pkg/store/mockstore/mockcopr/cop_handler_dag.go +++ b/pkg/store/mockstore/mockcopr/cop_handler_dag.go @@ -325,7 +325,7 @@ func (h coprHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]agg var relatedColOffsets []int for _, expr := range executor.Aggregation.AggFunc { var aggExpr aggregation.Aggregation - aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sctx) + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sctx.GetExprCtx()) if err != nil { return nil, nil, nil, errors.Trace(err) } @@ -374,7 +374,7 @@ func (h coprHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (* } aggCtxs := make([]*aggregation.AggEvaluateContext, 0, len(aggs)) for _, agg := range aggs { - aggCtxs = append(aggCtxs, agg.CreateContext(ctx.evalCtx.sctx)) + aggCtxs = append(aggCtxs, agg.CreateContext(ctx.evalCtx.sctx.GetExprCtx())) } groupByCollators := make([]collate.Collator, 0, len(groupBys)) for _, expr := range groupBys { diff --git a/pkg/store/mockstore/mockcopr/executor.go b/pkg/store/mockstore/mockcopr/executor.go index 6b415878c623f..d1a54dad84956 100644 --- a/pkg/store/mockstore/mockcopr/executor.go +++ b/pkg/store/mockstore/mockcopr/executor.go @@ -406,7 +406,7 @@ func (e *selectionExec) Counts() []int64 { // evalBool evaluates expression to a boolean value. func evalBool(exprs []expression.Expression, row []types.Datum, ctx sessionctx.Context) (bool, error) { for _, expr := range exprs { - data, err := expr.Eval(ctx, chunk.MutRowFromDatums(row).ToRow()) + data, err := expr.Eval(ctx.GetExprCtx(), chunk.MutRowFromDatums(row).ToRow()) if err != nil { return false, errors.Trace(err) } @@ -541,7 +541,7 @@ func (e *topNExec) evalTopN(value [][]byte) error { return errors.Trace(err) } for i, expr := range e.orderByExprs { - newRow.key[i], err = expr.Eval(e.evalCtx.sctx, chunk.MutRowFromDatums(e.row).ToRow()) + newRow.key[i], err = expr.Eval(e.evalCtx.sctx.GetExprCtx(), chunk.MutRowFromDatums(e.row).ToRow()) if err != nil { return errors.Trace(err) } @@ -664,7 +664,7 @@ func getRowData(columns []*tipb.ColumnInfo, colIDs map[int64]int, handle int64, func convertToExprs(sctx sessionctx.Context, fieldTps []*types.FieldType, pbExprs []*tipb.Expr) ([]expression.Expression, error) { exprs := make([]expression.Expression, 0, len(pbExprs)) for _, expr := range pbExprs { - e, err := expression.PBToExpr(sctx, expr, fieldTps) + e, err := expression.PBToExpr(sctx.GetExprCtx(), expr, fieldTps) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/closure_exec.go b/pkg/store/mockstore/unistore/cophandler/closure_exec.go index 67e0e5a59edd5..0c34942d09bc2 100644 --- a/pkg/store/mockstore/unistore/cophandler/closure_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/closure_exec.go @@ -186,7 +186,7 @@ func buildClosureExecutor(dagCtx *dagContext, dagReq *tipb.DAGRequest) (*closure func convertToExprs(sctx sessionctx.Context, fieldTps []*types.FieldType, pbExprs []*tipb.Expr) ([]expression.Expression, error) { exprs := make([]expression.Expression, 0, len(pbExprs)) for _, expr := range pbExprs { - e, err := expression.PBToExpr(sctx, expr, fieldTps) + e, err := expression.PBToExpr(sctx.GetExprCtx(), expr, fieldTps) if err != nil { return nil, errors.Trace(err) } @@ -787,7 +787,7 @@ func (e *closureExecutor) processSelection(needCollectDetail bool) (gotRow bool, gotRow = true for _, expr := range e.selectionCtx.conditions { wc := e.sctx.GetSessionVars().StmtCtx.WarningCount() - d, err := expr.Eval(e.sctx, row) + d, err := expr.Eval(e.sctx.GetExprCtx(), row) if err != nil { return false, errors.Trace(err) } @@ -1026,7 +1026,7 @@ func (e *topNProcessor) Process(key, value []byte) (err error) { ctx := e.topNCtx row := e.scanCtx.chk.GetRow(0) for i, expr := range ctx.orderByExprs { - d, err := expr.Eval(e.sctx, row) + d, err := expr.Eval(e.sctx.GetExprCtx(), row) if err != nil { return errors.Trace(err) } @@ -1124,7 +1124,7 @@ func (e *hashAggProcessor) getGroupKey(row chunk.Row) ([]byte, error) { sc := e.sctx.GetSessionVars().StmtCtx errCtx := sc.ErrCtx() for _, item := range e.groupByExprs { - v, err := item.Eval(e.sctx, row) + v, err := item.Eval(e.sctx.GetExprCtx(), row) if err != nil { return nil, errors.Trace(err) } @@ -1143,7 +1143,7 @@ func (e *hashAggProcessor) getContexts(groupKey []byte) []*aggregation.AggEvalua if !ok { aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.aggExprs)) for _, agg := range e.aggExprs { - aggCtxs = append(aggCtxs, agg.CreateContext(e.sctx)) + aggCtxs = append(aggCtxs, agg.CreateContext(e.sctx.GetExprCtx())) } e.aggCtxsMap[string(groupKey)] = aggCtxs } diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index 7e64a9f3930ba..9cf1aca52b86a 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -331,7 +331,7 @@ func getAggInfo(ctx *dagContext, pbAgg *tipb.Aggregation) ([]aggregation.Aggrega var err error for _, expr := range pbAgg.AggFunc { var aggExpr aggregation.Aggregation - aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx) + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx.GetExprCtx()) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp.go b/pkg/store/mockstore/unistore/cophandler/mpp.go index cbf3799c6ef70..8cb5116063266 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp.go @@ -315,7 +315,7 @@ func (b *mppExecBuilder) buildMPPExchangeSender(pb *tipb.ExchangeSender) (*exchS if pb.Tp == tipb.ExchangeType_Hash { // remove the limitation of len(pb.PartitionKeys) == 1 for _, partitionKey := range pb.PartitionKeys { - expr, err := expression.PBToExpr(b.sctx, partitionKey, child.getFieldTypes()) + expr, err := expression.PBToExpr(b.sctx.GetExprCtx(), partitionKey, child.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } @@ -412,12 +412,12 @@ func (b *mppExecBuilder) buildMPPJoin(pb *tipb.Join, children []*tipb.Executor) if pb.InnerIdx == 1 { e.probeChild = leftCh e.buildChild = rightCh - probeExpr, err := expression.PBToExpr(b.sctx, pb.LeftJoinKeys[0], leftCh.getFieldTypes()) + probeExpr, err := expression.PBToExpr(b.sctx.GetExprCtx(), pb.LeftJoinKeys[0], leftCh.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } e.probeKey = probeExpr.(*expression.Column) - buildExpr, err := expression.PBToExpr(b.sctx, pb.RightJoinKeys[0], rightCh.getFieldTypes()) + buildExpr, err := expression.PBToExpr(b.sctx.GetExprCtx(), pb.RightJoinKeys[0], rightCh.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } @@ -425,12 +425,12 @@ func (b *mppExecBuilder) buildMPPJoin(pb *tipb.Join, children []*tipb.Executor) } else { e.probeChild = rightCh e.buildChild = leftCh - buildExpr, err := expression.PBToExpr(b.sctx, pb.LeftJoinKeys[0], leftCh.getFieldTypes()) + buildExpr, err := expression.PBToExpr(b.sctx.GetExprCtx(), pb.LeftJoinKeys[0], leftCh.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } e.buildKey = buildExpr.(*expression.Column) - probeExpr, err := expression.PBToExpr(b.sctx, pb.RightJoinKeys[0], rightCh.getFieldTypes()) + probeExpr, err := expression.PBToExpr(b.sctx.GetExprCtx(), pb.RightJoinKeys[0], rightCh.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } @@ -459,7 +459,7 @@ func (b *mppExecBuilder) buildMPPProj(proj *tipb.Projection) (*projExec, error) e.children = []mppExec{chExec} for _, pbExpr := range proj.Exprs { - expr, err := expression.PBToExpr(b.sctx, pbExpr, chExec.getFieldTypes()) + expr, err := expression.PBToExpr(b.sctx.GetExprCtx(), pbExpr, chExec.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } @@ -484,7 +484,7 @@ func (b *mppExecBuilder) buildMPPSel(sel *tipb.Selection) (*selExec, error) { } for _, pbExpr := range sel.Conditions { - expr, err := expression.PBToExpr(b.sctx, pbExpr, chExec.getFieldTypes()) + expr, err := expression.PBToExpr(b.sctx.GetExprCtx(), pbExpr, chExec.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } @@ -512,7 +512,7 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { for _, aggFunc := range agg.AggFunc { ft := expression.PbTypeToFieldType(aggFunc.FieldType) e.fieldTypes = append(e.fieldTypes, ft) - aggExpr, _, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) + aggExpr, _, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx.GetExprCtx()) if err != nil { return nil, errors.Trace(err) } @@ -524,7 +524,7 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { ft := expression.PbTypeToFieldType(gby.FieldType) e.fieldTypes = append(e.fieldTypes, ft) e.groupByTypes = append(e.groupByTypes, ft) - gbyExpr, err := expression.PBToExpr(b.sctx, gby, chExec.getFieldTypes()) + gbyExpr, err := expression.PBToExpr(b.sctx.GetExprCtx(), gby, chExec.getFieldTypes()) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go index ca2c7d600c424..613472ce85816 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go @@ -559,7 +559,7 @@ func (e *topNExec) open() error { for i := 0; i < numRows; i++ { row := chk.GetRow(i) for j, cond := range e.conds { - d, err := cond.Eval(e.sctx, row) + d, err := cond.Eval(e.sctx.GetExprCtx(), row) if err != nil { return err } @@ -1012,7 +1012,7 @@ func (e *aggExec) getGroupKey(row chunk.Row) (*chunk.MutRow, []byte, error) { gbyRow := chunk.MutRowFromTypes(e.groupByTypes) sc := e.sctx.GetSessionVars().StmtCtx for i, item := range e.groupByExprs { - v, err := item.Eval(e.sctx, row) + v, err := item.Eval(e.sctx.GetExprCtx(), row) if err != nil { return nil, nil, errors.Trace(err) } @@ -1032,7 +1032,7 @@ func (e *aggExec) getContexts(groupKey []byte) []*aggregation.AggEvaluateContext if !ok { aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.aggExprs)) for _, agg := range e.aggExprs { - aggCtxs = append(aggCtxs, agg.CreateContext(e.sctx)) + aggCtxs = append(aggCtxs, agg.CreateContext(e.sctx.GetExprCtx())) } e.aggCtxsMap[string(groupKey)] = aggCtxs } @@ -1131,7 +1131,7 @@ func (e *selExec) next() (*chunk.Chunk, error) { row := chk.GetRow(rows) passCheck := true for _, cond := range e.conditions { - d, err := cond.Eval(e.sctx, row) + d, err := cond.Eval(e.sctx.GetExprCtx(), row) if err != nil { return nil, errors.Trace(err) } @@ -1182,7 +1182,7 @@ func (e *projExec) next() (*chunk.Chunk, error) { row := chk.GetRow(i) newRow := chunk.MutRowFromTypes(e.fieldTypes) for i, expr := range e.exprs { - d, err := expr.Eval(e.sctx, row) + d, err := expr.Eval(e.sctx.GetExprCtx(), row) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/table/column.go b/pkg/table/column.go index c34faaefe3c03..f89a21edb4ab1 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -530,7 +530,7 @@ func GetColOriginDefaultValue(ctx expression.BuildContext, col *model.ColumnInfo } // GetColOriginDefaultValueWithoutStrictSQLMode gets default value of the column from original default value with Strict SQL mode. -func GetColOriginDefaultValueWithoutStrictSQLMode(ctx sessionctx.Context, col *model.ColumnInfo) (types.Datum, error) { +func GetColOriginDefaultValueWithoutStrictSQLMode(ctx expression.BuildContext, col *model.ColumnInfo) (types.Datum, error) { return getColDefaultValue(ctx, col, col.GetOriginDefaultValue(), &getColOriginDefaultValue{ StrictSQLMode: false, }) @@ -736,7 +736,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd iter := chunk.NewIterator4Chunk(req) for i, idx := range virtualColumnIndex { for row := iter.Begin(); row != iter.End(); row = iter.Next() { - datum, err := expCols[idx].EvalVirtualColumn(sctx, row) + datum, err := expCols[idx].EvalVirtualColumn(sctx.GetExprCtx(), row) if err != nil { return err } diff --git a/pkg/table/constraint.go b/pkg/table/constraint.go index 50a94d0e7aec6..91bdaab9cf3a3 100644 --- a/pkg/table/constraint.go +++ b/pkg/table/constraint.go @@ -20,7 +20,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/mock" @@ -81,7 +80,7 @@ func ToConstraint(constraintInfo *model.ConstraintInfo, tblInfo *model.TableInfo }, nil } -func buildConstraintExpression(ctx sessionctx.Context, exprString string, db string, tblInfo *model.TableInfo) (expression.Expression, error) { +func buildConstraintExpression(ctx expression.BuildContext, exprString string, db string, tblInfo *model.TableInfo) (expression.Expression, error) { expr, err := expression.ParseSimpleExpr(ctx, exprString, expression.WithTableInfo(db, tblInfo)) if err != nil { // If it got an error here, ddl may hang forever, so this error log is important. diff --git a/pkg/table/contextimpl/table.go b/pkg/table/contextimpl/table.go index 6e1faca1625cd..f6402f1198ed5 100644 --- a/pkg/table/contextimpl/table.go +++ b/pkg/table/contextimpl/table.go @@ -29,11 +29,12 @@ var _ context.AllocatorContext = &TableContextImpl{} // TableContextImpl is used to provide context for table operations. type TableContextImpl struct { sessionctx.Context + exprCtx exprctx.BuildContext } // NewTableContextImpl creates a new TableContextImpl. -func NewTableContextImpl(sctx sessionctx.Context) *TableContextImpl { - return &TableContextImpl{Context: sctx} +func NewTableContextImpl(sctx sessionctx.Context, exprCtx exprctx.BuildContext) *TableContextImpl { + return &TableContextImpl{Context: sctx, exprCtx: exprCtx} } // TxnRecordTempTable record the temporary table to the current transaction. @@ -44,7 +45,7 @@ func (ctx *TableContextImpl) TxnRecordTempTable(tbl *model.TableInfo) tableutil. // GetExprCtx returns the ExprContext func (ctx *TableContextImpl) GetExprCtx() exprctx.BuildContext { - return ctx.Context + return ctx.exprCtx } func (ctx *TableContextImpl) vars() *variable.SessionVars { diff --git a/pkg/table/tables/partition.go b/pkg/table/tables/partition.go index bd385b9ac81f9..61014448951d2 100644 --- a/pkg/table/tables/partition.go +++ b/pkg/table/tables/partition.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/tablecodec" @@ -342,7 +341,7 @@ type ForRangeColumnsPruning struct { LessThan [][]*expression.Expression } -func dataForRangeColumnsPruning(ctx sessionctx.Context, defs []model.PartitionDefinition, schema *expression.Schema, names []*types.FieldName, p *parser.Parser, colOffsets []int) (*ForRangeColumnsPruning, error) { +func dataForRangeColumnsPruning(ctx expression.BuildContext, defs []model.PartitionDefinition, schema *expression.Schema, names []*types.FieldName, p *parser.Parser, colOffsets []int) (*ForRangeColumnsPruning, error) { var res ForRangeColumnsPruning res.LessThan = make([][]*expression.Expression, 0, len(defs)) for i := 0; i < len(defs); i++ { @@ -608,7 +607,7 @@ type ForRangePruning struct { } // dataForRangePruning extracts the less than parts from 'partition p0 less than xx ... partition p1 less than ...' -func dataForRangePruning(sctx sessionctx.Context, defs []model.PartitionDefinition) (*ForRangePruning, error) { +func dataForRangePruning(sctx expression.BuildContext, defs []model.PartitionDefinition) (*ForRangePruning, error) { var maxValue bool var unsigned bool lessThan := make([]int64, len(defs)) @@ -643,7 +642,7 @@ func dataForRangePruning(sctx sessionctx.Context, defs []model.PartitionDefiniti }, nil } -func fixOldVersionPartitionInfo(sctx sessionctx.Context, str string) (int64, bool) { +func fixOldVersionPartitionInfo(sctx expression.BuildContext, str string) (int64, bool) { // less than value should be calculate to integer before persistent. // Old version TiDB may not do it and store the raw expression. tmp, err := parseSimpleExprWithNames(parser.New(), sctx, str, nil, nil) @@ -670,7 +669,7 @@ func rangePartitionExprStrings(cols []model.CIStr, expr string) []string { return s } -func generateKeyPartitionExpr(ctx sessionctx.Context, expr string, partCols []model.CIStr, +func generateKeyPartitionExpr(ctx expression.BuildContext, expr string, partCols []model.CIStr, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { ret := &PartitionExpr{ ForKeyPruning: &ForKeyPruning{}, @@ -685,7 +684,7 @@ func generateKeyPartitionExpr(ctx sessionctx.Context, expr string, partCols []mo return ret, nil } -func generateRangePartitionExpr(ctx sessionctx.Context, expr string, partCols []model.CIStr, +func generateRangePartitionExpr(ctx expression.BuildContext, expr string, partCols []model.CIStr, defs []model.PartitionDefinition, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { // The caller should assure partition info is not nil. p := parser.New() @@ -722,7 +721,7 @@ func generateRangePartitionExpr(ctx sessionctx.Context, expr string, partCols [] return ret, nil } -func getRangeLocateExprs(ctx sessionctx.Context, p *parser.Parser, defs []model.PartitionDefinition, partStrs []string, schema *expression.Schema, names types.NameSlice) ([]expression.Expression, error) { +func getRangeLocateExprs(ctx expression.BuildContext, p *parser.Parser, defs []model.PartitionDefinition, partStrs []string, schema *expression.Schema, names types.NameSlice) ([]expression.Expression, error) { var buf bytes.Buffer locateExprs := make([]expression.Expression, 0, len(defs)) for i := 0; i < len(defs); i++ { @@ -775,7 +774,7 @@ func findIdxByColUniqueID(cols []*expression.Column, col *expression.Column) int return -1 } -func extractPartitionExprColumns(ctx sessionctx.Context, expr string, partCols []model.CIStr, columns []*expression.Column, names types.NameSlice) (expression.Expression, []*expression.Column, []int, error) { +func extractPartitionExprColumns(ctx expression.BuildContext, expr string, partCols []model.CIStr, columns []*expression.Column, names types.NameSlice) (expression.Expression, []*expression.Column, []int, error) { var cols []*expression.Column var partExpr expression.Expression if len(partCols) == 0 { @@ -806,7 +805,7 @@ func extractPartitionExprColumns(ctx sessionctx.Context, expr string, partCols [ return partExpr, deDupCols, offset, nil } -func generateListPartitionExpr(ctx sessionctx.Context, tblInfo *model.TableInfo, expr string, partCols []model.CIStr, +func generateListPartitionExpr(ctx expression.BuildContext, tblInfo *model.TableInfo, expr string, partCols []model.CIStr, defs []model.PartitionDefinition, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { // The caller should assure partition info is not nil. partExpr, exprCols, offset, err := extractPartitionExprColumns(ctx, expr, partCols, columns, names) @@ -853,7 +852,7 @@ func (lp *ForListPruning) Clone() *ForListPruning { return &ret } -func (lp *ForListPruning) buildListPruner(ctx sessionctx.Context, exprStr string, defs []model.PartitionDefinition, exprCols []*expression.Column, +func (lp *ForListPruning) buildListPruner(ctx expression.BuildContext, exprStr string, defs []model.PartitionDefinition, exprCols []*expression.Column, columns []*expression.Column, names types.NameSlice) error { schema := expression.NewSchema(columns...) p := parser.New() @@ -882,7 +881,7 @@ func (lp *ForListPruning) buildListPruner(ctx sessionctx.Context, exprStr string return nil } -func (lp *ForListPruning) buildListColumnsPruner(ctx sessionctx.Context, +func (lp *ForListPruning) buildListColumnsPruner(ctx expression.BuildContext, tblInfo *model.TableInfo, partCols []model.CIStr, defs []model.PartitionDefinition, columns []*expression.Column, names types.NameSlice) error { schema := expression.NewSchema(columns...) @@ -933,7 +932,7 @@ func (lp *ForListPruning) buildListColumnsPruner(ctx sessionctx.Context, // buildListPartitionValueMap builds list partition value map. // The map is column value -> partition index. // colIdx is the column index in the list columns. -func (lp *ForListPruning) buildListPartitionValueMap(ctx sessionctx.Context, defs []model.PartitionDefinition, +func (lp *ForListPruning) buildListPartitionValueMap(ctx expression.BuildContext, defs []model.PartitionDefinition, schema *expression.Schema, names types.NameSlice, p *parser.Parser) error { lp.valueMap = map[int64]int{} lp.nullPartitionIdx = -1 @@ -1183,7 +1182,7 @@ func (lp *ForListColumnPruning) LocateRanges(tc types.Context, ec errctx.Context return locations, nil } -func generateHashPartitionExpr(ctx sessionctx.Context, exprStr string, +func generateHashPartitionExpr(ctx expression.BuildContext, exprStr string, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) { // The caller should assure partition info is not nil. schema := expression.NewSchema(columns...) diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index fe0b54b14ddc6..c963f3d12cbe4 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/autoid" @@ -1671,7 +1672,7 @@ func GetColDefaultValue(ctx sessionctx.Context, col *table.Column, defaultVals [ return colVal, errors.New("Miss column") } if defaultVals[col.Offset].IsNull() { - colVal, err = table.GetColOriginDefaultValue(ctx, col.ToInfo()) + colVal, err = table.GetColOriginDefaultValue(ctx.GetExprCtx(), col.ToInfo()) if err != nil { return colVal, err } @@ -2300,7 +2301,7 @@ func BuildPartitionTableScanFromInfos(tableInfo *model.TableInfo, columnInfos [] } // SetPBColumnsDefaultValue sets the default values of tipb.ColumnInfo. -func SetPBColumnsDefaultValue(ctx sessionctx.Context, pbColumns []*tipb.ColumnInfo, columns []*model.ColumnInfo) error { +func SetPBColumnsDefaultValue(ctx expression.BuildContext, pbColumns []*tipb.ColumnInfo, columns []*model.ColumnInfo) error { for i, c := range columns { // For virtual columns, we set their default values to NULL so that TiKV will return NULL properly, // They real values will be computed later. diff --git a/pkg/util/admin/admin.go b/pkg/util/admin/admin.go index 86dd16be58f5e..a971bbd8550ee 100644 --- a/pkg/util/admin/admin.go +++ b/pkg/util/admin/admin.go @@ -161,7 +161,7 @@ func CheckRecordAndIndex(ctx context.Context, sessCtx sessionctx.Context, txn kv return false, errors.Errorf("Column %v define as not null, but can't find the value where handle is %v", col.Name, h1) } // NULL value is regarded as its default value. - colDefVal, err := table.GetColOriginDefaultValue(sessCtx, col.ToInfo()) + colDefVal, err := table.GetColOriginDefaultValue(sessCtx.GetExprCtx(), col.ToInfo()) if err != nil { return false, errors.Trace(err) } @@ -194,7 +194,7 @@ func CheckRecordAndIndex(ctx context.Context, sessCtx sessionctx.Context, txn kv func makeRowDecoder(t table.Table, sctx sessionctx.Context) (*decoder.RowDecoder, error) { dbName := model.NewCIStr(sctx.GetSessionVars().CurrentDB) - exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sctx, dbName, t.Meta().Name, t.Meta().Cols(), t.Meta()) + exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sctx.GetExprCtx(), dbName, t.Meta().Name, t.Meta().Cols(), t.Meta()) if err != nil { return nil, err } diff --git a/pkg/util/mock/BUILD.bazel b/pkg/util/mock/BUILD.bazel index 3c436f656e89f..2ef3aa0c823d3 100644 --- a/pkg/util/mock/BUILD.bazel +++ b/pkg/util/mock/BUILD.bazel @@ -12,6 +12,8 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/util/mock", visibility = ["//visibility:public"], deps = [ + "//pkg/expression/context", + "//pkg/expression/contextimpl", "//pkg/extension", "//pkg/infoschema/context", "//pkg/kv", diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index 93828f7e854e7..48ec7ead3bf4d 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -22,6 +22,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/kvrpcpb" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl" "github.com/pingcap/tidb/pkg/extension" infoschema "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/kv" @@ -55,6 +57,7 @@ var ( // Context represents mocked sessionctx.Context. type Context struct { planctx.EmptyPlanContextExtended + *exprctximpl.ExprCtxExtendedImpl txn wrapTxn // mock global variable Store kv.Storage // mock global variable ctx context.Context @@ -202,6 +205,11 @@ func (c *Context) GetPlanCtx() planctx.PlanContext { return c } +// GetExprCtx returns the expression context of the session. +func (c *Context) GetExprCtx() exprctx.BuildContext { + return c +} + // GetTableCtx returns the table.MutateContext func (c *Context) GetTableCtx() tbctx.MutateContext { return c.tblctx @@ -514,7 +522,8 @@ func NewContext() *Context { } vars := variable.NewSessionVars(sctx) sctx.sessionVars = vars - sctx.tblctx = tbctximpl.NewTableContextImpl(sctx) + sctx.ExprCtxExtendedImpl = exprctximpl.NewExprExtendedImpl(sctx) + sctx.tblctx = tbctximpl.NewTableContextImpl(sctx, sctx) vars.InitChunkSize = 2 vars.MaxChunkSize = 32 vars.TimeZone = time.UTC diff --git a/pkg/util/ranger/bench_test.go b/pkg/util/ranger/bench_test.go index e5f2a3377b0ef..4da75ec75e168 100644 --- a/pkg/util/ranger/bench_test.go +++ b/pkg/util/ranger/bench_test.go @@ -119,7 +119,7 @@ WHERE require.NotNil(b, selection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } cols, lengths := expression.IndexInfo2PrefixCols(tbl.Columns, selection.Schema().Columns, tbl.Indices[0]) require.NotNil(b, cols) diff --git a/pkg/util/ranger/ranger_test.go b/pkg/util/ranger/ranger_test.go index 39206d206b1d3..58dfd8686f824 100644 --- a/pkg/util/ranger/ranger_test.go +++ b/pkg/util/ranger/ranger_test.go @@ -271,7 +271,7 @@ func TestTableRange(t *testing.T) { selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() col := expression.ColInfo2Col(selection.Schema().Columns, tbl.Columns[0]) @@ -471,7 +471,7 @@ create table t( require.NotNil(t, selection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } cols, lengths := expression.IndexInfo2PrefixCols(tbl.Columns, selection.Schema().Columns, tbl.Indices[tt.indexPos]) require.NotNil(t, cols) @@ -833,7 +833,7 @@ func TestColumnRange(t *testing.T) { require.True(t, ok) conds := make([]expression.Expression, len(sel.Conditions)) for i, cond := range sel.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } col := expression.ColInfo2Col(sel.Schema().Columns, ds.TableInfo().Columns[tt.colPos]) require.NotNil(t, col) @@ -991,7 +991,7 @@ func TestIndexRangeForYear(t *testing.T) { require.NotNil(t, selection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } cols, lengths := expression.IndexInfo2PrefixCols(tbl.Columns, selection.Schema().Columns, tbl.Indices[tt.indexPos]) require.NotNil(t, cols) @@ -1060,7 +1060,7 @@ func TestPrefixIndexRangeScan(t *testing.T) { require.NotNil(t, selection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } cols, lengths := expression.IndexInfo2PrefixCols(tbl.Columns, selection.Schema().Columns, tbl.Indices[tt.indexPos]) require.NotNil(t, cols) @@ -1407,7 +1407,7 @@ create table t( require.NotNil(t, selection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } cols, lengths := expression.IndexInfo2PrefixCols(tbl.Columns, selection.Schema().Columns, tbl.Indices[tt.indexPos]) require.NotNil(t, cols) @@ -1646,7 +1646,7 @@ func TestTableShardIndex(t *testing.T) { selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } ds, ok := selection.Children()[0].(*plannercore.DataSource) if !ok { @@ -1709,10 +1709,10 @@ func TestShardIndexFuncSuites(t *testing.T) { col1 := &expression.Column{UniqueID: 1, ID: 1, RetType: longlongType} // col2 is GC column and VirtualExpr = tidb_shard(col0) col2 := &expression.Column{UniqueID: 2, ID: 2, RetType: longlongType} - col2.VirtualExpr = expression.NewFunctionInternal(sctx, ast.TiDBShard, col2.RetType, col0) + col2.VirtualExpr = expression.NewFunctionInternal(sctx.GetExprCtx(), ast.TiDBShard, col2.RetType, col0) // col3 is GC column and VirtualExpr = abs(col0) col3 := &expression.Column{UniqueID: 3, ID: 3, RetType: longlongType} - col3.VirtualExpr = expression.NewFunctionInternal(sctx, ast.Abs, col2.RetType, col0) + col3.VirtualExpr = expression.NewFunctionInternal(sctx.GetExprCtx(), ast.Abs, col2.RetType, col0) col4 := &expression.Column{UniqueID: 4, ID: 4, RetType: longlongType} cols := []*expression.Column{col0, col1} @@ -1736,8 +1736,8 @@ func TestShardIndexFuncSuites(t *testing.T) { // normal case con1 := &expression.Constant{Value: types.NewDatum(1), RetType: longlongType} con5 := &expression.Constant{Value: types.NewDatum(5), RetType: longlongType} - exprEq := expression.NewFunctionInternal(sctx, ast.EQ, col0.RetType, col0, con1) - exprIn := expression.NewFunctionInternal(sctx, ast.In, col0.RetType, col0, con1, con5) + exprEq := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.EQ, col0.RetType, col0, con1) + exprIn := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.In, col0.RetType, col0, con1, con5) require.NotNil(t, exprEq) require.NotNil(t, exprIn) // input is nil @@ -1745,13 +1745,13 @@ func TestShardIndexFuncSuites(t *testing.T) { // input is column require.Equal(t, len(ranger.ExtractColumnsFromExpr(exprEq.(*expression.ScalarFunction))), 1) // (col0 = 1 and col3 > 1) or (col4 < 5 and 5) - exprGt := expression.NewFunctionInternal(sctx, ast.GT, longlongType, col3, con1) + exprGt := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.GT, longlongType, col3, con1) require.NotNil(t, exprGt) - andExpr1 := expression.NewFunctionInternal(sctx, ast.And, longlongType, exprEq, exprGt) + andExpr1 := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.And, longlongType, exprEq, exprGt) require.NotNil(t, andExpr1) - exprLt := expression.NewFunctionInternal(sctx, ast.LT, longlongType, col4, con5) - andExpr2 := expression.NewFunctionInternal(sctx, ast.And, longlongType, exprLt, con5) - orExpr2 := expression.NewFunctionInternal(sctx, ast.Or, longlongType, andExpr1, andExpr2) + exprLt := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.LT, longlongType, col4, con5) + andExpr2 := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.And, longlongType, exprLt, con5) + orExpr2 := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.Or, longlongType, andExpr1, andExpr2) require.Equal(t, len(ranger.ExtractColumnsFromExpr(orExpr2.(*expression.ScalarFunction))), 3) // ------------------------------------------- @@ -1770,12 +1770,12 @@ func TestShardIndexFuncSuites(t *testing.T) { require.False(t, ranger.NeedAddColumn4InCond(shardIndexCols, accessCond, nil)) // col1 in (1, 5) - exprIn2 := expression.NewFunctionInternal(sctx, ast.In, col1.RetType, col1, con1, con5) + exprIn2 := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.In, col1.RetType, col1, con1, con5) accessCond[1] = exprIn2 require.False(t, ranger.NeedAddColumn4InCond(shardIndexCols, accessCond, exprIn2.(*expression.ScalarFunction))) // col0 in (1, col1) - exprIn3 := expression.NewFunctionInternal(sctx, ast.In, col0.RetType, col1, con1, col1) + exprIn3 := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.In, col0.RetType, col1, con1, col1) accessCond[1] = exprIn3 require.False(t, ranger.NeedAddColumn4InCond(shardIndexCols, accessCond, exprIn3.(*expression.ScalarFunction))) @@ -1795,7 +1795,7 @@ func TestShardIndexFuncSuites(t *testing.T) { // ------------------------------------------- // test AddExpr4EqAndInCondition function // ------------------------------------------- - exprIn4 := expression.NewFunctionInternal(sctx, ast.In, col0.RetType, col0, con1) + exprIn4 := expression.NewFunctionInternal(sctx.GetExprCtx(), ast.In, col0.RetType, col0, con1) test := []struct { inputConds []expression.Expression outputConds string @@ -2278,7 +2278,7 @@ create table t( require.NotNil(t, selection, fmt.Sprintf("expr:%v", tt.exprStr)) conds := make([]expression.Expression, len(selection.Conditions)) for i, cond := range selection.Conditions { - conds[i] = expression.PushDownNot(sctx, cond) + conds[i] = expression.PushDownNot(sctx.GetExprCtx(), cond) } cols, lengths := expression.IndexInfo2PrefixCols(tbl.Columns, selection.Schema().Columns, tbl.Indices[tt.indexPos]) require.NotNil(t, cols) diff --git a/pkg/util/rowDecoder/decoder.go b/pkg/util/rowDecoder/decoder.go index 6c84f5afaeef9..7b19814990507 100644 --- a/pkg/util/rowDecoder/decoder.go +++ b/pkg/util/rowDecoder/decoder.go @@ -188,7 +188,7 @@ func (rd *RowDecoder) EvalRemainedExprColumnMap(ctx sessionctx.Context, row map[ continue } // Eval the column value - val, err := col.GenExpr.Eval(ctx, rd.mutRow.ToRow()) + val, err := col.GenExpr.Eval(ctx.GetExprCtx(), rd.mutRow.ToRow()) if err != nil { return nil, err }