Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evalengine: implement date/time math #13274

Merged
merged 4 commits into from
Jun 12, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
evalengine: implement date/time math
Signed-off-by: Vicent Marti <vmg@strn.cat>
vmg committed Jun 8, 2023
commit 1951d20e7ed93711370111908f8d363fa404b546
12 changes: 12 additions & 0 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 51 additions & 6 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
@@ -3343,7 +3343,7 @@ func (asm *assembler) Fn_Sysdate(prec uint8) {
if tz := env.currentTimezone(); tz != nil {
now = now.In(tz)
}
val.bytes = datetime.FromStdTime(now).Format(prec)
val.bytes = datetime.NewDateTimeFromStd(now).Format(prec)
env.vm.stack[env.vm.sp] = val
env.vm.sp++
return 1
@@ -3591,7 +3591,7 @@ func (asm *assembler) Fn_FROM_UNIXTIME_i() {
if tz := env.currentTimezone(); tz != nil {
t = t.In(tz)
}
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.FromStdTime(t), 0)
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.NewDateTimeFromStd(t), 0)
return 1
}, "FN FROM_UNIXTIME INT64(SP-1)")
}
@@ -3607,7 +3607,7 @@ func (asm *assembler) Fn_FROM_UNIXTIME_u() {
if tz := env.currentTimezone(); tz != nil {
t = t.In(tz)
}
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.FromStdTime(t), 0)
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.NewDateTimeFromStd(t), 0)
return 1
}, "FN FROM_UNIXTIME UINT64(SP-1)")
}
@@ -3631,7 +3631,7 @@ func (asm *assembler) Fn_FROM_UNIXTIME_d() {
if tz := env.currentTimezone(); tz != nil {
t = t.In(tz)
}
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.FromStdTime(t), int(arg.length))
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.NewDateTimeFromStd(t), int(arg.length))
return 1
}, "FN FROM_UNIXTIME DECIMAL(SP-1)")
}
@@ -3648,7 +3648,7 @@ func (asm *assembler) Fn_FROM_UNIXTIME_f() {
if tz := env.currentTimezone(); tz != nil {
t = t.In(tz)
}
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.FromStdTime(t), 6)
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalDateTime(datetime.NewDateTimeFromStd(t), 6)
return 1
}, "FN FROM_UNIXTIME FLOAT(SP-1)")
}
@@ -3674,7 +3674,7 @@ func (asm *assembler) Fn_MAKEDATE() {
if t.IsZero() {
env.vm.stack[env.vm.sp-2] = nil
} else {
env.vm.stack[env.vm.sp-2] = env.vm.arena.newEvalDate(datetime.FromStdTime(t).Date)
env.vm.stack[env.vm.sp-2] = env.vm.arena.newEvalDate(datetime.NewDateTimeFromStd(t).Date)
}
env.vm.sp--
return 1
@@ -4240,3 +4240,48 @@ func (asm *assembler) Fn_UUID_TO_BIN1() {
return 1
}, "FN UUID_TO_BIN VARBINARY(SP-2) INT64(SP-1)")
}

func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
asm.adjustStack(-1)
asm.emit(func(env *ExpressionEnv) int {
interval := evalToInterval(env.vm.stack[env.vm.sp-1], unit, sub)
if interval == nil {
env.vm.stack[env.vm.sp-2] = nil
env.vm.sp--
return 1
}

tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{})
env.vm.sp--
return 1
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
}

func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.TypedCollation) {
asm.adjustStack(-1)
asm.emit(func(env *ExpressionEnv) int {
var interval *datetime.Interval
var tmp *evalTemporal

interval = evalToInterval(env.vm.stack[env.vm.sp-1], unit, sub)
if interval == nil {
goto baddate
}

tmp = evalToTemporal(env.vm.stack[env.vm.sp-2])
if tmp == nil {
goto baddate
}

env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col)
env.vm.sp--
return 1

baddate:
env.vm.stack[env.vm.sp-2] = nil
env.vm.sp--
return 1
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")

}
93 changes: 93 additions & 0 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ import (
"time"

"vitess.io/vitess/go/hack"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/decimal"
"vitess.io/vitess/go/mysql/json"
@@ -143,6 +144,30 @@ func (e *evalTemporal) toStdTime(loc *time.Location) time.Time {
return e.dt.ToStdTime(loc)
}

func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation) eval {
var tmp *evalTemporal
var ok bool

switch tt := e.SQLType(); {
case tt == sqltypes.Date && !interval.Unit().HasTimeParts():
tmp = &evalTemporal{t: e.t}
tmp.dt.Date, ok = e.dt.Date.AddInterval(interval)
case tt == sqltypes.Time && !interval.Unit().HasDateParts():
tmp = &evalTemporal{t: e.t}
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
tmp = e.toDateTime(int(e.prec))
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
}
if !ok {
return nil
}
if strcoll.Valid() {
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), strcoll)
}
return tmp
}

