Skip to content

Commit

Permalink
types: use flags in types package to handle clip zero case (pingcap#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored and wuhuizuo committed Apr 2, 2024
1 parent 8a75fb8 commit 5709ccd
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 55 deletions.
4 changes: 4 additions & 0 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/topsql/stmtstats"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -313,6 +314,9 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
}
vars.StmtCtx.SetTimeZone(vars.Location())
vars.StmtCtx.SetTypeFlags(types.StrictFlags.
WithClipNegativeToZero(true),
)
if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil {
logger.Warn("new session: failed to set timestamp",
log.ShortError(err))
Expand Down
16 changes: 6 additions & 10 deletions pkg/ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"fmt"
"sync"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/ddl/copr"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/dbterror"
"github.com/pingcap/tidb/pkg/util/intest"
Expand Down Expand Up @@ -148,12 +148,6 @@ func initSessCtx(
sqlMode mysql.SQLMode,
tzLocation *model.TimeZoneLocation,
) error {
// Unify the TimeZone settings in newContext.
if sessCtx.GetSessionVars().StmtCtx.TimeZone() == nil {
tz := *time.UTC
sessCtx.GetSessionVars().StmtCtx.SetTimeZone(&tz)
}
sessCtx.GetSessionVars().StmtCtx.IsDDLJobInQueue = true
// Set the row encode format version.
rowFormat := variable.GetDDLReorgRowFormat()
sessCtx.GetSessionVars().RowEncoder.Enable = rowFormat != variable.DefTiDBRowFormatV1
Expand All @@ -162,15 +156,17 @@ func initSessCtx(
if err := setSessCtxLocation(sessCtx, tzLocation); err != nil {
return errors.Trace(err)
}
sessCtx.GetSessionVars().StmtCtx.SetTimeZone(sessCtx.GetSessionVars().Location())
sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.NoZeroDate = sqlMode.HasStrictMode()

typeFlags := sessCtx.GetSessionVars().StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode())
sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags)
sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(types.StrictFlags.
WithTruncateAsWarning(!sqlMode.HasStrictMode()).
WithClipNegativeToZero(true),
)

// Prevent initializing the mock context in the workers concurrently.
// For details, see https://github.com/pingcap/tidb/issues/40879.
Expand Down
7 changes: 6 additions & 1 deletion pkg/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,12 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.SetTypeFlags(sc.TypeFlags().
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()))
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()).
// WithClipNegativeToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
WithClipNegativeToZero(sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt),
)

vars.PlanCacheParams.Reset()
if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority {
Expand Down
6 changes: 4 additions & 2 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ var fakeSctx = newFakeSctx()

func newFakeSctx() *stmtctx.StatementContext {
sc := stmtctx.NewStmtCtx()
sc.InInsertStmt = true
sc.SetTypeFlags(types.StrictFlags.
WithClipNegativeToZero(true),
)
return sc
}

Expand Down Expand Up @@ -980,7 +982,7 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool
} else {
var uintVal uint64
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
res = int64(uintVal)
}
if types.ErrOverflow.Equal(err) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ func (b *builtinCastRealAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.C
} else {
var uintVal uint64
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
i64s[i] = int64(uintVal)
}
if types.ErrOverflow.Equal(err) {
Expand Down
22 changes: 4 additions & 18 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,6 @@ func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) {
sc.typeCtx = sc.typeCtx.WithFlags(flags)
}

// UpdateTypeFlags updates the flags of the type context
func (sc *StatementContext) UpdateTypeFlags(fn func(typectx.Flags) typectx.Flags) {
flags := fn(sc.typeCtx.Flags())
sc.typeCtx = sc.typeCtx.WithFlags(flags)
}

// HandleTruncate ignores or returns the error based on the TypeContext inside.
func (sc *StatementContext) HandleTruncate(err error) error {
return sc.typeCtx.HandleTruncate(err)
Expand Down Expand Up @@ -1133,13 +1127,6 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails {
return details
}

// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
func (sc *StatementContext) ShouldClipToZero() bool {
return sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt || sc.IsDDLJobInQueue
}

// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows,
// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types.
func (sc *StatementContext) ShouldIgnoreOverflowError() bool {
Expand Down Expand Up @@ -1236,12 +1223,11 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location)
sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0
sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0
sc.SetTimeZone(tz)

typeFlags := sc.TypeFlags()
typeFlags = typeFlags.
sc.SetTypeFlags(typectx.StrictFlags.
WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0).
WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0)
sc.typeCtx = typectx.NewContext(typeFlags, tz, sc.AppendWarning)
WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0).
WithClipNegativeToZero(sc.InInsertStmt),
)
}

// GetLockWaitStartTime returns the statement pessimistic lock wait start time
Expand Down
6 changes: 0 additions & 6 deletions pkg/sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,6 @@ func TestSetStmtCtxTypeFlags(t *testing.T) {
sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning)
require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagInvalidDateAsWarning, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeFlags())

sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags {
return (flags | typectx.FlagSkipUTF8Check | typectx.FlagClipNegativeToZero) &^ typectx.FlagSkipASCIICheck
})
require.Equal(t, typectx.FlagSkipUTF8Check|typectx.FlagClipNegativeToZero|typectx.FlagInvalidDateAsWarning, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeFlags())
}

