From f5f8043ea898ebaa599d37595294fc9e228a63f0 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 4 Sep 2020 15:23:43 +0800 Subject: [PATCH] expression: fallback vectorized control expressions (#19367) (#19749) --- expression/builtin_control_vec_generated.go | 828 +++++++++++++++++--- expression/generator/control_vec.go | 208 ++++- expression/integration_test.go | 7 + sessionctx/stmtctx/stmtctx.go | 14 + 4 files changed, 943 insertions(+), 114 deletions(-) diff --git a/expression/builtin_control_vec_generated.go b/expression/builtin_control_vec_generated.go index 39fc8153e5a99..22addcc9f19b3 100644 --- a/expression/builtin_control_vec_generated.go +++ b/expression/builtin_control_vec_generated.go @@ -22,6 +22,30 @@ import ( "github.com/pingcap/tidb/util/chunk" ) +// NOTE: Control expressions optionally evaluate some branches depending on conditions, but vectorization executes all +// branches, during which the unnecessary branches may return errors or warnings. To avoid this case, when branches +// meet errors or warnings, the vectorization falls back the scalar execution. + +func (b *builtinCaseWhenIntSig) fallbackEvalInt(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ResizeInt64(n, false) + x := result.Int64s() + for i := 0; i < n; i++ { + res, isNull, err := b.evalInt(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinCaseWhenIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -31,6 +55,8 @@ func (b *builtinCaseWhenIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Col var eLse *chunk.Column thensSlice := make([][]int64, l/2) var eLseSlice []int64 + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -38,8 +64,13 @@ func (b *builtinCaseWhenIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Col return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalInt(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -49,8 +80,13 @@ func (b *builtinCaseWhenIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Col return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalInt(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalInt(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalInt(input, result) } thens[j/2] = bufThen thensSlice[j/2] = bufThen.Int64s() @@ -64,8 +100,13 @@ func (b *builtinCaseWhenIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Col return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalInt(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalInt(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalInt(input, result) } eLse = bufElse eLseSlice = bufElse.Int64s() @@ -96,6 +137,26 @@ func (b *builtinCaseWhenIntSig) vectorized() bool { return true } +func (b *builtinCaseWhenRealSig) fallbackEvalReal(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ResizeFloat64(n, false) + x := result.Float64s() + for i := 0; i < n; i++ { + res, isNull, err := b.evalReal(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinCaseWhenRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -105,6 +166,8 @@ func (b *builtinCaseWhenRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.C var eLse *chunk.Column thensSlice := make([][]float64, l/2) var eLseSlice []float64 + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -112,8 +175,13 @@ func (b *builtinCaseWhenRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalReal(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -123,8 +191,13 @@ func (b *builtinCaseWhenRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalReal(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalReal(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalReal(input, result) } thens[j/2] = bufThen thensSlice[j/2] = bufThen.Float64s() @@ -138,8 +211,13 @@ func (b *builtinCaseWhenRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalReal(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalReal(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalReal(input, result) } eLse = bufElse eLseSlice = bufElse.Float64s() @@ -170,6 +248,26 @@ func (b *builtinCaseWhenRealSig) vectorized() bool { return true } +func (b *builtinCaseWhenDecimalSig) fallbackEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ResizeDecimal(n, false) + x := result.Decimals() + for i := 0; i < n; i++ { + res, isNull, err := b.evalDecimal(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = *res + + } + return nil +} + func (b *builtinCaseWhenDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -179,6 +277,8 @@ func (b *builtinCaseWhenDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *c var eLse *chunk.Column thensSlice := make([][]types.MyDecimal, l/2) var eLseSlice []types.MyDecimal + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -186,8 +286,13 @@ func (b *builtinCaseWhenDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *c return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDecimal(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -197,8 +302,13 @@ func (b *builtinCaseWhenDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *c return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalDecimal(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalDecimal(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDecimal(input, result) } thens[j/2] = bufThen thensSlice[j/2] = bufThen.Decimals() @@ -212,8 +322,13 @@ func (b *builtinCaseWhenDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *c return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalDecimal(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalDecimal(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDecimal(input, result) } eLse = bufElse eLseSlice = bufElse.Decimals() @@ -244,6 +359,23 @@ func (b *builtinCaseWhenDecimalSig) vectorized() bool { return true } +func (b *builtinCaseWhenStringSig) fallbackEvalString(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ReserveString(n) + for i := 0; i < n; i++ { + res, isNull, err := b.evalString(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.AppendString(res) + } + return nil +} + func (b *builtinCaseWhenStringSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -251,6 +383,8 @@ func (b *builtinCaseWhenStringSig) vecEvalString(input *chunk.Chunk, result *chu whensSlice := make([][]int64, l/2) thens := make([]*chunk.Column, l/2) var eLse *chunk.Column + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -258,8 +392,13 @@ func (b *builtinCaseWhenStringSig) vecEvalString(input *chunk.Chunk, result *chu return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalString(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -269,8 +408,13 @@ func (b *builtinCaseWhenStringSig) vecEvalString(input *chunk.Chunk, result *chu return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalString(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalString(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalString(input, result) } thens[j/2] = bufThen } @@ -283,8 +427,13 @@ func (b *builtinCaseWhenStringSig) vecEvalString(input *chunk.Chunk, result *chu return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalString(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalString(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalString(input, result) } eLse = bufElse } @@ -319,6 +468,26 @@ func (b *builtinCaseWhenStringSig) vectorized() bool { return true } +func (b *builtinCaseWhenTimeSig) fallbackEvalTime(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ResizeTime(n, false) + x := result.Times() + for i := 0; i < n; i++ { + res, isNull, err := b.evalTime(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinCaseWhenTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -328,6 +497,8 @@ func (b *builtinCaseWhenTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.C var eLse *chunk.Column thensSlice := make([][]types.Time, l/2) var eLseSlice []types.Time + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -335,8 +506,13 @@ func (b *builtinCaseWhenTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalTime(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -346,8 +522,13 @@ func (b *builtinCaseWhenTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalTime(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalTime(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalTime(input, result) } thens[j/2] = bufThen thensSlice[j/2] = bufThen.Times() @@ -361,8 +542,13 @@ func (b *builtinCaseWhenTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalTime(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalTime(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalTime(input, result) } eLse = bufElse eLseSlice = bufElse.Times() @@ -393,6 +579,26 @@ func (b *builtinCaseWhenTimeSig) vectorized() bool { return true } +func (b *builtinCaseWhenDurationSig) fallbackEvalDuration(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ResizeGoDuration(n, false) + x := result.GoDurations() + for i := 0; i < n; i++ { + res, isNull, err := b.evalDuration(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res.Duration + + } + return nil +} + func (b *builtinCaseWhenDurationSig) vecEvalDuration(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -402,6 +608,8 @@ func (b *builtinCaseWhenDurationSig) vecEvalDuration(input *chunk.Chunk, result var eLse *chunk.Column thensSlice := make([][]time.Duration, l/2) var eLseSlice []time.Duration + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -409,8 +617,13 @@ func (b *builtinCaseWhenDurationSig) vecEvalDuration(input *chunk.Chunk, result return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDuration(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -420,8 +633,13 @@ func (b *builtinCaseWhenDurationSig) vecEvalDuration(input *chunk.Chunk, result return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalDuration(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalDuration(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDuration(input, result) } thens[j/2] = bufThen thensSlice[j/2] = bufThen.GoDurations() @@ -435,8 +653,13 @@ func (b *builtinCaseWhenDurationSig) vecEvalDuration(input *chunk.Chunk, result return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalDuration(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalDuration(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDuration(input, result) } eLse = bufElse eLseSlice = bufElse.GoDurations() @@ -467,6 +690,23 @@ func (b *builtinCaseWhenDurationSig) vectorized() bool { return true } +func (b *builtinCaseWhenJSONSig) fallbackEvalJSON(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ReserveJSON(n) + for i := 0; i < n; i++ { + res, isNull, err := b.evalJSON(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.AppendJSON(res) + } + return nil +} + func (b *builtinCaseWhenJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -474,6 +714,8 @@ func (b *builtinCaseWhenJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.C whensSlice := make([][]int64, l/2) thens := make([]*chunk.Column, l/2) var eLse *chunk.Column + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j += 2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -481,8 +723,13 @@ func (b *builtinCaseWhenJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalJSON(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -492,8 +739,13 @@ func (b *builtinCaseWhenJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEvalJSON(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEvalJSON(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalJSON(input, result) } thens[j/2] = bufThen } @@ -506,8 +758,13 @@ func (b *builtinCaseWhenJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.C return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEvalJSON(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEvalJSON(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalJSON(input, result) } eLse = bufElse } @@ -542,6 +799,25 @@ func (b *builtinCaseWhenJSONSig) vectorized() bool { return true } +func (b *builtinIfNullIntSig) fallbackEvalInt(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Int64s() + for i := 0; i < n; i++ { + res, isNull, err := b.evalInt(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinIfNullIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() if err := b.args[0].VecEvalInt(b.ctx, input, result); err != nil { @@ -552,10 +828,16 @@ func (b *builtinIfNullIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalInt(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalInt(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalInt(input, result) } - arg0 := result.Int64s() arg1 := buf1.Int64s() for i := 0; i < n; i++ { @@ -571,6 +853,25 @@ func (b *builtinIfNullIntSig) vectorized() bool { return true } +func (b *builtinIfNullRealSig) fallbackEvalReal(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Float64s() + for i := 0; i < n; i++ { + res, isNull, err := b.evalReal(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinIfNullRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() if err := b.args[0].VecEvalReal(b.ctx, input, result); err != nil { @@ -581,10 +882,16 @@ func (b *builtinIfNullRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Col return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalReal(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalReal(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalReal(input, result) } - arg0 := result.Float64s() arg1 := buf1.Float64s() for i := 0; i < n; i++ { @@ -600,6 +907,25 @@ func (b *builtinIfNullRealSig) vectorized() bool { return true } +func (b *builtinIfNullDecimalSig) fallbackEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Decimals() + for i := 0; i < n; i++ { + res, isNull, err := b.evalDecimal(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = *res + + } + return nil +} + func (b *builtinIfNullDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() if err := b.args[0].VecEvalDecimal(b.ctx, input, result); err != nil { @@ -610,10 +936,16 @@ func (b *builtinIfNullDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *chu return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalDecimal(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalDecimal(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDecimal(input, result) } - arg0 := result.Decimals() arg1 := buf1.Decimals() for i := 0; i < n; i++ { @@ -629,6 +961,23 @@ func (b *builtinIfNullDecimalSig) vectorized() bool { return true } +func (b *builtinIfNullStringSig) fallbackEvalString(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ReserveString(n) + for i := 0; i < n; i++ { + res, isNull, err := b.evalString(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.AppendString(res) + } + return nil +} + func (b *builtinIfNullStringSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETString, n) @@ -644,8 +993,15 @@ func (b *builtinIfNullStringSig) vecEvalString(input *chunk.Chunk, result *chunk return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalString(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalString(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalString(input, result) } result.ReserveString(n) @@ -665,6 +1021,25 @@ func (b *builtinIfNullStringSig) vectorized() bool { return true } +func (b *builtinIfNullTimeSig) fallbackEvalTime(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Times() + for i := 0; i < n; i++ { + res, isNull, err := b.evalTime(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinIfNullTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() if err := b.args[0].VecEvalTime(b.ctx, input, result); err != nil { @@ -675,10 +1050,16 @@ func (b *builtinIfNullTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.Col return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalTime(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalTime(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalTime(input, result) } - arg0 := result.Times() arg1 := buf1.Times() for i := 0; i < n; i++ { @@ -694,6 +1075,25 @@ func (b *builtinIfNullTimeSig) vectorized() bool { return true } +func (b *builtinIfNullDurationSig) fallbackEvalDuration(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.GoDurations() + for i := 0; i < n; i++ { + res, isNull, err := b.evalDuration(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res.Duration + + } + return nil +} + func (b *builtinIfNullDurationSig) vecEvalDuration(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() if err := b.args[0].VecEvalDuration(b.ctx, input, result); err != nil { @@ -704,10 +1104,16 @@ func (b *builtinIfNullDurationSig) vecEvalDuration(input *chunk.Chunk, result *c return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalDuration(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalDuration(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDuration(input, result) } - arg0 := result.GoDurations() arg1 := buf1.GoDurations() for i := 0; i < n; i++ { @@ -723,6 +1129,23 @@ func (b *builtinIfNullDurationSig) vectorized() bool { return true } +func (b *builtinIfNullJSONSig) fallbackEvalJSON(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ReserveJSON(n) + for i := 0; i < n; i++ { + res, isNull, err := b.evalJSON(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.AppendJSON(res) + } + return nil +} + func (b *builtinIfNullJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETJson, n) @@ -738,8 +1161,15 @@ func (b *builtinIfNullJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.Col return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalJSON(b.ctx, input, buf1); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalJSON(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalJSON(input, result) } result.ReserveJSON(n) @@ -759,6 +1189,25 @@ func (b *builtinIfNullJSONSig) vectorized() bool { return true } +func (b *builtinIfIntSig) fallbackEvalInt(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Int64s() + for i := 0; i < n; i++ { + res, isNull, err := b.evalInt(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinIfIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -769,16 +1218,29 @@ func (b *builtinIfIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } - if err := b.args[1].VecEvalInt(b.ctx, input, result); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalInt(b.ctx, input, result) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalInt(input, result) } + buf2, err := b.bufAllocator.get(types.ETInt, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalInt(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalInt(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalInt(input, result) } arg0 := buf0.Int64s() @@ -805,6 +1267,25 @@ func (b *builtinIfIntSig) vectorized() bool { return true } +func (b *builtinIfRealSig) fallbackEvalReal(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Float64s() + for i := 0; i < n; i++ { + res, isNull, err := b.evalReal(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinIfRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -815,16 +1296,29 @@ func (b *builtinIfRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } - if err := b.args[1].VecEvalReal(b.ctx, input, result); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalReal(b.ctx, input, result) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalReal(input, result) } + buf2, err := b.bufAllocator.get(types.ETReal, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalReal(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalReal(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalReal(input, result) } arg0 := buf0.Int64s() @@ -851,6 +1345,25 @@ func (b *builtinIfRealSig) vectorized() bool { return true } +func (b *builtinIfDecimalSig) fallbackEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Decimals() + for i := 0; i < n; i++ { + res, isNull, err := b.evalDecimal(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = *res + + } + return nil +} + func (b *builtinIfDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -861,16 +1374,29 @@ func (b *builtinIfDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *chunk.C if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } - if err := b.args[1].VecEvalDecimal(b.ctx, input, result); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalDecimal(b.ctx, input, result) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDecimal(input, result) } + buf2, err := b.bufAllocator.get(types.ETDecimal, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalDecimal(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalDecimal(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDecimal(input, result) } arg0 := buf0.Int64s() @@ -897,6 +1423,23 @@ func (b *builtinIfDecimalSig) vectorized() bool { return true } +func (b *builtinIfStringSig) fallbackEvalString(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ReserveString(n) + for i := 0; i < n; i++ { + res, isNull, err := b.evalString(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.AppendString(res) + } + return nil +} + func (b *builtinIfStringSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -907,21 +1450,34 @@ func (b *builtinIfStringSig) vecEvalString(input *chunk.Chunk, result *chunk.Col if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() buf1, err := b.bufAllocator.get(types.ETString, n) if err != nil { return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalString(b.ctx, input, buf1); err != nil { - return err + err = b.args[1].VecEvalString(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalString(input, result) } + buf2, err := b.bufAllocator.get(types.ETString, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalString(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalString(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalString(input, result) } result.ReserveString(n) @@ -951,6 +1507,25 @@ func (b *builtinIfStringSig) vectorized() bool { return true } +func (b *builtinIfTimeSig) fallbackEvalTime(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.Times() + for i := 0; i < n; i++ { + res, isNull, err := b.evalTime(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res + + } + return nil +} + func (b *builtinIfTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -961,16 +1536,29 @@ func (b *builtinIfTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk.Column) if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } - if err := b.args[1].VecEvalTime(b.ctx, input, result); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalTime(b.ctx, input, result) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalTime(input, result) } + buf2, err := b.bufAllocator.get(types.ETDatetime, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalTime(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalTime(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalTime(input, result) } arg0 := buf0.Int64s() @@ -997,6 +1585,25 @@ func (b *builtinIfTimeSig) vectorized() bool { return true } +func (b *builtinIfDurationSig) fallbackEvalDuration(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + x := result.GoDurations() + for i := 0; i < n; i++ { + res, isNull, err := b.evalDuration(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + + x[i] = res.Duration + + } + return nil +} + func (b *builtinIfDurationSig) vecEvalDuration(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -1007,16 +1614,29 @@ func (b *builtinIfDurationSig) vecEvalDuration(input *chunk.Chunk, result *chunk if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } - if err := b.args[1].VecEvalDuration(b.ctx, input, result); err != nil { - return err + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEvalDuration(b.ctx, input, result) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDuration(input, result) } + buf2, err := b.bufAllocator.get(types.ETDuration, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalDuration(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalDuration(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalDuration(input, result) } arg0 := buf0.Int64s() @@ -1043,6 +1663,23 @@ func (b *builtinIfDurationSig) vectorized() bool { return true } +func (b *builtinIfJSONSig) fallbackEvalJSON(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + result.ReserveJSON(n) + for i := 0; i < n; i++ { + res, isNull, err := b.evalJSON(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.AppendJSON(res) + } + return nil +} + func (b *builtinIfJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -1053,21 +1690,34 @@ func (b *builtinIfJSONSig) vecEvalJSON(input *chunk.Chunk, result *chunk.Column) if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() buf1, err := b.bufAllocator.get(types.ETJson, n) if err != nil { return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEvalJSON(b.ctx, input, buf1); err != nil { - return err + err = b.args[1].VecEvalJSON(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalJSON(input, result) } + buf2, err := b.bufAllocator.get(types.ETJson, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEvalJSON(b.ctx, input, buf2); err != nil { - return err + err = b.args[2].VecEvalJSON(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEvalJSON(input, result) } result.ReserveJSON(n) diff --git a/expression/generator/control_vec.go b/expression/generator/control_vec.go index 733c73603e777..0b77d127d6f9a 100644 --- a/expression/generator/control_vec.go +++ b/expression/generator/control_vec.go @@ -49,10 +49,54 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) + +// NOTE: Control expressions optionally evaluate some branches depending on conditions, but vectorization executes all +// branches, during which the unnecessary branches may return errors or warnings. To avoid this case, when branches +// meet errors or warnings, the vectorization falls back the scalar execution. + ` var builtinCaseWhenVec = template.Must(template.New("builtinCaseWhenVec").Parse(` {{ range .Sigs }}{{ with .Arg0 }} +func (b *builtinCaseWhen{{ .TypeName }}Sig) fallbackEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + {{- if .Fixed }} + result.Resize{{ .TypeNameInColumn }}(n, false) + x := result.{{ .TypeNameInColumn }}s() + for i := 0; i < n; i++ { + res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + {{ if eq .TypeName "Decimal" }} + x[i] = *res + {{ else if eq .TypeName "Duration" }} + x[i] = res.Duration + {{ else }} + x[i] = res + {{ end }} + } + {{ else }} + result.Reserve{{ .TypeNameInColumn }}(n) + for i := 0; i < n; i++ { + res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.Append{{ .TypeNameInColumn }}(res) + } + {{ end -}} + return nil +} + func (b *builtinCaseWhen{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() args, l := b.getArgs(), len(b.getArgs()) @@ -64,6 +108,8 @@ func (b *builtinCaseWhen{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk. thensSlice := make([][]{{.TypeNameGo}}, l/2) var eLseSlice []{{.TypeNameGo}} {{- end }} + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() for j := 0; j < l-1; j+=2 { bufWhen, err := b.bufAllocator.get(types.ETInt, n) @@ -71,8 +117,13 @@ func (b *builtinCaseWhen{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk. return err } defer b.bufAllocator.put(bufWhen) - if err := args[j].VecEvalInt(b.ctx, input, bufWhen); err != nil { - return err + err = args[j].VecEvalInt(b.ctx, input, bufWhen) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input, result) } whens[j/2] = bufWhen whensSlice[j/2] = bufWhen.Int64s() @@ -82,8 +133,13 @@ func (b *builtinCaseWhen{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk. return err } defer b.bufAllocator.put(bufThen) - if err := args[j+1].VecEval{{ .TypeName }}(b.ctx, input, bufThen); err != nil { - return err + err = args[j+1].VecEval{{ .TypeName }}(b.ctx, input, bufThen) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input, result) } thens[j/2] = bufThen {{- if .Fixed }} @@ -99,8 +155,13 @@ func (b *builtinCaseWhen{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk. return err } defer b.bufAllocator.put(bufElse) - if err := args[l-1].VecEval{{ .TypeName }}(b.ctx, input, bufElse); err != nil { - return err + err = args[l-1].VecEval{{ .TypeName }}(b.ctx, input, bufElse) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input, result) } eLse = bufElse {{- if .Fixed }} @@ -163,9 +224,46 @@ func (b *builtinCaseWhen{{ .TypeName }}Sig) vectorized() bool { var builtinIfNullVec = template.Must(template.New("builtinIfNullVec").Parse(` {{ range .Sigs }}{{ with .Arg0 }} -func (b *builtinIfNull{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { +func (b *builtinIfNull{{ .TypeName }}Sig) fallbackEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + {{- if .Fixed }} + x := result.{{ .TypeNameInColumn }}s() + for i := 0; i < n; i++ { + res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + {{ if eq .TypeName "Decimal" }} + x[i] = *res + {{ else if eq .TypeName "Duration" }} + x[i] = res.Duration + {{ else }} + x[i] = res + {{ end }} + } + {{ else }} + result.Reserve{{ .TypeNameInColumn }}(n) + for i := 0; i < n; i++ { + res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.Append{{ .TypeNameInColumn }}(res) + } + {{ end -}} + return nil +} +func (b *builtinIfNull{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() {{- if .Fixed }} if err := b.args[0].VecEval{{ .TypeName }}(b.ctx, input, result); err != nil { return err @@ -175,10 +273,16 @@ func (b *builtinIfNull{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Ch return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEval{{ .TypeName }}(b.ctx, input, buf1); err != nil { - return err - } - + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEval{{ .TypeName }}(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input, result) + } arg0 := result.{{ .TypeNameInColumn }}s() arg1 := buf1.{{ .TypeNameInColumn }}s() for i := 0; i < n; i++ { @@ -201,9 +305,16 @@ func (b *builtinIfNull{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Ch return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEval{{ .TypeName }}(b.ctx, input, buf1); err != nil { - return err - } + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + err = b.args[1].VecEval{{ .TypeName }}(b.ctx, input, buf1) + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input,result) + } result.Reserve{{ .TypeNameInColumn }}(n) for i := 0; i < n; i++ { @@ -228,6 +339,44 @@ func (b *builtinIfNull{{ .TypeName }}Sig) vectorized() bool { var builtinIfVec = template.Must(template.New("builtinIfVec").Parse(` {{ range .Sigs }}{{ with .Arg0 }} +func (b *builtinIf{{ .TypeName }}Sig) fallbackEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + {{- if .Fixed }} + x := result.{{ .TypeNameInColumn }}s() + for i := 0; i < n; i++ { + res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) + if err != nil { + return err + } + result.SetNull(i, isNull) + if isNull { + continue + } + {{ if eq .TypeName "Decimal" }} + x[i] = *res + {{ else if eq .TypeName "Duration" }} + x[i] = res.Duration + {{ else }} + x[i] = res + {{ end }} + } + {{ else }} + result.Reserve{{ .TypeNameInColumn }}(n) + for i := 0; i < n; i++ { + res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + result.Append{{ .TypeNameInColumn }}(res) + } + {{ end -}} + return nil +} + func (b *builtinIf{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get(types.ETInt, n) @@ -238,30 +387,39 @@ func (b *builtinIf{{ .TypeName }}Sig) vecEval{{ .TypeName }}(input *chunk.Chunk, if err := b.args[0].VecEvalInt(b.ctx, input, buf0); err != nil { return err } - + sc := b.ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() {{- if .Fixed }} - if err := b.args[1].VecEval{{ .TypeName }}(b.ctx, input, result); err != nil { - return err - } + err = b.args[1].VecEval{{ .TypeName }}(b.ctx, input, result) {{- else }} buf1, err := b.bufAllocator.get(types.ET{{ .ETName }}, n) if err != nil { return err } defer b.bufAllocator.put(buf1) - if err := b.args[1].VecEval{{ .TypeName }}(b.ctx, input, buf1); err != nil { - return err - } + err = b.args[1].VecEval{{ .TypeName }}(b.ctx, input, buf1) {{- end }} + afterWarns := sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input, result) + } + buf2, err := b.bufAllocator.get(types.ET{{ .ETName }}, n) if err != nil { return err } defer b.bufAllocator.put(buf2) - if err := b.args[2].VecEval{{ .TypeName }}(b.ctx, input, buf2); err != nil { - return err - } - + err = b.args[2].VecEval{{ .TypeName }}(b.ctx, input, buf2) + afterWarns = sc.WarningCount() + if err != nil || afterWarns > beforeWarns { + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + } + return b.fallbackEval{{ .TypeName }}(input, result) + } {{ if not .Fixed }} result.Reserve{{ .TypeNameInColumn }}(n) {{- end }} diff --git a/expression/integration_test.go b/expression/integration_test.go index e9d24de988a4b..7b645be880722 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2800,6 +2800,13 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { result.Check(testkit.Rows(" 4")) result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end") result.Check(testkit.Rows(" 4")) + // test warnings + tk.MustQuery("select case when b=0 then 1 else 1/b end from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select if(b=0, 1, 1/b) from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select ifnull(b, b/0) from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0")) tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0")) diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index fcfb774b66095..e052fb23cdffd 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -368,6 +368,20 @@ func (sc *StatementContext) SetWarnings(warns []SQLWarn) { sc.mu.Unlock() } +// TruncateWarnings truncates wanrings begin from start and returns the truncated warnings. +func (sc *StatementContext) TruncateWarnings(start int) []SQLWarn { + sc.mu.Lock() + defer sc.mu.Unlock() + sz := len(sc.mu.warnings) - start + if sz <= 0 { + return nil + } + ret := make([]SQLWarn, sz) + copy(ret, sc.mu.warnings[start:]) + sc.mu.warnings = sc.mu.warnings[:start] + return ret +} + // AppendWarning appends a warning with level 'Warning'. func (sc *StatementContext) AppendWarning(warn error) { sc.mu.Lock()