From 06342ddbe28851ef7ba19ffc0c7e20b0c3cb49c3 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Wed, 28 Aug 2019 12:00:35 +0800 Subject: [PATCH] expression: make `builtinCastIntAsInt` support vectorized evaluation (#11826) --- expression/builtin.go | 1 + expression/builtin_cast.go | 21 ++++++++++ expression/builtin_cast_bench_test.go | 60 +++++++++++++++++++++++++++ expression/builtin_cast_test.go | 27 ++++++++++++ expression/vectorized.go | 5 ++- 5 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 expression/builtin_cast_bench_test.go diff --git a/expression/builtin.go b/expression/builtin.go index 437e7a81237e1..86b14ea6d2496 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -297,6 +297,7 @@ func (b *baseBuiltinFunc) cloneFrom(from *baseBuiltinFunc) { b.ctx = from.ctx b.tp = from.tp b.pbCode = from.pbCode + b.columnBufferAllocator = newLocalSliceBuffer(len(b.args)) } func (b *baseBuiltinFunc) Clone() builtinFunc { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index be3b5273330b3..ab12b728f628d 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -435,6 +435,10 @@ func (b *builtinCastIntAsIntSig) Clone() builtinFunc { return newSig } +func (b *builtinCastIntAsIntSig) vectorized() bool { + return true +} + func (b *builtinCastIntAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { res, isNull, err = b.args[0].EvalInt(b.ctx, row) if isNull || err != nil { @@ -446,6 +450,23 @@ func (b *builtinCastIntAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool, return } +func (b *builtinCastIntAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + if err := b.args[0].VecEvalInt(b.ctx, input, result); err != nil { + return err + } + if b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) { + i64s := result.Int64s() + // the null array of result is set by its child args[0], + // so we can skip it here to make this loop simpler to improve its performance. + for i := range i64s { + if i64s[i] < 0 { + i64s[i] = 0 + } + } + } + return nil +} + type builtinCastIntAsRealSig struct { baseBuiltinCastFunc } diff --git a/expression/builtin_cast_bench_test.go b/expression/builtin_cast_bench_test.go new file mode 100644 index 0000000000000..e68d0bfdf175d --- /dev/null +++ b/expression/builtin_cast_bench_test.go @@ -0,0 +1,60 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "math/rand" + "testing" + + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" +) + +func genCastIntAsInt() (*builtinCastIntAsIntSig, *chunk.Chunk, *chunk.Column) { + col := &Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0} + baseFunc := newBaseBuiltinFunc(mock.NewContext(), []Expression{col}) + baseCast := newBaseBuiltinCastFunc(baseFunc, false) + cast := &builtinCastIntAsIntSig{baseCast} + input := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1024) + for i := 0; i < 1024; i++ { + input.AppendInt64(0, rand.Int63n(10000)-5000) + } + result := chunk.NewColumn(types.NewFieldType(mysql.TypeLonglong), 1024) + return cast, input, result +} + +func BenchmarkCastIntAsIntRow(b *testing.B) { + cast, input, _ := genCastIntAsInt() + it := chunk.NewIterator4Chunk(input) + b.ResetTimer() + for i := 0; i < b.N; i++ { + for row := it.Begin(); row != it.End(); row = it.Next() { + if _, _, err := cast.evalInt(row); err != nil { + b.Fatal(err) + } + } + } +} + +func BenchmarkCastIntAsIntVec(b *testing.B) { + cast, input, result := genCastIntAsInt() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := cast.vecEvalInt(input, result); err != nil { + b.Fatal(err) + } + } +} diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index eb2ca59062599..c81f2fcd556ed 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -1375,3 +1375,30 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsJSON(c *C) { c.Assert(ok, IsTrue) c.Assert(output, Equals, input) } + +func (s *testEvaluatorSuite) TestCastIntAsIntVec(c *C) { + cast, input, result := genCastIntAsInt() + c.Assert(cast.vecEvalInt(input, result), IsNil) + i64s := result.Int64s() + it := chunk.NewIterator4Chunk(input) + i := 0 + for row := it.Begin(); row != it.End(); row = it.Next() { + v, _, err := cast.evalInt(row) + c.Assert(err, IsNil) + c.Assert(v, Equals, i64s[i]) + i++ + } + + cast.inUnion = true + cast.getRetTp().Flag |= mysql.UnsignedFlag + c.Assert(cast.vecEvalInt(input, result), IsNil) + i64s = result.Int64s() + it = chunk.NewIterator4Chunk(input) + i = 0 + for row := it.Begin(); row != it.End(); row = it.Next() { + v, _, err := cast.evalInt(row) + c.Assert(err, IsNil) + c.Assert(v, Equals, i64s[i]) + i++ + } +} diff --git a/expression/vectorized.go b/expression/vectorized.go index b6e32e1be8e0c..51ecc20c4878a 100644 --- a/expression/vectorized.go +++ b/expression/vectorized.go @@ -21,7 +21,10 @@ import ( ) func genVecFromConstExpr(ctx sessionctx.Context, expr Expression, targetType types.EvalType, input *chunk.Chunk, result *chunk.Column) error { - n := input.NumRows() + n := 1 + if input != nil { + n = input.NumRows() + } switch targetType { case types.ETInt: result.ResizeInt64(n)