func newEvalDateTime(dt datetime.DateTime, l int) *evalTemporal {
return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)}
}
@@ -190,6 +215,74 @@ func precision(req, got int) int {
return req
}

func evalToTemporal(e eval) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
return e
case *evalBytes:
if t, l, ok := datetime.ParseDateTime(e.string(), -1); ok {
return newEvalDateTime(t, l)
}
if d, ok := datetime.ParseDate(e.string()); ok {
return newEvalDate(d)
}
if t, l, ok := datetime.ParseTime(e.string(), -1); ok {
return newEvalTime(t, l)
}
case *evalInt64:
if t, ok := datetime.ParseDateTimeInt64(e.i); ok {
return newEvalDateTime(t, 0)
}
if d, ok := datetime.ParseDateInt64(e.i); ok {
return newEvalDate(d)
}
if t, ok := datetime.ParseTimeInt64(e.i); ok {
return newEvalTime(t, 0)
}
case *evalUint64:
if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok {
return newEvalDateTime(t, 0)
}
if d, ok := datetime.ParseDateInt64(int64(e.u)); ok {
return newEvalDate(d)
}
if t, ok := datetime.ParseTimeInt64(int64(e.u)); ok {
return newEvalTime(t, 0)
}
case *evalFloat:
if t, l, ok := datetime.ParseDateTimeFloat(e.f, -1); ok {
return newEvalDateTime(t, l)
}
if d, ok := datetime.ParseDateFloat(e.f); ok {
return newEvalDate(d)
}
if t, l, ok := datetime.ParseTimeFloat(e.f, -1); ok {
return newEvalTime(t, l)
}
case *evalDecimal:
if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, -1); ok {
return newEvalDateTime(t, l)
}
if d, ok := datetime.ParseDateDecimal(e.dec); ok {
return newEvalDate(d)
}
if d, l, ok := datetime.ParseTimeDecimal(e.dec, e.length, -1); ok {
return newEvalTime(d, l)
}
case *evalJSON:
if dt, ok := e.DateTime(); ok {
if dt.Date.IsZero() {
return newEvalTime(dt.Time, datetime.DefaultPrecision)
}
if dt.Time.IsZero() {
return newEvalDate(dt.Date)
}
return newEvalDateTime(dt, datetime.DefaultPrecision)
}
}
return nil
}

