Skip to content

Commit

Permalink
Merge branch 'master' into refine-prefer-range-scan
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyifangreeneyes authored Sep 2, 2021
2 parents 700c291 + b0c9d19 commit 9632446
Show file tree
Hide file tree
Showing 7 changed files with 615 additions and 38 deletions.
11 changes: 0 additions & 11 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,6 @@ func numericContextResultType(ft *types.FieldType) types.EvalType {
return evalTp4Ft
}

// setFlenDecimal4Int is called to set proper `Flen` and `Decimal` of return
// type according to the two input parameter's types.
func setFlenDecimal4Int(retTp, a, b *types.FieldType) {
retTp.Decimal = 0
retTp.Flen = mysql.MaxIntWidth
}

// setFlenDecimal4RealOrDecimal is called to set proper `Flen` and `Decimal` of return
// type according to the two input parameter's types.
func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool, isMultiply bool) {
Expand Down Expand Up @@ -190,7 +183,6 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args [
if mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticPlusIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusInt)
return sig, nil
Expand Down Expand Up @@ -338,7 +330,6 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args
if err != nil {
return nil, err
}
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
if (mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag)) && !ctx.GetSessionVars().SQLMode.HasNoUnsignedSubtractionMode() {
bf.tp.Flag |= mysql.UnsignedFlag
}
Expand Down Expand Up @@ -523,12 +514,10 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, ar
}
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntUnsignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyIntUnsigned)
return sig, nil
}
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyInt)
return sig, nil
Expand Down
25 changes: 0 additions & 25 deletions expression/builtin_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,31 +91,6 @@ func (s *testEvaluatorSuite) TestSetFlenDecimal4RealOrDecimal(c *C) {
c.Assert(ret.Flen, Equals, types.UnspecifiedLength)
}

func (s *testEvaluatorSuite) TestSetFlenDecimal4Int(c *C) {
ret := &types.FieldType{}
a := &types.FieldType{
Decimal: 1,
Flen: 3,
}
b := &types.FieldType{
Decimal: 0,
Flen: 2,
}
setFlenDecimal4Int(ret, a, b)
c.Assert(ret.Decimal, Equals, 0)
c.Assert(ret.Flen, Equals, mysql.MaxIntWidth)

b.Flen = mysql.MaxIntWidth + 1
setFlenDecimal4Int(ret, a, b)
c.Assert(ret.Decimal, Equals, 0)
c.Assert(ret.Flen, Equals, mysql.MaxIntWidth)

b.Flen = types.UnspecifiedLength
setFlenDecimal4Int(ret, a, b)
c.Assert(ret.Decimal, Equals, 0)
c.Assert(ret.Flen, Equals, mysql.MaxIntWidth)
}

