Skip to content

Commit

Permalink
experssion: table: standalone implement for EvalContext and `BuildC…
Browse files Browse the repository at this point in the history
…ontext` (#51299)

close #51298
  • Loading branch information
lcwangchao authored Feb 27, 2024
1 parent 4c88ad1 commit 6f02e99
Show file tree
Hide file tree
Showing 98 changed files with 507 additions and 365 deletions.
3 changes: 3 additions & 0 deletions br/pkg/lightning/backend/kv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ go_library(
"//br/pkg/utils",
"//pkg/errctx",
"//pkg/expression",
"//pkg/expression/context",
"//pkg/expression/contextimpl",
"//pkg/infoschema/context",
"//pkg/kv",
"//pkg/meta/autoid",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/planner/context",
"//pkg/planner/contextimpl",
"//pkg/sessionctx",
"//pkg/sessionctx/variable",
"//pkg/table",
Expand Down
4 changes: 2 additions & 2 deletions br/pkg/lightning/backend/kv/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (e *BaseKVEncoder) getActualDatum(col *table.Column, rowID int64, inputDatu
e.SessionCtx.Vars.TxnCtx = nil
}()
}
value, err = table.GetColDefaultValue(e.SessionCtx, col.ToInfo())
value, err = table.GetColDefaultValue(e.SessionCtx.GetExprCtx(), col.ToInfo())
}
return value, err
}
Expand Down Expand Up @@ -353,7 +353,7 @@ func evalGeneratedColumns(se *Session, record []types.Datum, cols []*table.Colum
mutRow := chunk.MutRowFromDatums(record)
for _, gc := range genCols {
col := cols[gc.Index].ToInfo()
evaluated, err := gc.Expr.Eval(se, mutRow.ToRow())
evaluated, err := gc.Expr.Eval(se.GetExprCtx(), mutRow.ToRow())
if err != nil {
return col, err
}
Expand Down
28 changes: 23 additions & 5 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/manual"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/errctx"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl"
infoschema "github.com/pingcap/tidb/pkg/infoschema/context"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/model"
planctx "github.com/pingcap/tidb/pkg/planner/context"
planctximpl "github.com/pingcap/tidb/pkg/planner/contextimpl"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
tbctx "github.com/pingcap/tidb/pkg/table/context"
Expand Down Expand Up @@ -264,15 +267,21 @@ func (*transaction) SetAssertion(_ []byte, _ ...kv.FlagsOp) error {
return nil
}

type planCtxImpl struct {
*Session
*planctximpl.PlanCtxExtendedImpl
}

// Session is a trimmed down Session type which only wraps our own trimmed-down
// transaction type and provides the session variables to the TiDB library
// optimized for Lightning.
type Session struct {
sessionctx.Context
planctx.EmptyPlanContextExtended
txn transaction
Vars *variable.SessionVars
tblctx *tbctximpl.TableContextImpl
txn transaction
Vars *variable.SessionVars
planctx *planCtxImpl
tblctx *tbctximpl.TableContextImpl
// currently, we only set `CommonAddRecordCtx`
values map[fmt.Stringer]any
}
Expand Down Expand Up @@ -333,7 +342,11 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
vars.TxnCtx = nil
s.Vars = vars
s.tblctx = tbctximpl.NewTableContextImpl(s)
s.planctx = &planCtxImpl{
Session: s,
PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s, exprctximpl.NewExprExtendedImpl(s)),
}
s.tblctx = tbctximpl.NewTableContextImpl(s, s.planctx)
s.txn.kvPairs = &Pairs{}

return s
Expand Down Expand Up @@ -363,7 +376,12 @@ func (se *Session) GetSessionVars() *variable.SessionVars {

// GetPlanCtx returns the PlanContext.
func (se *Session) GetPlanCtx() planctx.PlanContext {
return se
return se.planctx
}

// GetExprCtx returns the expression context of the session.
func (se *Session) GetExprCtx() exprctx.BuildContext {
return se.planctx
}

// GetTableCtx returns the table.MutateContext
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func CollectGeneratedColumns(se *Session, meta *model.TableInfo, cols []*table.C
for i, col := range cols {
if col.GeneratedExpr != nil {
expr, err := expression.BuildSimpleExpr(
se,
se.GetExprCtx(),
col.GeneratedExpr.Internal(),
expression.WithInputSchemaAndNames(schema, names, meta),
expression.WithAllowCastArray(true),
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/backfilling.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ func makeupDecodeColMap(sessCtx sessionctx.Context, dbName model.CIStr, t table.
for _, col := range t.WritableCols() {
writableColInfos = append(writableColInfos, col.ColumnInfo)
}
exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx, dbName, t.Meta().Name, writableColInfos, t.Meta())
exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta())
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra
val := w.rowMap[w.oldColInfo.ID]
col := w.newColInfo
if val.Kind() == types.KindNull && col.FieldType.GetType() == mysql.TypeTimestamp && mysql.HasNotNullFlag(col.GetFlag()) {
if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx, col.GetType(), col.GetDecimal()); err == nil {
if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx.GetExprCtx(), col.GetType(), col.GetDecimal()); err == nil {
// convert null value to timestamp should be substituted with current timestamp if NOT_NULL flag is set.
w.rowMap[w.oldColInfo.ID] = v
}
Expand Down Expand Up @@ -1971,7 +1971,7 @@ func generateOriginDefaultValue(col *model.ColumnInfo, ctx sessionctx.Context) (
if ctx == nil {
t = time.Now()
} else {
t, _ = expression.GetStmtTimestamp(ctx)
t, _ = expression.GetStmtTimestamp(ctx.GetExprCtx())
}
if col.GetType() == mysql.TypeTimestamp {
odValue = types.NewTime(types.FromGoTime(t.UTC()), col.GetType(), col.GetDecimal()).String()
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/copr/copr_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func NewCopContextBase(
handleIDs = []int64{extra.ID}
}

expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx,
expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx.GetExprCtx(),
model.CIStr{} /* unused */, tblInfo.Name, colInfos, tblInfo)
if err != nil {
return nil, err
Expand Down
40 changes: 20 additions & 20 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value an
if value != nil && ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() &&
ctx.GetSessionVars().SQLMode.HasStrictMode() && types.IsTypeTime(col.GetType()) {
if vv, ok := value.(string); ok {
timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil)
timeValue, err := expression.GetTimeValue(ctx.GetExprCtx(), vv, col.GetType(), col.GetDecimal(), nil)
if err != nil {
return hasDefaultValue, value, errors.Trace(err)
}
Expand Down Expand Up @@ -1425,7 +1425,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate {
vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil)
vd, err := expression.GetTimeValue(ctx.GetExprCtx(), option.Expr, tp, fsp, nil)
value := vd.GetValue()
if err != nil {
return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
Expand All @@ -1445,7 +1445,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

// evaluate the non-function-call expr to a certain value.
v, err := expression.EvalSimpleAst(ctx, option.Expr)
v, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr)
if err != nil {
return nil, false, errors.Trace(err)
}
Expand Down Expand Up @@ -1654,7 +1654,7 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
if c.DefaultIsExpr {
return nil
}
if _, err := table.GetColDefaultValue(ctx, c.ToInfo()); err != nil {
if _, err := table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo()); err != nil {
return types.ErrInvalidDefault.GenWithStackByArgs(c.Name)
}
return nil
Expand Down Expand Up @@ -3252,7 +3252,7 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) erro

// checkPartitionByList checks validity of a "BY LIST" partition.
func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error {
return checkListPartitionValue(ctx, tbInfo)
return checkListPartitionValue(ctx.GetExprCtx(), tbInfo)
}

func isValidKeyPartitionColType(fieldType types.FieldType) bool {
Expand Down Expand Up @@ -3345,7 +3345,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef
// PARTITION p1 VALUES LESS THAN (10,20,'mmm')
// PARTITION p2 VALUES LESS THAN (15,30,'sss')
colInfo := findColumnByName(pi.Columns[i].L, tbInfo)
succ, err := parseAndEvalBoolExpr(ctx, curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo)
succ, err := parseAndEvalBoolExpr(ctx.GetExprCtx(), curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo)
if err != nil {
return false, err
}
Expand All @@ -3357,7 +3357,7 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef
return false, nil
}

func parseAndEvalBoolExpr(ctx sessionctx.Context, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) {
func parseAndEvalBoolExpr(ctx expression.BuildContext, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) {
lexpr, err := expression.ParseSimpleExpr(ctx, l, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType))
if err != nil {
return false, err
Expand Down Expand Up @@ -4406,7 +4406,7 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *
return d.hashPartitionManagement(ctx, ident, spec, pi)
}

partInfo, err := BuildAddedPartitionInfo(ctx, meta, spec)
partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4652,7 +4652,7 @@ func (d *ddl) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, spec
if err != nil {
return errors.Trace(err)
}
partInfo, err := BuildAddedPartitionInfo(ctx, meta, spec)
partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4718,7 +4718,7 @@ func (d *ddl) RemovePartitioning(ctx sessionctx.Context, ident ast.Ident, spec *
partNames[i] = pi.Definitions[i].Name.L
}
meta.Partition.Type = model.PartitionTypeNone
partInfo, err := BuildAddedPartitionInfo(ctx, meta, newSpec)
partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, newSpec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4796,11 +4796,11 @@ func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tb
}

isUnsigned := isPartExprUnsigned(tblInfo)
currentRangeValue, _, err := getRangeValue(ctx, pi.Definitions[lastPartIdx].LessThan[0], isUnsigned)
currentRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), pi.Definitions[lastPartIdx].LessThan[0], isUnsigned)
if err != nil {
return errors.Trace(err)
}
newRangeValue, _, err := getRangeValue(ctx, partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned)
newRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -4975,7 +4975,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *
}

if spec.Tp == ast.AlterTableDropFirstPartition {
intervalOptions := getPartitionIntervalFromTable(ctx, meta)
intervalOptions := getPartitionIntervalFromTable(ctx.GetExprCtx(), meta)
if intervalOptions == nil {
return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(
"FIRST PARTITION, does not seem like an INTERVAL partitioned table")
Expand All @@ -4985,7 +4985,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *
"FIRST PARTITION, table info already contains partition definitions")
}
spec.Partition.Interval = intervalOptions
err = GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition)
err = GeneratePartDefsFromInterval(ctx.GetExprCtx(), spec.Tp, meta, spec.Partition)
if err != nil {
return err
}
Expand Down Expand Up @@ -5490,7 +5490,7 @@ func setDefaultValueWithBinaryPadding(col *table.Column, value any) error {
}

func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error {
value, err := expression.EvalSimpleAst(ctx, option.Expr)
value, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -5851,7 +5851,7 @@ func GetModifiableColumnJob(
oldTypeFlags := sv.TypeFlags()
newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false)
sv.SetTypeFlags(newTypeFlags)
_, err = buildPartitionDefinitionsInfo(sctx, pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
_, err = buildPartitionDefinitionsInfo(sctx.GetExprCtx(), pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
sv.SetTypeFlags(oldTypeFlags)
if err != nil {
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error())
Expand Down Expand Up @@ -7416,7 +7416,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
if err != nil {
return nil, errors.Trace(err)
}
expr, err := expression.BuildSimpleExpr(ctx, idxPart.Expr,
expr, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), idxPart.Expr,
expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo),
expression.WithAllowCastArray(true),
)
Expand Down Expand Up @@ -8011,7 +8011,7 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str
}

// BuildAddedPartitionInfo build alter table add partition info
func BuildAddedPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
func BuildAddedPartitionInfo(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) {
numParts := uint64(0)
switch meta.Partition.Type {
case model.PartitionTypeNone:
Expand Down Expand Up @@ -8077,7 +8077,7 @@ func BuildAddedPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, spec
return part, nil
}

func buildAddedPartitionDefs(ctx sessionctx.Context, meta *model.TableInfo, spec *ast.AlterTableSpec) error {
func buildAddedPartitionDefs(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) error {
partInterval := getPartitionIntervalFromTable(ctx, meta)
if partInterval == nil {
return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(
Expand All @@ -8095,7 +8095,7 @@ func buildAddedPartitionDefs(ctx sessionctx.Context, meta *model.TableInfo, spec
return GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition)
}

func checkAndGetColumnsTypeAndValuesMatch(ctx sessionctx.Context, colTypes []types.FieldType, exprs []ast.ExprNode) ([]string, error) {
func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes []types.FieldType, exprs []ast.ExprNode) ([]string, error) {
// Validate() has already checked len(colNames) = len(exprs)
// create table ... partition by range columns (cols)
// partition p0 values less than (expr)
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/index_cop.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ func buildDAGPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*m
func constructTableScanPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.Executor, error) {
tblScan := tables.BuildTableScanFromInfos(tblInfo, colInfos)
tblScan.TableId = tblInfo.ID
err := tables.SetPBColumnsDefaultValue(sCtx, tblScan.Columns, colInfos)
err := tables.SetPBColumnsDefaultValue(sCtx.GetExprCtx(), tblScan.Columns, colInfos)
return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err
}

Expand Down
Loading

0 comments on commit 6f02e99

Please sign in to comment.