func evalToTime(e eval, l int) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/expr_env.go
Original file line number Diff line number Diff line change
@@ -51,9 +51,9 @@ type (

func (env *ExpressionEnv) time(utc bool) datetime.DateTime {
if utc {
return datetime.FromStdTime(env.now.UTC())
return datetime.NewDateTimeFromStd(env.now.UTC())
}
return datetime.FromStdTime(env.now)
return datetime.NewDateTimeFromStd(env.now)
}

func (env *ExpressionEnv) currentUser() string {
104 changes: 100 additions & 4 deletions go/vt/vtgate/evalengine/fn_time.go
Original file line number Diff line number Diff line change
@@ -141,6 +141,13 @@ type (
builtinYearWeek struct {
CallExpr
}

builtinDateMath struct {
CallExpr
sub bool
unit datetime.IntervalType
collate collations.ID
}
)

var _ Expr = (*builtinNow)(nil)
@@ -212,7 +219,7 @@ func (call *builtinSysdate) eval(env *ExpressionEnv) (eval, error) {
if tz := env.currentTimezone(); tz != nil {
now = now.In(tz)
}
return newEvalRaw(sqltypes.Datetime, datetime.FromStdTime(now).Format(call.prec), collationBinary), nil
return newEvalRaw(sqltypes.Datetime, datetime.NewDateTimeFromStd(now).Format(call.prec), collationBinary), nil
}

func (call *builtinSysdate) typeof(_ *ExpressionEnv, _ []*querypb.Field) (sqltypes.Type, typeFlag) {
@@ -340,7 +347,7 @@ func convertTz(dt datetime.DateTime, from, to *time.Location) (datetime.DateTime
if err != nil {
return datetime.DateTime{}, false
}
return datetime.FromStdTime(ts.In(to)), true
return datetime.NewDateTimeFromStd(ts.In(to)), true
}

func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) {
@@ -646,7 +653,7 @@ func (b *builtinFromUnixtime) eval(env *ExpressionEnv) (eval, error) {
t = t.In(tz)
}

dt := newEvalDateTime(datetime.FromStdTime(t), prec)
dt := newEvalDateTime(datetime.NewDateTimeFromStd(t), prec)

if len(b.Arguments) == 1 {
return dt, nil
@@ -809,7 +816,7 @@ func (b *builtinMakedate) eval(env *ExpressionEnv) (eval, error) {
if t.IsZero() {
return nil, nil
}
return newEvalDate(datetime.FromStdTime(t).Date), nil
return newEvalDate(datetime.NewDateTimeFromStd(t).Date), nil
}

func (b *builtinMakedate) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {
@@ -1688,3 +1695,92 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) {
c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil
}

func evalToInterval(itv eval, unit datetime.IntervalType, negate bool) *datetime.Interval {
switch itv := itv.(type) {
case *evalBytes:
return datetime.ParseInterval(itv.string(), unit, negate)
case *evalFloat:
return datetime.ParseIntervalFloat(itv.f, unit, negate)
case *evalDecimal:
return datetime.ParseIntervalDecimal(itv.dec, itv.length, unit, negate)
default:
return datetime.ParseIntervalInt64(evalToNumeric(itv, false).toInt64().i, unit, negate)
}
}

func (call *builtinDateMath) eval(env *ExpressionEnv) (eval, error) {
date, err := call.Arguments[0].eval(env)
if err != nil || date == nil {
return date, err
}

itv, err := call.Arguments[1].eval(env)
if err != nil || itv == nil {
return itv, err
}

interval := evalToInterval(itv, call.unit, call.sub)
if interval == nil {
return nil, nil
}

if tmp, ok := date.(*evalTemporal); ok {
return tmp.addInterval(interval, collations.TypedCollation{}), nil
}

if tmp := evalToTemporal(date); tmp != nil {
return tmp.addInterval(interval, defaultCoercionCollation(call.collate)), nil
}

return nil, nil
}

func (call *builtinDateMath) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {
tt, f := call.Arguments[0].typeof(env, fields)

switch {
case tt == sqltypes.Date && !call.unit.HasTimeParts():
return sqltypes.Date, f | flagNullable
case tt == sqltypes.Time && !call.unit.HasDateParts():
return sqltypes.Time, f | flagNullable
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && call.unit.HasTimeParts()) || (tt == sqltypes.Time && call.unit.HasDateParts()):
return sqltypes.Datetime, f | flagNullable
default:
return sqltypes.Char, f | flagNullable
}
}

func (call *builtinDateMath) compile(c *compiler) (ctype, error) {
date, err := call.Arguments[0].compile(c)
if err != nil {
return ctype{}, err
}

// TODO: constant propagation
_, err = call.Arguments[1].compile(c)
if err != nil {
return ctype{}, err
}

var ret ctype
ret.Flag = date.Flag | flagNullable
ret.Col = collationBinary

switch {
case date.Type == sqltypes.Date && !call.unit.HasTimeParts():
ret.Type = sqltypes.Date
c.asm.Fn_DATEADD_D(call.unit, call.sub)
case date.Type == sqltypes.Time && !call.unit.HasDateParts():
ret.Type = sqltypes.Time
c.asm.Fn_DATEADD_D(call.unit, call.sub)
case date.Type == sqltypes.Datetime || date.Type == sqltypes.Timestamp || (date.Type == sqltypes.Date && call.unit.HasTimeParts()) || (date.Type == sqltypes.Time && call.unit.HasDateParts()):
ret.Type = sqltypes.Datetime
c.asm.Fn_DATEADD_D(call.unit, call.sub)
default:
ret.Type = sqltypes.VarChar
ret.Col = defaultCoercionCollation(c.cfg.Collation)
c.asm.Fn_DATEADD_s(call.unit, call.sub, ret.Col)
}
return ret, nil
}
Loading