func TestResetStmtCtx(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/mockcopr/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][

// flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext {
sc := new(stmtctx.StatementContext)
sc := stmtctx.NewStmtCtx()
sc.InitFromPBFlagAndTz(flags, tz)
return sc
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/unistore/cophandler/cop_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType,

// flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext {
sc := new(stmtctx.StatementContext)
sc := stmtctx.NewStmtCtx()
sc.InitFromPBFlagAndTz(flags, tz)
return sc
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd
}

// Clip to zero if get negative value after cast to unsigned.
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.ShouldClipToZero() {
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.TypeFlags().ClipNegativeToZero() {
switch datum.Kind() {
case types.KindInt64:
if datum.GetInt64() < 0 {
Expand Down
13 changes: 13 additions & 0 deletions pkg/types/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ const (
FlagSkipUTF8MB4Check
)

// ClipNegativeToZero indicates whether the flag `FlagClipNegativeToZero` is set
func (f Flags) ClipNegativeToZero() bool {
return f&FlagClipNegativeToZero != 0
}

// WithClipNegativeToZero returns a new flags with `FlagClipNegativeToZero` set/unset according to the clip parameter
func (f Flags) WithClipNegativeToZero(clip bool) Flags {
if clip {
return f | FlagClipNegativeToZero
}
return f &^ FlagClipNegativeToZero
}

// SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set
func (f Flags) SkipASCIICheck() bool {
return f&FlagSkipASCIICheck != 0
Expand Down
16 changes: 13 additions & 3 deletions pkg/types/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ func TestWithNewFlags(t *testing.T) {
require.Equal(t, time.UTC, ctx2.Location())
}

func TestStringFlags(t *testing.T) {
func TestSimpleOnOffFlags(t *testing.T) {
cases := []struct {
name string
flag Flags
readFn func(f Flags) bool
writeFn func(f Flags, skip bool) Flags
readFn func(Flags) bool
writeFn func(Flags, bool) Flags
}{
{
name: "FlagClipNegativeToZero",
flag: FlagClipNegativeToZero,
readFn: func(f Flags) bool {
return f.ClipNegativeToZero()
},
writeFn: func(f Flags, clip bool) Flags {
return f.WithClipNegativeToZero(clip)
},
},
{
name: "FlagSkipASCIICheck",
flag: FlagSkipASCIICheck,
Expand Down
12 changes: 6 additions & 6 deletions pkg/types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
}

// ConvertIntToUint converts an int value to an uint value.
func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
if sc.ShouldClipToZero() && val < 0 {
func ConvertIntToUint(flags Flags, val int64, upperBound uint64, tp byte) (uint64, error) {
if val < 0 && flags.ClipNegativeToZero() {
return 0, overflow(val, tp)
}

Expand All @@ -167,10 +167,10 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
}

// ConvertFloatToUint converts a float value to an uint value.
func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
func ConvertFloatToUint(flags Flags, fval float64, upperBound uint64, tp byte) (uint64, error) {
val := RoundFloat(fval)
if val < 0 {
if sc.ShouldClipToZero() {
if flags.ClipNegativeToZero() {
return 0, overflow(val, tp)
}
return uint64(int64(val)), overflow(val, tp)
Expand Down Expand Up @@ -586,7 +586,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool,
i := j.GetInt64()
if unsigned {
uBound := IntergerUnsignedUpperBound(tp)
u, err := ConvertIntToUint(sc, i, uBound, tp)
u, err := ConvertIntToUint(sc.TypeFlags(), i, uBound, tp)
return int64(u), sc.HandleOverflow(err, err)
}

Expand Down Expand Up @@ -614,7 +614,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool,
return u, sc.HandleOverflow(e, e)
}
bound := IntergerUnsignedUpperBound(tp)
u, err := ConvertFloatToUint(sc, f, bound, tp)
u, err := ConvertFloatToUint(sc.TypeFlags(), f, bound, tp)
return int64(u), sc.HandleOverflow(err, err)
case JSONTypeCodeString:
str := string(hack.String(j.GetString()))
Expand Down
10 changes: 5 additions & 5 deletions pkg/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1194,11 +1194,11 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
)
switch d.k {
case KindInt64:
val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp)
val, err = ConvertIntToUint(sc.TypeFlags(), d.GetInt64(), upperBound, tp)
case KindUint64:
val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp)
case KindFloat32, KindFloat64:
val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetFloat64(), upperBound, tp)
case KindString, KindBytes:
uval, err1 := StrToUint(sc.TypeCtxOrDefault(), d.GetString(), false)
if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() {
Expand All @@ -1215,7 +1215,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
if err == nil {
err = err1
}
val, err1 = ConvertIntToUint(sc, ival, upperBound, tp)
val, err1 = ConvertIntToUint(sc.TypeFlags(), ival, upperBound, tp)
if err == nil {
err = err1
}
Expand All @@ -1230,9 +1230,9 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
case KindMysqlDecimal:
val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp)
case KindMysqlEnum:
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlSet().ToNumber(), upperBound, tp)
case KindBinaryLiteral, KindMysqlBit:
val, err = d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault())
if err == nil {
Expand Down

0 comments on commit 5709ccd

Please sign in to comment.