Skip to content

Commit

Permalink
expression: fix the behavior when adding date with big interval (#49228)
Browse files Browse the repository at this point in the history
close #49227
  • Loading branch information
lcwangchao authored Dec 12, 2023
1 parent 2ca7121 commit 724b88b
Show file tree
Hide file tree
Showing 9 changed files with 801 additions and 78 deletions.
1 change: 1 addition & 0 deletions pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/config",
"//pkg/errctx",
"//pkg/errno",
"//pkg/extension",
"//pkg/kv",
Expand Down
117 changes: 82 additions & 35 deletions pkg/expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
Expand Down Expand Up @@ -2754,7 +2755,7 @@ type baseDateArithmetical struct {

func newDateArithmeticalUtil() baseDateArithmetical {
return baseDateArithmetical{
intervalRegexp: regexp.MustCompile(`-?[\d]+`),
intervalRegexp: regexp.MustCompile(`^[+-]?[\d]+`),
}
}

Expand Down Expand Up @@ -2864,17 +2865,58 @@ func (du *baseDateArithmetical) getIntervalFromString(ctx sessionctx.Context, ar
if isNull || err != nil {
return "", true, err
}
// unit "DAY" and "HOUR" has to be specially handled.
if toLower := strings.ToLower(unit); toLower == "day" || toLower == "hour" {
if strings.ToLower(interval) == "true" {
interval = "1"
} else if strings.ToLower(interval) == "false" {

interval, err = du.intervalReformatString(ctx.GetSessionVars().StmtCtx.ErrCtx(), interval, unit)
return interval, false, err
}

func (du *baseDateArithmetical) intervalReformatString(ec errctx.Context, str string, unit string) (interval string, err error) {
switch strings.ToUpper(unit) {
case "MICROSECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR":
str = strings.TrimSpace(str)
// a single unit value has to be specially handled.
interval = du.intervalRegexp.FindString(str)
if interval == "" {
interval = "0"
} else {
interval = du.intervalRegexp.FindString(interval)
}

if interval != str {
err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str))
}
case "SECOND":
// The unit SECOND is specially handled, for example:
// date + INTERVAL "1e2" SECOND = date + INTERVAL 100 second
// date + INTERVAL "1.6" SECOND = date + INTERVAL 1.6 second
// But:
// date + INTERVAL "1e2" MINUTE = date + INTERVAL 1 MINUTE
// date + INTERVAL "1.6" MINUTE = date + INTERVAL 1 MINUTE
var dec types.MyDecimal
if err = dec.FromString([]byte(str)); err != nil {
truncatedErr := types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str)
err = ec.HandleErrorWithAlias(err, truncatedErr, truncatedErr)
}
interval = string(dec.ToString())
default:
interval = str
}
return interval, false, nil
return interval, err
}

func (du *baseDateArithmetical) intervalDecimalToString(ec errctx.Context, dec *types.MyDecimal) (string, error) {
var rounded types.MyDecimal
err := dec.Round(&rounded, 0, types.ModeHalfUp)
if err != nil {
return "", err
}

intVal, err := rounded.ToInt()
if err != nil {
if err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", dec.String())); err != nil {
return "", err
}
}

return strconv.FormatInt(intVal, 10), nil
}

func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) {
Expand Down Expand Up @@ -2921,9 +2963,8 @@ func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, a
// interval is already like the %f format.
default:
// YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND
castExpr := WrapWithCastAsString(ctx, WrapWithCastAsInt(ctx, args[1]))
interval, isNull, err = castExpr.EvalString(ctx, row)
if isNull || err != nil {
interval, err = du.intervalDecimalToString(ctx.GetSessionVars().StmtCtx.ErrCtx(), decimal)
if err != nil {
return "", true, err
}
}
Expand All @@ -2936,6 +2977,11 @@ func (du *baseDateArithmetical) getIntervalFromInt(ctx sessionctx.Context, args
if isNull || err != nil {
return "", true, err
}

if mysql.HasUnsignedFlag(args[1].GetType().GetFlag()) {
return strconv.FormatUint(uint64(interval), 10), false, nil
}

return strconv.FormatInt(interval, 10), false, nil
}

Expand All @@ -2962,7 +3008,10 @@ func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time,
}

goTime = goTime.Add(time.Duration(nano))
goTime = types.AddDate(year, month, day, goTime)
goTime, err = types.AddDate(year, month, day, goTime)
if err != nil {
return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}

// Adjust fsp as required by outer - always respect type inference.
date.SetFsp(resultFsp)
Expand All @@ -2974,10 +3023,6 @@ func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time,
return date, false, nil
}

if goTime.Year() < 0 || goTime.Year() > 9999 {
return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}

date.SetCoreTime(types.FromGoTime(goTime))
overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx.TypeCtx(), date)
if err := handleInvalidTimeError(ctx, err); err != nil {
Expand Down Expand Up @@ -3236,28 +3281,19 @@ func (du *baseDateArithmetical) vecGetIntervalFromString(b *baseBuiltinFunc, ctx
return err
}

amendInterval := func(val string) string {
return val
}
if unitLower := strings.ToLower(unit); unitLower == "day" || unitLower == "hour" {
amendInterval = func(val string) string {
if intervalLower := strings.ToLower(val); intervalLower == "true" {
return "1"
} else if intervalLower == "false" {
return "0"
}
return du.intervalRegexp.FindString(val)
}
}

ec := ctx.GetSessionVars().StmtCtx.ErrCtx()
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}

