Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experssion: table: standalone implement for EvalContext and BuildContext #51299

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions br/pkg/lightning/backend/kv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ go_library(
"//br/pkg/utils",
"//pkg/errctx",
"//pkg/expression",
"//pkg/expression/context",
"//pkg/expression/contextimpl",
"//pkg/infoschema/context",
"//pkg/kv",
"//pkg/meta/autoid",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/planner/context",
"//pkg/planner/contextimpl",
"//pkg/sessionctx",
"//pkg/sessionctx/variable",
"//pkg/table",
Expand Down
4 changes: 2 additions & 2 deletions br/pkg/lightning/backend/kv/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (e *BaseKVEncoder) getActualDatum(col *table.Column, rowID int64, inputDatu
e.SessionCtx.Vars.TxnCtx = nil
}()
}
value, err = table.GetColDefaultValue(e.SessionCtx, col.ToInfo())
value, err = table.GetColDefaultValue(e.SessionCtx.GetExprCtx(), col.ToInfo())
}
return value, err
}
Expand Down Expand Up @@ -353,7 +353,7 @@ func evalGeneratedColumns(se *Session, record []types.Datum, cols []*table.Colum
mutRow := chunk.MutRowFromDatums(record)
for _, gc := range genCols {
col := cols[gc.Index].ToInfo()
evaluated, err := gc.Expr.Eval(se, mutRow.ToRow())
evaluated, err := gc.Expr.Eval(se.GetExprCtx(), mutRow.ToRow())
if err != nil {
return col, err
}
Expand Down
28 changes: 23 additions & 5 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/manual"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/errctx"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl"
infoschema "github.com/pingcap/tidb/pkg/infoschema/context"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/model"
planctx "github.com/pingcap/tidb/pkg/planner/context"
planctximpl "github.com/pingcap/tidb/pkg/planner/contextimpl"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
tbctx "github.com/pingcap/tidb/pkg/table/context"
Expand Down Expand Up @@ -264,15 +267,21 @@ func (*transaction) SetAssertion(_ []byte, _ ...kv.FlagsOp) error {
return nil
}

type planCtxImpl struct {
*Session
*planctximpl.PlanCtxExtendedImpl
}

// Session is a trimmed down Session type which only wraps our own trimmed-down
// transaction type and provides the session variables to the TiDB library
// optimized for Lightning.
type Session struct {
sessionctx.Context
planctx.EmptyPlanContextExtended
txn transaction
Vars *variable.SessionVars
tblctx *tbctximpl.TableContextImpl
txn transaction
Vars *variable.SessionVars
planctx *planCtxImpl
tblctx *tbctximpl.TableContextImpl
// currently, we only set `CommonAddRecordCtx`
values map[fmt.Stringer]any
}
Expand Down Expand Up @@ -333,7 +342,11 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
vars.TxnCtx = nil
s.Vars = vars
s.tblctx = tbctximpl.NewTableContextImpl(s)
s.planctx = &planCtxImpl{
Session: s,
PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s, exprctximpl.NewExprExtendedImpl(s)),
}
s.tblctx = tbctximpl.NewTableContextImpl(s, s.planctx)
s.txn.kvPairs = &Pairs{}

return s
Expand Down Expand Up @@ -363,7 +376,12 @@ func (se *Session) GetSessionVars() *variable.SessionVars {

// GetPlanCtx returns the PlanContext.
func (se *Session) GetPlanCtx() planctx.PlanContext {
return se
return se.planctx
}

// GetExprCtx returns the expression context of the session.
func (se *Session) GetExprCtx() exprctx.BuildContext {
return se.planctx
}

// GetTableCtx returns the table.MutateContext
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func CollectGeneratedColumns(se *Session, meta *model.TableInfo, cols []*table.C
for i, col := range cols {
if col.GeneratedExpr != nil {
expr, err := expression.BuildSimpleExpr(
se,
se.GetExprCtx(),
col.GeneratedExpr.Internal(),
expression.WithInputSchemaAndNames(schema, names, meta),
expression.WithAllowCastArray(true),
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/backfilling.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ func makeupDecodeColMap(sessCtx sessionctx.Context, dbName model.CIStr, t table.
for _, col := range t.WritableCols() {
writableColInfos = append(writableColInfos, col.ColumnInfo)
}
exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx, dbName, t.Meta().Name, writableColInfos, t.Meta())
exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta())
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra
val := w.rowMap[w.oldColInfo.ID]
col := w.newColInfo
if val.Kind() == types.KindNull && col.FieldType.GetType() == mysql.TypeTimestamp && mysql.HasNotNullFlag(col.GetFlag()) {
if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx, col.GetType(), col.GetDecimal()); err == nil {
if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx.GetExprCtx(), col.GetType(), col.GetDecimal()); err == nil {
// convert null value to timestamp should be substituted with current timestamp if NOT_NULL flag is set.
w.rowMap[w.oldColInfo.ID] = v
}
Expand Down Expand Up @@ -1971,7 +1971,7 @@ func generateOriginDefaultValue(col *model.ColumnInfo, ctx sessionctx.Context) (
if ctx == nil {
t = time.Now()
} else {
t, _ = expression.GetStmtTimestamp(ctx)
t, _ = expression.GetStmtTimestamp(ctx.GetExprCtx())
}
if col.GetType() == mysql.TypeTimestamp {
odValue = types.NewTime(types.FromGoTime(t.UTC()), col.GetType(), col.GetDecimal()).String()
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/copr/copr_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func NewCopContextBase(
handleIDs = []int64{extra.ID}
}

expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx,
expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(),
model.CIStr{} /* unused */, tblInfo.Name, colInfos, tblInfo)
if err != nil {
return nil, err
Expand Down
40 changes: 20 additions & 20 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value an
if value != nil && ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() &&
ctx.GetSessionVars().SQLMode.HasStrictMode() && types.IsTypeTime(col.GetType()) {
if vv, ok := value.(string); ok {
timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil)
timeValue, err := expression.GetTimeValue(ctx.GetExprCtx(), vv, col.GetType(), col.GetDecimal(), nil)
if err != nil {
return hasDefaultValue, value, errors.Trace(err)
}
Expand Down Expand Up @@ -1385,7 +1385,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate {
vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil)
vd, err := expression.GetTimeValue(ctx.GetExprCtx(), option.Expr, tp, fsp, nil)
value := vd.GetValue()
if err != nil {
return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
Expand All @@ -1405,7 +1405,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

// evaluate the non-function-call expr to a certain value.
v, err := expression.EvalSimpleAst(ctx, option.Expr)
v, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr)
if err != nil {
return nil, false, errors.Trace(err)
}
Expand Down Expand Up @@ -1614,7 +1614,7 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
if c.DefaultIsExpr {
return nil
}
if _, err := table.GetColDefaultValue(ctx, c.ToInfo()); err != nil {
if _, err := table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo()); err != nil {
return types.ErrInvalidDefault.GenWithStackByArgs(c.Name)
}
return nil
Expand Down Expand Up @@ -3212,7 +3212,7 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) erro

// checkPartitionByList checks validity of a "BY LIST" partition.
func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error {
return checkListPartitionValue(ctx, tbInfo)
return checkListPartitionValue(ctx.GetExprCtx(), tbInfo)
}

func isValidKeyPartitionColType(fieldType types.FieldType) bool {
Expand Down Expand Up @@ -3305,7 +3305,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef
// PARTITION p1 VALUES LESS THAN (10,20,'mmm')
// PARTITION p2 VALUES LESS THAN (15,30,'sss')
colInfo := findColumnByName(pi.Columns[i].L, tbInfo)
succ, err := parseAndEvalBoolExpr(ctx, curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo)
succ, err := parseAndEvalBoolExpr(ctx.GetExprCtx(), curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo)
if err != nil {
return false, err
}
Expand All @@ -3317,7 +3317,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef
return false, nil
}

func parseAndEvalBoolExpr(ctx sessionctx.Context, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) {
func parseAndEvalBoolExpr(ctx expression.BuildContext, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) {
lexpr, err := expression.ParseSimpleExpr(ctx, l, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType))
if err != nil {
return false, err
Expand Down Expand Up @@ -4366,7 +4366,7 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *
return d.hashPartitionManagement(ctx, ident, spec, pi)
}

partInfo, err := BuildAddedPartitionInfo(ctx, meta, spec)
partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4612,7 +4612,7 @@ func (d *ddl) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, spec
if err != nil {
return errors.Trace(err)
}
partInfo, err := BuildAddedPartitionInfo(ctx, meta, spec)
partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4678,7 +4678,7 @@ func (d *ddl) RemovePartitioning(ctx sessionctx.Context, ident ast.Ident, spec *
partNames[i] = pi.Definitions[i].Name.L
}
meta.Partition.Type = model.PartitionTypeNone
partInfo, err := BuildAddedPartitionInfo(ctx, meta, newSpec)
partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, newSpec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4756,11 +4756,11 @@ func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tb
}

isUnsigned := isPartExprUnsigned(tblInfo)
currentRangeValue, _, err := getRangeValue(ctx, pi.Definitions[lastPartIdx].LessThan[0], isUnsigned)
currentRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), pi.Definitions[lastPartIdx].LessThan[0], isUnsigned)
if err != nil {
return errors.Trace(err)
}
newRangeValue, _, err := getRangeValue(ctx, partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned)
newRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4935,7 +4935,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *
}

if spec.Tp == ast.AlterTableDropFirstPartition {
intervalOptions := getPartitionIntervalFromTable(ctx, meta)
intervalOptions := getPartitionIntervalFromTable(ctx.GetExprCtx(), meta)
if intervalOptions == nil {
return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(
"FIRST PARTITION, does not seem like an INTERVAL partitioned table")
Expand All @@ -4945,7 +4945,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *
"FIRST PARTITION, table info already contains partition definitions")
}
spec.Partition.Interval = intervalOptions
err = GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition)
err = GeneratePartDefsFromInterval(ctx.GetExprCtx(), spec.Tp, meta, spec.Partition)
if err != nil {
return err
}
Expand Down Expand Up @@ -5447,7 +5447,7 @@ func setDefaultValueWithBinaryPadding(col *table.Column, value any) error {
}

func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error {
value, err := expression.EvalSimpleAst(ctx, option.Expr)
value, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -5808,7 +5808,7 @@ func GetModifiableColumnJob(
oldTypeFlags := sv.TypeFlags()
newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false)
sv.SetTypeFlags(newTypeFlags)
_, err = buildPartitionDefinitionsInfo(sctx, pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
_, err = buildPartitionDefinitionsInfo(sctx.GetExprCtx(), pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
sv.SetTypeFlags(oldTypeFlags)
if err != nil {
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error())
Expand Down Expand Up @@ -7373,7 +7373,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
if err != nil {
return nil, errors.Trace(err)
}
expr, err := expression.BuildSimpleExpr(ctx, idxPart.Expr,
expr, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), idxPart.Expr,
expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo),
expression.WithAllowCastArray(true),
)
Expand Down Expand Up @@ -7968,7 +7968,7 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str
}

// BuildAddedPartitionInfo build alter table add partition info
func BuildAddedPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
func BuildAddedPartitionInfo(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
numParts := uint64(0)
switch meta.Partition.Type {
case model.PartitionTypeNone:
Expand Down Expand Up @@ -8034,7 +8034,7 @@ func BuildAddedPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, spec
return part, nil
}

func buildAddedPartitionDefs(ctx sessionctx.Context, meta *model.TableInfo, spec *ast.AlterTableSpec) error {
func buildAddedPartitionDefs(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) error {
partInterval := getPartitionIntervalFromTable(ctx, meta)
if partInterval == nil {
return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(
Expand All @@ -8052,7 +8052,7 @@ func buildAddedPartitionDefs(ctx sessionctx.Context, meta *model.TableInfo, spec
return GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition)
}

func checkAndGetColumnsTypeAndValuesMatch(ctx sessionctx.Context, colTypes []types.FieldType, exprs []ast.ExprNode) ([]string, error) {
func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes []types.FieldType, exprs []ast.ExprNode) ([]string, error) {
// Validate() has already checked len(colNames) = len(exprs)
// create table ... partition by range columns (cols)
// partition p0 values less than (expr)
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/index_cop.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ func buildDAGPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*m
func constructTableScanPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.Executor, error) {
tblScan := tables.BuildTableScanFromInfos(tblInfo, colInfos)
tblScan.TableId = tblInfo.ID
err := tables.SetPBColumnsDefaultValue(sCtx, tblScan.Columns, colInfos)
err := tables.SetPBColumnsDefaultValue(sCtx.GetExprCtx(), tblScan.Columns, colInfos)
return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err
}

Expand Down
Loading
Loading