func (s *testEvaluatorSuite) TestArithmeticPlus(c *C) {
// case: 1
args := []interface{}{int64(12), int64(1)}
Expand Down
3 changes: 2 additions & 1 deletion expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ func init() {
// FoldConstant does constant folding optimization on an expression excluding deferred ones.
func FoldConstant(expr Expression) Expression {
e, _ := foldConstant(expr)
// keep the original coercibility values after folding
// keep the original coercibility, charset and collation values after folding
e.SetCoercibility(expr.Coercibility())
e.GetType().Charset, e.GetType().Collate = expr.GetType().Charset, expr.GetType().Collate
return e
}

Expand Down
2 changes: 2 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6540,6 +6540,8 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) {
tk.MustExec("INSERT INTO `t1` VALUES ('Ȇ');")
tk.MustQuery("select * from t1 where col1 not in (0xc484, 0xe5a4bc, 0xc3b3);").Check(testkit.Rows("Ȇ"))
tk.MustQuery("select * from t1 where col1 >= 0xc484 and col1 <= 0xc3b3;").Check(testkit.Rows("Ȇ"))

tk.MustQuery("select collation(IF('a' < 'B' collate utf8mb4_general_ci, 'smaller', 'greater' collate utf8mb4_unicode_ci));").Check(testkit.Rows("utf8mb4_unicode_ci"))
}

func (s *testIntegrationSerialSuite) TestWeightString(c *C) {
Expand Down
147 changes: 147 additions & 0 deletions store/driver/txn/ranged_kv_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package txn

import (
"bytes"
"context"
"sort"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
)

Expand Down Expand Up @@ -62,3 +65,147 @@ func (s *RangedKVRetriever) Intersect(startKey, endKey kv.Key) *RangedKVRetrieve

return nil
}

// ScanCurrentRange scans the retriever for current range
func (s *RangedKVRetriever) ScanCurrentRange(reverse bool) (kv.Iterator, error) {
if !s.Valid() {
return nil, errors.New("retriever is invalid")
}

startKey, endKey := s.StartKey, s.EndKey
if !reverse {
return s.Iter(startKey, endKey)
}

iter, err := s.IterReverse(endKey)
if err != nil {
return nil, err
}

if len(startKey) > 0 {
iter = newLowerBoundReverseIter(iter, startKey)
}

return iter, nil
}

type sortedRetrievers []*RangedKVRetriever

func (retrievers sortedRetrievers) TryGet(ctx context.Context, k kv.Key) (bool, []byte, error) {
if len(retrievers) == 0 {
return false, nil, nil
}

_, r := retrievers.searchRetrieverByKey(k)
if r == nil || !r.Contains(k) {
return false, nil, nil
}

val, err := r.Get(ctx, k)
return true, val, err
}

func (retrievers sortedRetrievers) TryBatchGet(ctx context.Context, keys []kv.Key, collectF func(k kv.Key, v []byte)) ([]kv.Key, error) {
if len(retrievers) == 0 {
return keys, nil
}

var nonCustomKeys []kv.Key
for _, k := range keys {
custom, val, err := retrievers.TryGet(ctx, k)
if !custom {
nonCustomKeys = append(nonCustomKeys, k)
continue
}

if kv.ErrNotExist.Equal(err) {
continue
}

if err != nil {
return nil, err
}

collectF(k, val)
}

return nonCustomKeys, nil
}

// GetScanRetrievers gets all retrievers who have intersections with range [StartKey, endKey).
// If snapshot is not nil, the range between two custom retrievers with a snapshot retriever will also be returned.
func (retrievers sortedRetrievers) GetScanRetrievers(startKey, endKey kv.Key, snapshot kv.Retriever) []*RangedKVRetriever {
// According to our experience, in most cases there is only one retriever returned.
result := make([]*RangedKVRetriever, 0, 1)

// Firstly, we should find the first retriever whose EndKey is after input startKey,
// it is obvious that the retrievers before it do not have a common range with the input.
idx, _ := retrievers.searchRetrieverByKey(startKey)

// If not found, it means the scan range is located out of retrievers, just use snapshot to scan it
if idx == len(retrievers) {
if snapshot != nil {
result = append(result, NewRangeRetriever(snapshot, startKey, endKey))
}
return result
}

// Check every retriever whose index >= idx whether it intersects with the input range.
// If it is true, put the intersected range to the result.
// The range between two retrievers should also be checked because we read snapshot data from there.
checks := retrievers[idx:]
for i, retriever := range checks {
// Intersect with the range which is on the left of the retriever and use snapshot to read it
// Notice that when len(retriever.StartKey) == 0, that means there is no left range for it
if len(retriever.StartKey) > 0 && snapshot != nil {
var snapStartKey kv.Key
if i != 0 {
snapStartKey = checks[i-1].EndKey
} else {
snapStartKey = nil
}

if r := NewRangeRetriever(snapshot, snapStartKey, retriever.StartKey).Intersect(startKey, endKey); r != nil {
result = append(result, r)
}
}

// Intersect the current retriever
if r := retriever.Intersect(startKey, endKey); r != nil {
result = append(result, r)
continue
}

// Not necessary to continue when the current retriever does not have a valid intersection
return result
}

// If the last retriever has an intersection, we should still check the range on its right.
lastRetriever := checks[len(checks)-1]
if snapshot != nil && len(lastRetriever.EndKey) > 0 {
if r := NewRangeRetriever(snapshot, lastRetriever.EndKey, nil).Intersect(startKey, endKey); r != nil {
result = append(result, r)
}
}

return result
}

// searchRetrieverByKey searches the first retriever whose EndKey after the specified key
func (retrievers sortedRetrievers) searchRetrieverByKey(k kv.Key) (int, *RangedKVRetriever) {
n := len(retrievers)
if n == 0 {
return n, nil
}

i := sort.Search(n, func(i int) bool {
r := retrievers[i]
return len(r.EndKey) == 0 || bytes.Compare(r.EndKey, k) > 0
})

if i < n {
return i, retrievers[i]
}

return n, nil
}
Loading

0 comments on commit 9632446

Please sign in to comment.