result.AppendString(amendInterval(buf.GetString(i)))
interval, err := du.intervalReformatString(ec, buf.GetString(i), unit)
if err != nil {
return err
}
result.AppendString(interval)
}
return nil
}
Expand Down Expand Up @@ -3325,10 +3361,18 @@ func (du *baseDateArithmetical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, ct
/* keep interval as original decimal */
default:
// YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND
castExpr := WrapWithCastAsString(ctx, WrapWithCastAsInt(ctx, b.args[1]))
amendInterval = func(_ string, row *chunk.Row) (string, bool, error) {
interval, isNull, err := castExpr.EvalString(ctx, *row)
return interval, isNull || err != nil, err
dec, isNull, err := b.args[1].EvalDecimal(ctx, *row)
if isNull || err != nil {
return "", true, err
}

str, err := du.intervalDecimalToString(ctx.GetSessionVars().StmtCtx.ErrCtx(), dec)
if err != nil {
return "", true, err
}

return str, false, nil
}
}

Expand Down Expand Up @@ -3376,9 +3420,12 @@ func (du *baseDateArithmetical) vecGetIntervalFromInt(b *baseBuiltinFunc, ctx se

result.ReserveString(n)
i64s := buf.Int64s()
unsigned := mysql.HasUnsignedFlag(b.args[1].GetType().GetFlag())
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
} else if unsigned {
result.AppendString(strconv.FormatUint(uint64(i64s[i]), 10))
} else {
result.AppendString(strconv.FormatInt(i64s[i], 10))
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2258,11 +2258,11 @@ func TestTimeBuiltin(t *testing.T) {
{"\"2011-11-11 10:10:10\"", "\"20\"", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "19.88", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"19.88\"", "DAY", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"prefix19suffix\"", "DAY", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"prefix19suffix\"", "DAY", "2011-11-11 10:10:10", "2011-11-11 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"20-11\"", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"20,11\"", "daY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"1000\"", "dAy", "2014-08-07 10:10:10", "2009-02-14 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"true\"", "Day", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"true\"", "Day", "2011-11-11 10:10:10", "2011-11-11 10:10:10"},
{"\"2011-11-11 10:10:10\"", "true", "Day", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
{"\"2011-11-11\"", "1", "DAY", "2011-11-12", "2011-11-10"},
{"\"2011-11-11\"", "10", "HOUR", "2011-11-11 10:00:00", "2011-11-10 14:00:00"},
Expand Down Expand Up @@ -2340,8 +2340,8 @@ func TestTimeBuiltin(t *testing.T) {
{"\"2009-01-01\"", "6/0", "HOUR_MINUTE", "<nil>", "<nil>"},
{"\"1970-01-01 12:00:00\"", "CAST(6/4 AS DECIMAL(3,1))", "HOUR_MINUTE", "1970-01-01 13:05:00", "1970-01-01 10:55:00"},
// for issue #8077
{"\"2012-01-02\"", "\"prefix8\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"prefix8prefix\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"prefix8\"", "HOUR", "2012-01-02 00:00:00", "2012-01-02 00:00:00"},
{"\"2012-01-02\"", "\"prefix8prefix\"", "HOUR", "2012-01-02 00:00:00", "2012-01-02 00:00:00"},
{"\"2012-01-02\"", "\"8:00\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"8:00:00\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
}
Expand Down
20 changes: 18 additions & 2 deletions pkg/types/core_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,30 @@ func compareTime(a, b CoreTime) int {
// Dig it and we found it's caused by golang api time.Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time ,
// it says October 32 converts to November 1 ,it conflicts with mysql.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time) {
func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time, _ error) {
// We must limit the range of year, month and day to avoid overflow.
// The datetime range is from '1000-01-01 00:00:00.000000' to '9999-12-31 23:59:59.499999',
// so it is safe to limit the added value from -10000*365 to 10000*365.
const maxAdd = 10000 * 365
const minAdd = -maxAdd
if year > maxAdd || year < minAdd ||
month > maxAdd || month < minAdd ||
day > maxAdd || day < minAdd {
return nt, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")
}

df := getFixDays(int(year), int(month), int(day), ot)
if df != 0 {
nt = ot.AddDate(int(year), int(month), df)
} else {
nt = ot.AddDate(int(year), int(month), int(day))
}
return nt

if nt.Year() < 0 || nt.Year() > 9999 {
return nt, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")
}

return nt, nil
}

func calcTimeFromSec(to *CoreTime, seconds, microseconds int) {
Expand Down
29 changes: 23 additions & 6 deletions pkg/types/core_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,33 @@ func TestAddDate(t *testing.T) {
month int
day int
ot time.Time
err bool
}{
{01, 1, 0, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{02, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{03, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{04, 2, 24, time.Date(2000, 2, 10, 0, 0, 0, 0, time.UTC)},
{01, 04, 05, time.Date(2019, 04, 01, 1, 2, 3, 4, time.UTC)},
{01, 1, 0, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{02, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{03, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{04, 2, 24, time.Date(2000, 2, 10, 0, 0, 0, 0, time.UTC), false},
{01, 04, 05, time.Date(2019, 04, 01, 1, 2, 3, 4, time.UTC), false},
{7999, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{-2000, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{8000, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{10001 * 365, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 10001 * 36, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 1, 10001 * 365, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{-2001, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{-10001 * 365, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, -10001 * 36, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 1, -10001 * 365, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
}

for _, tt := range tests {
res := AddDate(int64(tt.year), int64(tt.month), int64(tt.day), tt.ot)
res, err := AddDate(int64(tt.year), int64(tt.month), int64(tt.day), tt.ot)
if tt.err {
require.EqualError(t, err, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime").Error())
require.True(t, ErrDatetimeFunctionOverflow.Equal(err))
continue
}
require.NoError(t, err)
require.Equal(t, tt.year+tt.ot.Year(), res.Year())
}
}
Expand Down
Loading

0 comments on commit 724b88b

Please sign in to comment.