Skip to content

Commit

Permalink
add argument 'ctx' in the GetType method
Browse files Browse the repository at this point in the history
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
  • Loading branch information
YangKeao committed Jun 3, 2024
1 parent 2896454 commit c216499
Show file tree
Hide file tree
Showing 141 changed files with 1,096 additions and 1,027 deletions.
1 change: 1 addition & 0 deletions pkg/ddl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ go_test(
"//pkg/errctx",
"//pkg/errno",
"//pkg/executor",
"//pkg/expression",
"//pkg/infoschema",
"//pkg/keyspace",
"//pkg/kv",
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func BenchmarkExtractDatumByOffsets(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
ddl.ExtractDatumByOffsetsForTest(row, offsets, c.ExprColumnInfos, handleDataBuf)
ddl.ExtractDatumByOffsetsForTest(tk.Session().GetExprCtx().GetEvalCtx(), row, offsets, c.ExprColumnInfos, handleDataBuf)
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/ddl/copr/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ go_test(
shard_count = 3,
deps = [
"//pkg/expression",
"//pkg/expression/contextstatic",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/types",
Expand Down
7 changes: 4 additions & 3 deletions pkg/ddl/copr/copr_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
distsqlctx "github.com/pingcap/tidb/pkg/distsql/context"
"github.com/pingcap/tidb/pkg/expression"
exprctx "github.com/pingcap/tidb/pkg/expression/context"

// make sure mock.MockInfoschema is initialized to make sure the test pass
_ "github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/parser/model"
Expand Down Expand Up @@ -127,7 +128,7 @@ func NewCopContextBase(
return nil, err
}
hdColOffsets := resolveIndicesForHandle(expColInfos, handleIDs)
vColOffsets, vColFts := collectVirtualColumnOffsetsAndTypes(expColInfos)
vColOffsets, vColFts := collectVirtualColumnOffsetsAndTypes(exprCtx.GetEvalCtx(), expColInfos)

return &CopContextBase{
TableInfo: tblInfo,
Expand Down Expand Up @@ -319,13 +320,13 @@ func resolveIndicesForHandle(cols []*expression.Column, handleIDs []int64) []int
return offsets
}

func collectVirtualColumnOffsetsAndTypes(cols []*expression.Column) ([]int, []*types.FieldType) {
func collectVirtualColumnOffsetsAndTypes(ctx expression.EvalContext, cols []*expression.Column) ([]int, []*types.FieldType) {
var offsets []int
var fts []*types.FieldType
for i, col := range cols {
if col.VirtualExpr != nil {
offsets = append(offsets, i)
fts = append(fts, col.GetType())
fts = append(fts, col.GetType(ctx))
}
}
return offsets, fts
Expand Down
4 changes: 3 additions & 1 deletion pkg/ddl/copr/copr_ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"

"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/contextstatic"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
Expand Down Expand Up @@ -200,7 +201,8 @@ func TestCollectVirtualColumnOffsetsAndTypes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotOffsets, gotFt := collectVirtualColumnOffsetsAndTypes(tt.cols)
ctx := contextstatic.NewStaticEvalContext()
gotOffsets, gotFt := collectVirtualColumnOffsetsAndTypes(ctx, tt.cols)
require.Equal(t, gotOffsets, tt.offsets)
require.Equal(t, len(gotFt), len(tt.fieldTp))
for i, ft := range gotFt {
Expand Down
8 changes: 4 additions & 4 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2252,7 +2252,7 @@ func BuildTableInfo(
return nil, errors.Trace(err)
}
// check if the expression is bool type
if err := table.IfCheckConstraintExprBoolType(constraintInfo, tbInfo); err != nil {
if err := table.IfCheckConstraintExprBoolType(ctx.GetExprCtx().GetEvalCtx(), constraintInfo, tbInfo); err != nil {
return nil, err
}
constraintInfo.ID = allocateConstraintID(tbInfo)
Expand Down Expand Up @@ -4890,7 +4890,7 @@ func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tb
return nil
}

isUnsigned := isPartExprUnsigned(tblInfo)
isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo)
currentRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), pi.Definitions[lastPartIdx].LessThan[0], isUnsigned)
if err != nil {
return errors.Trace(err)
Expand Down Expand Up @@ -7561,7 +7561,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
Version: model.CurrLatestColumnInfoVersion,
Dependences: make(map[string]struct{}),
Hidden: true,
FieldType: *expr.GetType(),
FieldType: *expr.GetType(ctx.GetExprCtx().GetEvalCtx()),
}
// Reset some flag, it may be caused by wrong type infer. But it's not easy to fix them all, so reset them here for safety.
colInfo.DelFlag(mysql.PriKeyFlag | mysql.UniqueKeyFlag | mysql.AutoIncrementFlag)
Expand Down Expand Up @@ -9405,7 +9405,7 @@ func (d *ddl) CreateCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constr
return errors.Trace(err)
}
// check if the expression is bool type
if err := table.IfCheckConstraintExprBoolType(constraintInfo, tblInfo); err != nil {
if err := table.IfCheckConstraintExprBoolType(ctx.GetExprCtx().GetEvalCtx(), constraintInfo, tblInfo); err != nil {
return err
}
job := &model.Job{
Expand Down
6 changes: 4 additions & 2 deletions pkg/ddl/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/pkg/ddl/copr"
"github.com/pingcap/tidb/pkg/ddl/internal/session"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table"
Expand Down Expand Up @@ -67,11 +68,12 @@ func FetchChunk4Test(copCtx copr.CopContext, tbl table.PhysicalTable, startKey,
}

func ConvertRowToHandleAndIndexDatum(
ctx expression.EvalContext,
handleDataBuf, idxDataBuf []types.Datum,
row chunk.Row, copCtx copr.CopContext, idxID int64) (kv.Handle, []types.Datum, error) {
c := copCtx.GetBase()
idxData := extractDatumByOffsets(row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf)
handleData := extractDatumByOffsets(row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf)
idxData := extractDatumByOffsets(ctx, row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf)
handleData := extractDatumByOffsets(ctx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf)
handle, err := buildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, time.Local, errctx.StrictNoWarningContext)
return handle, idxData, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,7 @@ func writeChunkToLocal(
) (int, kv.Handle, error) {
iter := chunk.NewIterator4Chunk(copChunk)
c := copCtx.GetBase()
ectx := c.ExprCtx.GetEvalCtx()

maxIdxColCnt := maxIndexColumnCount(indexes)
idxDataBuf := make([]types.Datum, maxIdxColCnt)
Expand Down Expand Up @@ -1778,7 +1779,7 @@ func writeChunkToLocal(
restoreDataBuf = make([]types.Datum, len(c.HandleOutputOffsets))
}
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
handleDataBuf := extractDatumByOffsets(row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf)
handleDataBuf := extractDatumByOffsets(ectx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf)
if restore {
// restoreDataBuf should not truncate index values.
for i, datum := range handleDataBuf {
Expand All @@ -1791,7 +1792,7 @@ func writeChunkToLocal(
}
for i, index := range indexes {
idxID := index.Meta().ID
idxDataBuf = extractDatumByOffsets(
idxDataBuf = extractDatumByOffsets(ectx,
row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf)
idxData := idxDataBuf[:len(index.Meta().Columns)]
var rsData []types.Datum
Expand Down
4 changes: 2 additions & 2 deletions pkg/ddl/index_cop.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,10 @@ func constructTableScanPB(ctx exprctx.BuildContext, tblInfo *model.TableInfo, co
return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err
}

func extractDatumByOffsets(row chunk.Row, offsets []int, expCols []*expression.Column, buf []types.Datum) []types.Datum {
func extractDatumByOffsets(ctx expression.EvalContext, row chunk.Row, offsets []int, expCols []*expression.Column, buf []types.Datum) []types.Datum {
for i, offset := range offsets {
c := expCols[offset]
row.DatumWithBuffer(offset, c.GetType(), &buf[i])
row.DatumWithBuffer(offset, c.GetType(ctx), &buf[i])
}
return buf
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/index_cop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestAddIndexFetchRowsFromCoprocessor(t *testing.T) {
idxDataBuf := make([]types.Datum, len(idxInfo.Columns))

for row := iter.Begin(); row != iter.End(); row = iter.Next() {
handle, idxDatum, err := ddl.ConvertRowToHandleAndIndexDatum(handleDataBuf, idxDataBuf, row, copCtx, idxInfo.ID)
handle, idxDatum, err := ddl.ConvertRowToHandleAndIndexDatum(tk.Session().GetExprCtx().GetEvalCtx(), handleDataBuf, idxDataBuf, row, copCtx, idxInfo.ID)
require.NoError(t, err)
handles = append(handles, handle)
copiedIdxDatum := make([]types.Datum, len(idxDatum))
Expand Down
16 changes: 8 additions & 8 deletions pkg/ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ func getPartitionIntervalFromTable(ctx expression.BuildContext, tbInfo *model.Ta
return nil
}
} else {
if !isPartExprUnsigned(tbInfo) {
if !isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) {
minVal = "-9223372036854775808"
}
}
Expand Down Expand Up @@ -985,7 +985,7 @@ func generatePartitionDefinitionsFromInterval(ctx expression.BuildContext, partO
if partCol != nil {
min = getLowerBoundInt(partCol)
} else {
if !isPartExprUnsigned(tbInfo) {
if !isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) {
min = math.MinInt64
}
}
Expand Down Expand Up @@ -1476,7 +1476,7 @@ func buildRangePartitionDefinitions(ctx expression.BuildContext, defs []*ast.Par

func checkPartitionValuesIsInt(ctx expression.BuildContext, defName any, exprs []ast.ExprNode, tbInfo *model.TableInfo) error {
tp := types.NewFieldType(mysql.TypeLonglong)
if isPartExprUnsigned(tbInfo) {
if isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) {
tp.AddFlag(mysql.UnsignedFlag)
}
for _, exp := range exprs {
Expand Down Expand Up @@ -1642,7 +1642,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, schema st
if err != nil {
return errors.Trace(err)
}
if e.GetType().EvalType() == types.ETInt {
if e.GetType(ctx.GetExprCtx().GetEvalCtx()).EvalType() == types.ETInt {
return nil
}

Expand All @@ -1665,7 +1665,7 @@ func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo)
if strings.EqualFold(defs[len(defs)-1].LessThan[0], partitionMaxValue) {
defs = defs[:len(defs)-1]
}
isUnsigned := isPartExprUnsigned(tblInfo)
isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo)
var prevRangeValue any
for i := 0; i < len(defs); i++ {
if strings.EqualFold(defs[i].LessThan[0], partitionMaxValue) {
Expand Down Expand Up @@ -1728,7 +1728,7 @@ func formatListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableI
cols := make([]*model.ColumnInfo, 0, len(pi.Columns))
if len(pi.Columns) == 0 {
tp := types.NewFieldType(mysql.TypeLonglong)
if isPartExprUnsigned(tblInfo) {
if isPartExprUnsigned(ctx.GetEvalCtx(), tblInfo) {
tp.AddFlag(mysql.UnsignedFlag)
}
colTps = []*types.FieldType{tp}
Expand Down Expand Up @@ -4051,7 +4051,7 @@ func (cns columnNameSlice) At(i int) string {
return cns[i].Name.L
}

func isPartExprUnsigned(tbInfo *model.TableInfo) bool {
func isPartExprUnsigned(ectx expression.EvalContext, tbInfo *model.TableInfo) bool {
// We should not rely on any configuration, system or session variables, so use a mock ctx!
// Same as in tables.newPartitionExpr
ctx := mock.NewContext()
Expand All @@ -4060,7 +4060,7 @@ func isPartExprUnsigned(tbInfo *model.TableInfo) bool {
logutil.DDLLogger().Error("isPartExpr failed parsing expression!", zap.Error(err))
return false
}
if mysql.HasUnsignedFlag(expr.GetType().GetFlag()) {
if mysql.HasUnsignedFlag(expr.GetType(ectx).GetFlag()) {
return true
}
return false
Expand Down
8 changes: 4 additions & 4 deletions pkg/distsql/select_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ func (h *chunkRowHeap) Pop() any {

// NewSortedSelectResults is only for partition table
// If schema == nil, sort by first few columns.
func NewSortedSelectResults(selectResult []SelectResult, schema *expression.Schema, byitems []*util.ByItems, memTracker *memory.Tracker) SelectResult {
func NewSortedSelectResults(ectx expression.EvalContext, selectResult []SelectResult, schema *expression.Schema, byitems []*util.ByItems, memTracker *memory.Tracker) SelectResult {
s := &sortedSelectResults{
schema: schema,
selectResult: selectResult,
byItems: byitems,
memTracker: memTracker,
}
s.initCompareFuncs()
s.initCompareFuncs(ectx)
s.buildKeyColumns()
s.heap = &chunkRowHeap{s}
s.cachedChunks = make([]*chunk.Chunk, len(selectResult))
Expand Down Expand Up @@ -141,10 +141,10 @@ func (ssr *sortedSelectResults) updateCachedChunk(ctx context.Context, idx uint3
return nil
}

func (ssr *sortedSelectResults) initCompareFuncs() {
func (ssr *sortedSelectResults) initCompareFuncs(ectx expression.EvalContext) {
ssr.compareFuncs = make([]chunk.CompareFunc, len(ssr.byItems))
for i, item := range ssr.byItems {
keyType := item.Expr.GetType()
keyType := item.Expr.GetType(ectx)
ssr.compareFuncs[i] = chunk.GetCompareFunc(keyType)
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx *mock.Context, p multiArg
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().GetCollate())
ctor := collate.GetCollator(args[0].GetType(ctx).GetCollate())
partialDesc, finalDesc := desc.Split([]int{0, 1})

// build partial func for partial phase.
Expand Down Expand Up @@ -684,7 +684,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx *mock.Context, p multiArgsAggTest) {
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().GetCollate())
ctor := collate.GetCollator(args[0].GetType(ctx).GetCollate())
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
Expand Down
Loading

0 comments on commit c216499

Please sign in to comment.