diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index 06a7f07d0ce99..9711f585ca0f4 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -86,13 +86,6 @@ func numericContextResultType(ft *types.FieldType) types.EvalType { return evalTp4Ft } -// setFlenDecimal4Int is called to set proper `Flen` and `Decimal` of return -// type according to the two input parameter's types. -func setFlenDecimal4Int(retTp, a, b *types.FieldType) { - retTp.Decimal = 0 - retTp.Flen = mysql.MaxIntWidth -} - // setFlenDecimal4RealOrDecimal is called to set proper `Flen` and `Decimal` of return // type according to the two input parameter's types. func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool, isMultiply bool) { @@ -190,7 +183,6 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args [ if mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag) { bf.tp.Flag |= mysql.UnsignedFlag } - setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) sig := &builtinArithmeticPlusIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_PlusInt) return sig, nil @@ -338,7 +330,6 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args if err != nil { return nil, err } - setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) if (mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag)) && !ctx.GetSessionVars().SQLMode.HasNoUnsignedSubtractionMode() { bf.tp.Flag |= mysql.UnsignedFlag } @@ -523,12 +514,10 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, ar } if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag - setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) sig := &builtinArithmeticMultiplyIntUnsignedSig{bf} sig.setPbCode(tipb.ScalarFuncSig_MultiplyIntUnsigned) return sig, nil } - setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) sig := &builtinArithmeticMultiplyIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_MultiplyInt) return sig, nil diff --git a/expression/builtin_arithmetic_test.go b/expression/builtin_arithmetic_test.go index e9c7ab959bce8..af0e0d27970d1 100644 --- a/expression/builtin_arithmetic_test.go +++ b/expression/builtin_arithmetic_test.go @@ -91,31 +91,6 @@ func (s *testEvaluatorSuite) TestSetFlenDecimal4RealOrDecimal(c *C) { c.Assert(ret.Flen, Equals, types.UnspecifiedLength) } -func (s *testEvaluatorSuite) TestSetFlenDecimal4Int(c *C) { - ret := &types.FieldType{} - a := &types.FieldType{ - Decimal: 1, - Flen: 3, - } - b := &types.FieldType{ - Decimal: 0, - Flen: 2, - } - setFlenDecimal4Int(ret, a, b) - c.Assert(ret.Decimal, Equals, 0) - c.Assert(ret.Flen, Equals, mysql.MaxIntWidth) - - b.Flen = mysql.MaxIntWidth + 1 - setFlenDecimal4Int(ret, a, b) - c.Assert(ret.Decimal, Equals, 0) - c.Assert(ret.Flen, Equals, mysql.MaxIntWidth) - - b.Flen = types.UnspecifiedLength - setFlenDecimal4Int(ret, a, b) - c.Assert(ret.Decimal, Equals, 0) - c.Assert(ret.Flen, Equals, mysql.MaxIntWidth) -} - func (s *testEvaluatorSuite) TestArithmeticPlus(c *C) { // case: 1 args := []interface{}{int64(12), int64(1)} diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 7293134fa5bc1..8ce4615eb567e 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -38,8 +38,9 @@ func init() { // FoldConstant does constant folding optimization on an expression excluding deferred ones. func FoldConstant(expr Expression) Expression { e, _ := foldConstant(expr) - // keep the original coercibility values after folding + // keep the original coercibility, charset and collation values after folding e.SetCoercibility(expr.Coercibility()) + e.GetType().Charset, e.GetType().Collate = expr.GetType().Charset, expr.GetType().Collate return e } diff --git a/expression/integration_test.go b/expression/integration_test.go index e1768fd0de6dc..a173d1cb2e699 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -6540,6 +6540,8 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) { tk.MustExec("INSERT INTO `t1` VALUES ('Ȇ');") tk.MustQuery("select * from t1 where col1 not in (0xc484, 0xe5a4bc, 0xc3b3);").Check(testkit.Rows("Ȇ")) tk.MustQuery("select * from t1 where col1 >= 0xc484 and col1 <= 0xc3b3;").Check(testkit.Rows("Ȇ")) + + tk.MustQuery("select collation(IF('a' < 'B' collate utf8mb4_general_ci, 'smaller', 'greater' collate utf8mb4_unicode_ci));").Check(testkit.Rows("utf8mb4_unicode_ci")) } func (s *testIntegrationSerialSuite) TestWeightString(c *C) { diff --git a/store/driver/txn/ranged_kv_retriever.go b/store/driver/txn/ranged_kv_retriever.go index a342afb080452..414652e8afbfa 100644 --- a/store/driver/txn/ranged_kv_retriever.go +++ b/store/driver/txn/ranged_kv_retriever.go @@ -16,7 +16,10 @@ package txn import ( "bytes" + "context" + "sort" + "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" ) @@ -62,3 +65,147 @@ func (s *RangedKVRetriever) Intersect(startKey, endKey kv.Key) *RangedKVRetrieve return nil } + +// ScanCurrentRange scans the retriever for current range +func (s *RangedKVRetriever) ScanCurrentRange(reverse bool) (kv.Iterator, error) { + if !s.Valid() { + return nil, errors.New("retriever is invalid") + } + + startKey, endKey := s.StartKey, s.EndKey + if !reverse { + return s.Iter(startKey, endKey) + } + + iter, err := s.IterReverse(endKey) + if err != nil { + return nil, err + } + + if len(startKey) > 0 { + iter = newLowerBoundReverseIter(iter, startKey) + } + + return iter, nil +} + +type sortedRetrievers []*RangedKVRetriever + +func (retrievers sortedRetrievers) TryGet(ctx context.Context, k kv.Key) (bool, []byte, error) { + if len(retrievers) == 0 { + return false, nil, nil + } + + _, r := retrievers.searchRetrieverByKey(k) + if r == nil || !r.Contains(k) { + return false, nil, nil + } + + val, err := r.Get(ctx, k) + return true, val, err +} + +func (retrievers sortedRetrievers) TryBatchGet(ctx context.Context, keys []kv.Key, collectF func(k kv.Key, v []byte)) ([]kv.Key, error) { + if len(retrievers) == 0 { + return keys, nil + } + + var nonCustomKeys []kv.Key + for _, k := range keys { + custom, val, err := retrievers.TryGet(ctx, k) + if !custom { + nonCustomKeys = append(nonCustomKeys, k) + continue + } + + if kv.ErrNotExist.Equal(err) { + continue + } + + if err != nil { + return nil, err + } + + collectF(k, val) + } + + return nonCustomKeys, nil +} + +// GetScanRetrievers gets all retrievers who have intersections with range [StartKey, endKey). +// If snapshot is not nil, the range between two custom retrievers with a snapshot retriever will also be returned. +func (retrievers sortedRetrievers) GetScanRetrievers(startKey, endKey kv.Key, snapshot kv.Retriever) []*RangedKVRetriever { + // According to our experience, in most cases there is only one retriever returned. + result := make([]*RangedKVRetriever, 0, 1) + + // Firstly, we should find the first retriever whose EndKey is after input startKey, + // it is obvious that the retrievers before it do not have a common range with the input. + idx, _ := retrievers.searchRetrieverByKey(startKey) + + // If not found, it means the scan range is located out of retrievers, just use snapshot to scan it + if idx == len(retrievers) { + if snapshot != nil { + result = append(result, NewRangeRetriever(snapshot, startKey, endKey)) + } + return result + } + + // Check every retriever whose index >= idx whether it intersects with the input range. + // If it is true, put the intersected range to the result. + // The range between two retrievers should also be checked because we read snapshot data from there. + checks := retrievers[idx:] + for i, retriever := range checks { + // Intersect with the range which is on the left of the retriever and use snapshot to read it + // Notice that when len(retriever.StartKey) == 0, that means there is no left range for it + if len(retriever.StartKey) > 0 && snapshot != nil { + var snapStartKey kv.Key + if i != 0 { + snapStartKey = checks[i-1].EndKey + } else { + snapStartKey = nil + } + + if r := NewRangeRetriever(snapshot, snapStartKey, retriever.StartKey).Intersect(startKey, endKey); r != nil { + result = append(result, r) + } + } + + // Intersect the current retriever + if r := retriever.Intersect(startKey, endKey); r != nil { + result = append(result, r) + continue + } + + // Not necessary to continue when the current retriever does not have a valid intersection + return result + } + + // If the last retriever has an intersection, we should still check the range on its right. + lastRetriever := checks[len(checks)-1] + if snapshot != nil && len(lastRetriever.EndKey) > 0 { + if r := NewRangeRetriever(snapshot, lastRetriever.EndKey, nil).Intersect(startKey, endKey); r != nil { + result = append(result, r) + } + } + + return result +} + +// searchRetrieverByKey searches the first retriever whose EndKey after the specified key +func (retrievers sortedRetrievers) searchRetrieverByKey(k kv.Key) (int, *RangedKVRetriever) { + n := len(retrievers) + if n == 0 { + return n, nil + } + + i := sort.Search(n, func(i int) bool { + r := retrievers[i] + return len(r.EndKey) == 0 || bytes.Compare(r.EndKey, k) > 0 + }) + + if i < n { + return i, retrievers[i] + } + + return n, nil +} diff --git a/store/driver/txn/ranged_kv_retriever_test.go b/store/driver/txn/ranged_kv_retriever_test.go index b7190d3da8ba0..80645638ebcf5 100644 --- a/store/driver/txn/ranged_kv_retriever_test.go +++ b/store/driver/txn/ranged_kv_retriever_test.go @@ -14,10 +14,13 @@ package txn import ( + "context" "testing" + "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" "github.com/stretchr/testify/assert" + "github.com/tikv/client-go/v2/txnkv/transaction" ) func newRetriever(startKey, endKey kv.Key) *RangedKVRetriever { @@ -249,3 +252,463 @@ func TestRangedKVRetrieverIntersect(t *testing.T) { } } } + +func TestRangedKVRetrieverScanCurrentRange(t *testing.T) { + memBuffer := newMemBufferRetriever(t, [][]interface{}{{"a1", "v1"}, {"a2", "v2"}, {"a3", "v3"}, {"a4", "v4"}}) + cases := []struct { + retriever *RangedKVRetriever + reverse bool + result [][]interface{} + }{ + { + retriever: NewRangeRetriever(memBuffer, nil, nil), + result: [][]interface{}{{"a1", "v1"}, {"a2", "v2"}, {"a3", "v3"}, {"a4", "v4"}}, + }, + { + retriever: NewRangeRetriever(memBuffer, nil, nil), + reverse: true, + result: [][]interface{}{{"a4", "v4"}, {"a3", "v3"}, {"a2", "v2"}, {"a1", "v1"}}, + }, + { + retriever: NewRangeRetriever(memBuffer, kv.Key("a10"), kv.Key("a4")), + result: [][]interface{}{{"a2", "v2"}, {"a3", "v3"}}, + }, + { + retriever: NewRangeRetriever(memBuffer, kv.Key("a10"), kv.Key("a4")), + reverse: true, + result: [][]interface{}{{"a3", "v3"}, {"a2", "v2"}}, + }, + { + retriever: NewRangeRetriever(memBuffer, nil, kv.Key("a4")), + reverse: true, + result: [][]interface{}{{"a3", "v3"}, {"a2", "v2"}, {"a1", "v1"}}, + }, + } + + for _, c := range cases { + iter, err := c.retriever.ScanCurrentRange(c.reverse) + assert.Nil(t, err) + for i := range c.result { + expectedKey := makeBytes(c.result[i][0]) + expectedVal := makeBytes(c.result[i][1]) + assert.True(t, iter.Valid()) + gotKey := []byte(iter.Key()) + gotVal := iter.Value() + assert.Equal(t, expectedKey, gotKey) + assert.Equal(t, expectedVal, gotVal) + err = iter.Next() + assert.Nil(t, err) + } + assert.False(t, iter.Valid()) + } +} + +func TestSearchRetrieverByKey(t *testing.T) { + retrievers := sortedRetrievers{ + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab8"), kv.Key("ab9")), + } + + retrievers2 := sortedRetrievers{ + NewRangeRetriever(&kv.EmptyRetriever{}, nil, kv.Key("ab1")), + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(&kv.EmptyRetriever{}, kv.Key("ab8"), nil), + } + + cases := []struct { + retrievers sortedRetrievers + search interface{} + expectedIdx int + }{ + {retrievers: retrievers, search: nil, expectedIdx: 0}, + {retrievers: retrievers, search: "ab", expectedIdx: 0}, + {retrievers: retrievers, search: "ab1", expectedIdx: 0}, + {retrievers: retrievers, search: "ab2", expectedIdx: 0}, + {retrievers: retrievers, search: "ab3", expectedIdx: 1}, + {retrievers: retrievers, search: "ab31", expectedIdx: 1}, + {retrievers: retrievers, search: "ab4", expectedIdx: 1}, + {retrievers: retrievers, search: "ab5", expectedIdx: 2}, + {retrievers: retrievers, search: "ab51", expectedIdx: 2}, + {retrievers: retrievers, search: "ab71", expectedIdx: 3}, + {retrievers: retrievers, search: "ab8", expectedIdx: 3}, + {retrievers: retrievers, search: "ab81", expectedIdx: 3}, + {retrievers: retrievers, search: "ab9", expectedIdx: 4}, + {retrievers: retrievers, search: "aba", expectedIdx: 4}, + {retrievers: retrievers2, search: nil, expectedIdx: 0}, + {retrievers: retrievers2, search: "ab0", expectedIdx: 0}, + {retrievers: retrievers2, search: "ab8", expectedIdx: 3}, + {retrievers: retrievers2, search: "ab9", expectedIdx: 3}, + } + + for _, c := range cases { + idx, r := c.retrievers.searchRetrieverByKey(makeBytes(c.search)) + assert.Equal(t, c.expectedIdx, idx) + if idx < len(c.retrievers) { + assert.Equal(t, c.retrievers[idx], r) + } else { + assert.Nil(t, r) + } + } +} + +func TestGetScanRetrievers(t *testing.T) { + type mockRetriever struct { + kv.EmptyRetriever + // Avoid zero size struct, make it can be compared for different variables + _ interface{} + } + + snap := &mockRetriever{} + retrievers1 := sortedRetrievers{ + NewRangeRetriever(&mockRetriever{}, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(&mockRetriever{}, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(&mockRetriever{}, kv.Key("ab6"), kv.Key("ab7")), + } + + retrievers2 := sortedRetrievers{ + NewRangeRetriever(&mockRetriever{}, nil, kv.Key("ab1")), + NewRangeRetriever(&mockRetriever{}, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(&mockRetriever{}, kv.Key("ab5"), kv.Key("ab7")), + NewRangeRetriever(&mockRetriever{}, kv.Key("ab8"), nil), + } + + cases := []struct { + retrievers sortedRetrievers + startKey interface{} + endKey interface{} + expected sortedRetrievers + }{ + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab1", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab2", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab2")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab3", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab4", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab4")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab51", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab51")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab61", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab61")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: "ab8", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), kv.Key("ab8")), + }, + }, + { + retrievers: retrievers1, + startKey: nil, endKey: nil, + expected: sortedRetrievers{ + NewRangeRetriever(snap, nil, kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), nil), + }, + }, + { + retrievers: retrievers1, + startKey: "ab0", endKey: nil, + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), nil), + }, + }, + { + retrievers: retrievers1, + startKey: "ab2", endKey: nil, + expected: sortedRetrievers{ + NewRangeRetriever(retrievers1[0].Retriever, kv.Key("ab2"), kv.Key("ab3")), + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), nil), + }, + }, + { + retrievers: retrievers1, + startKey: "ab3", endKey: nil, + expected: sortedRetrievers{ + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(snap, kv.Key("ab5"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), nil), + }, + }, + { + retrievers: retrievers1, + startKey: "ab51", endKey: nil, + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab51"), kv.Key("ab6")), + NewRangeRetriever(retrievers1[2].Retriever, kv.Key("ab6"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), nil), + }, + }, + { + retrievers: retrievers1, + startKey: "ab3", endKey: "ab4", + expected: sortedRetrievers{ + NewRangeRetriever(retrievers1[1].Retriever, kv.Key("ab3"), kv.Key("ab4")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab51", endKey: "ab52", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab51"), kv.Key("ab52")), + }, + }, + { + retrievers: retrievers1, + startKey: "ab8", endKey: nil, + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab8"), nil), + }, + }, + { + retrievers: retrievers1, + startKey: "ab8", endKey: "ab9", + expected: sortedRetrievers{ + NewRangeRetriever(snap, kv.Key("ab8"), kv.Key("ab9")), + }, + }, + { + retrievers: retrievers2, + startKey: "ab0", endKey: "ab9", + expected: sortedRetrievers{ + NewRangeRetriever(retrievers2[0].Retriever, kv.Key("ab0"), kv.Key("ab1")), + NewRangeRetriever(snap, kv.Key("ab1"), kv.Key("ab3")), + NewRangeRetriever(retrievers2[1].Retriever, kv.Key("ab3"), kv.Key("ab5")), + NewRangeRetriever(retrievers2[2].Retriever, kv.Key("ab5"), kv.Key("ab7")), + NewRangeRetriever(snap, kv.Key("ab7"), kv.Key("ab8")), + NewRangeRetriever(retrievers2[3].Retriever, kv.Key("ab8"), kv.Key("ab9")), + }, + }, + } + + for _, c := range cases { + result := c.retrievers.GetScanRetrievers(makeBytes(c.startKey), makeBytes(c.endKey), snap) + expected := c.expected + assert.Equal(t, len(expected), len(result)) + for i := range expected { + assert.Same(t, expected[i].Retriever, result[i].Retriever) + assert.Equal(t, expected[i].StartKey, result[i].StartKey) + assert.Equal(t, expected[i].EndKey, result[i].EndKey) + } + + result = c.retrievers.GetScanRetrievers(makeBytes(c.startKey), makeBytes(c.endKey), nil) + expected = make([]*RangedKVRetriever, 0) + for _, r := range c.expected { + if r.Retriever != snap { + expected = append(expected, r) + } + } + assert.Equal(t, len(expected), len(result)) + for i := range expected { + assert.Same(t, expected[i].Retriever, result[i].Retriever) + assert.Equal(t, expected[i].StartKey, result[i].StartKey) + assert.Equal(t, expected[i].EndKey, result[i].EndKey) + } + } +} + +func newMemBufferRetriever(t *testing.T, data [][]interface{}) kv.MemBuffer { + txn, err := transaction.NewTiKVTxn(nil, nil, 0, "") + assert.Nil(t, err) + memBuffer := NewTiKVTxn(txn).GetMemBuffer() + for _, d := range data { + err := memBuffer.Set(makeBytes(d[0]), makeBytes(d[1])) + assert.Nil(t, err) + } + return memBuffer +} + +func TestSortedRetrieversTryGet(t *testing.T) { + retrievers := sortedRetrievers{ + NewRangeRetriever( + newMemBufferRetriever(t, [][]interface{}{{"ab0", "v0"}, {"ab2", "v2"}, {"ab3", "v3"}, {"ab4", "v4x"}}), + kv.Key("ab1"), kv.Key("ab3"), + ), + NewRangeRetriever( + newMemBufferRetriever(t, [][]interface{}{{"ab4", "v4"}, {"ab41", "v41"}}), + kv.Key("ab3"), kv.Key("ab5"), + ), + NewRangeRetriever( + newMemBufferRetriever(t, [][]interface{}{{"ab7", "v7"}}), + kv.Key("ab6"), kv.Key("ab8"), + ), + } + + tryGetCases := [][]interface{}{ + // {key, expectedValue, fromRetriever} + {"ab0", nil, false}, + {"ab1", kv.ErrNotExist, true}, + {"ab11", kv.ErrNotExist, true}, + {"ab2", "v2", true}, + {"ab3", kv.ErrNotExist, true}, + {"ab4", "v4", true}, + {"ab41", "v41", true}, + {"ab5", nil, false}, + {"ab7", "v7", true}, + {"ab8", nil, false}, + {"ab9", nil, false}, + } + + for _, c := range tryGetCases { + fromRetriever, val, err := retrievers.TryGet(context.TODO(), makeBytes(c[0])) + assert.Equal(t, c[2], fromRetriever) + if !fromRetriever { + assert.Nil(t, err) + assert.Nil(t, val) + continue + } + + if expectedErr, ok := c[1].(error); ok { + assert.True(t, errors.ErrorEqual(expectedErr, err)) + assert.Nil(t, val) + } else { + assert.Equal(t, makeBytes(c[1]), val) + assert.Nil(t, err) + } + } +} + +func TestSortedRetrieversTryBatchGet(t *testing.T) { + retrievers := sortedRetrievers{ + NewRangeRetriever( + newMemBufferRetriever(t, [][]interface{}{{"ab0", "v0"}, {"ab2", "v2"}, {"ab3", "v3"}, {"ab4", "v4x"}}), + kv.Key("ab1"), kv.Key("ab3"), + ), + NewRangeRetriever( + newMemBufferRetriever(t, [][]interface{}{{"ab4", "v4"}, {"ab41", "v41"}, {"ab51", "v51"}}), + kv.Key("ab3"), kv.Key("ab5"), + ), + NewRangeRetriever( + newMemBufferRetriever(t, [][]interface{}{{"ab7", "v7"}}), + kv.Key("ab6"), kv.Key("ab8"), + ), + } + + tryBatchGetCases := []struct { + keys []kv.Key + result map[string]string + retKeys []kv.Key + }{ + { + keys: []kv.Key{kv.Key("ab0")}, + result: map[string]string{}, + retKeys: []kv.Key{kv.Key("ab0")}, + }, + { + keys: []kv.Key{kv.Key("ab0"), kv.Key("ab51"), kv.Key("ab52"), kv.Key("ab9")}, + result: map[string]string{}, + retKeys: []kv.Key{kv.Key("ab0"), kv.Key("ab51"), kv.Key("ab52"), kv.Key("ab9")}, + }, + { + keys: []kv.Key{kv.Key("ab21"), kv.Key("ab3"), kv.Key("ab4"), kv.Key("ab41"), kv.Key("ab7")}, + result: map[string]string{ + "ab4": "v4", + "ab41": "v41", + "ab7": "v7", + }, + retKeys: nil, + }, + { + keys: []kv.Key{kv.Key("ab0"), kv.Key("ab2"), kv.Key("ab51"), kv.Key("ab7"), kv.Key("ab9")}, + result: map[string]string{ + "ab2": "v2", + "ab7": "v7", + }, + retKeys: []kv.Key{kv.Key("ab0"), kv.Key("ab51"), kv.Key("ab9")}, + }, + { + keys: []kv.Key{kv.Key("ab2"), kv.Key("ab4"), kv.Key("ab51"), kv.Key("ab52"), kv.Key("ab6"), kv.Key("ab7"), kv.Key("ab9")}, + result: map[string]string{ + "ab2": "v2", + "ab4": "v4", + "ab7": "v7", + }, + retKeys: []kv.Key{kv.Key("ab51"), kv.Key("ab52"), kv.Key("ab9")}, + }, + { + keys: []kv.Key{kv.Key("ab0"), kv.Key("ab51"), kv.Key("ab6"), kv.Key("ab7"), kv.Key("ab8"), kv.Key("ab9")}, + result: map[string]string{ + "ab7": "v7", + }, + retKeys: []kv.Key{kv.Key("ab0"), kv.Key("ab51"), kv.Key("ab8"), kv.Key("ab9")}, + }, + } + + for _, c := range tryBatchGetCases { + got := make(map[string][]byte) + keys, err := retrievers.TryBatchGet(context.TODO(), c.keys, func(k kv.Key, v []byte) { + _, ok := got[string(k)] + assert.False(t, ok) + got[string(k)] = v + }) + assert.Nil(t, err) + assert.Equal(t, c.retKeys, keys) + assert.Equal(t, len(c.result), len(got)) + for k, v := range c.result { + val, ok := got[k] + assert.True(t, ok) + assert.Equal(t, []byte(v), val) + } + } +} diff --git a/util/misc.go b/util/misc.go index e891cda5a82c6..896c7f880ba1d 100644 --- a/util/misc.go +++ b/util/misc.go @@ -591,7 +591,7 @@ func QueryStrForLog(query string) string { } func createTLSCertificates(certpath string, keypath string, rsaKeySize int) error { - privkey, err := rsa.GenerateKey(rand.Reader, 4096) + privkey, err := rsa.GenerateKey(rand.Reader, rsaKeySize) if err != nil { return err }