From 6a1e5215dff2e5d1e58b8a0c50c12ba7564e4094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Wed, 4 Aug 2021 14:19:07 +0800 Subject: [PATCH] executor: Add insert/replace ignore/on duplicate key support for local temporary table (#26636) --- executor/batch_checker.go | 4 +- executor/executor_test.go | 10 +++ executor/insert.go | 23 +++++-- executor/insert_common.go | 22 +++++-- executor/point_get.go | 2 +- executor/replace.go | 20 +++--- session/session_test.go | 117 +++++++++++++++++++++++++++++++++ sessionctx/variable/session.go | 44 ++++++++++--- table/tables/index.go | 2 +- table/tables/tables.go | 2 +- 10 files changed, 212 insertions(+), 34 deletions(-) diff --git a/executor/batch_checker.go b/executor/batch_checker.go index ed454a899bd52..4705038c4f537 100644 --- a/executor/batch_checker.go +++ b/executor/batch_checker.go @@ -234,9 +234,9 @@ func formatDataForDupError(data []types.Datum) (string, error) { // getOldRow gets the table record row from storage for batch check. // t could be a normal table or a partition, but it must not be a PartitionedTable. -func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, t table.Table, handle kv.Handle, +func getOldRow(ctx context.Context, sctx sessionctx.Context, kvGetter kv.Getter, t table.Table, handle kv.Handle, genExprs []expression.Expression) ([]types.Datum, error) { - oldValue, err := txn.Get(ctx, tablecodec.EncodeRecordKey(t.RecordPrefix(), handle)) + oldValue, err := kvGetter.Get(ctx, tablecodec.EncodeRecordKey(t.RecordPrefix(), handle)) if err != nil { return nil, err } diff --git a/executor/executor_test.go b/executor/executor_test.go index 73221883de66a..225828460661f 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -8463,9 +8463,19 @@ func (s testSerialSuite) assertTemporaryTableNoNetwork(c *C, temporaryTableType tk.MustQuery("select /*+ USE_INDEX(tmp_t, a) */ b from tmp_t where a = 1").Check(testkit.Rows("1")) tk.MustExec("rollback") + // prepare some data for local temporary table, when for global temporary table, the below operations have no effect. + tk.MustExec("insert into tmp_t value(10, 10, 10)") + tk.MustExec("insert into tmp_t value(11, 11, 11)") + // Pessimistic lock tk.MustExec("begin pessimistic") tk.MustExec("insert into tmp_t values (3, 3, 3)") + tk.MustExec("insert ignore into tmp_t values (4, 4, 4)") + tk.MustExec("insert into tmp_t values (5, 5, 5) on duplicate key update a=100") + tk.MustExec("insert into tmp_t values (10, 10, 10) on duplicate key update a=100") + tk.MustExec("insert ignore into tmp_t values (10, 10, 10) on duplicate key update id=11") + tk.MustExec("replace into tmp_t values(6, 6, 6)") + tk.MustExec("replace into tmp_t values(11, 100, 100)") tk.MustExec("update tmp_t set id = id + 1 where a = 1") tk.MustExec("delete from tmp_t where a > 1") tk.MustQuery("select count(*) from tmp_t where a >= 1 for update") diff --git a/executor/insert.go b/executor/insert.go index c6195ccef34c9..f0b84091240fa 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -20,6 +20,8 @@ import ( "runtime/trace" "time" + "github.com/pingcap/parser/model" + "github.com/opentracing/opentracing-go" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -166,7 +168,12 @@ func prefetchConflictedOldRows(ctx context.Context, txn kv.Transaction, rows []t return err } -func prefetchDataCache(ctx context.Context, txn kv.Transaction, rows []toBeCheckedRow) error { +func (e *InsertValues) prefetchDataCache(ctx context.Context, txn kv.Transaction, rows []toBeCheckedRow) error { + // Temporary table need not to do prefetch because its all data are stored in the memory. + if e.Table.Meta().TempTableType != model.TempTableNone { + return nil + } + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("prefetchDataCache", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -180,8 +187,8 @@ func prefetchDataCache(ctx context.Context, txn kv.Transaction, rows []toBeCheck } // updateDupRow updates a duplicate row to a new row. -func (e *InsertExec) updateDupRow(ctx context.Context, idxInBatch int, txn kv.Transaction, row toBeCheckedRow, handle kv.Handle, onDuplicate []*expression.Assignment) error { - oldRow, err := getOldRow(ctx, e.ctx, txn, row.t, handle, e.GenExprs) +func (e *InsertExec) updateDupRow(ctx context.Context, idxInBatch int, kvGetter kv.Getter, row toBeCheckedRow, handle kv.Handle, onDuplicate []*expression.Assignment) error { + oldRow, err := getOldRow(ctx, e.ctx, kvGetter, row.t, handle, e.GenExprs) if err != nil { return err } @@ -222,12 +229,14 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D prefetchStart := time.Now() // Use BatchGet to fill cache. // It's an optimization and could be removed without affecting correctness. - if err = prefetchDataCache(ctx, txn, toBeCheckedRows); err != nil { + if err = e.prefetchDataCache(ctx, txn, toBeCheckedRows); err != nil { return err } if e.stats != nil { e.stats.Prefetch += time.Since(prefetchStart) } + + txnValueGetter := e.txnValueGetter(txn) for i, r := range toBeCheckedRows { if r.handleKey != nil { handle, err := tablecodec.DecodeRowKey(r.handleKey.newKey) @@ -235,7 +244,7 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D return err } - err = e.updateDupRow(ctx, i, txn, r, handle, e.OnDuplicate) + err = e.updateDupRow(ctx, i, txnValueGetter, r, handle, e.OnDuplicate) if err == nil { continue } @@ -245,7 +254,7 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D } for _, uk := range r.uniqueKeys { - val, err := txn.Get(ctx, uk.newKey) + val, err := txnValueGetter.Get(ctx, uk.newKey) if err != nil { if kv.IsErrNotFound(err) { continue @@ -257,7 +266,7 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D return err } - err = e.updateDupRow(ctx, i, txn, r, handle, e.OnDuplicate) + err = e.updateDupRow(ctx, i, txnValueGetter, r, handle, e.OnDuplicate) if err != nil { if kv.IsErrNotFound(err) { // Data index inconsistent? A unique key provide the handle information, but the diff --git a/executor/insert_common.go b/executor/insert_common.go index 392976a743348..2c25541a7d512 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -1051,13 +1051,18 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } prefetchStart := time.Now() // Fill cache using BatchGet, the following Get requests don't need to visit TiKV. - if _, err = prefetchUniqueIndices(ctx, txn, toBeCheckedRows); err != nil { - return err + // Temporary table need not to do prefetch because its all data are stored in the memory. + if e.Table.Meta().TempTableType == model.TempTableNone { + if _, err = prefetchUniqueIndices(ctx, txn, toBeCheckedRows); err != nil { + return err + } } + if e.stats != nil { e.stats.Prefetch += time.Since(prefetchStart) } + txnValueGetter := e.txnValueGetter(txn) // append warnings and get no duplicated error rows for i, r := range toBeCheckedRows { if r.ignored { @@ -1065,7 +1070,7 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } skip := false if r.handleKey != nil { - _, err := txn.Get(ctx, r.handleKey.newKey) + _, err := txnValueGetter.Get(ctx, r.handleKey.newKey) if err == nil { e.ctx.GetSessionVars().StmtCtx.AppendWarning(r.handleKey.dupErr) continue @@ -1075,7 +1080,7 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } } for _, uk := range r.uniqueKeys { - _, err := txn.Get(ctx, uk.newKey) + _, err := txnValueGetter.Get(ctx, uk.newKey) if err == nil { // If duplicate keys were found in BatchGet, mark row = nil. e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) @@ -1104,6 +1109,15 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D return nil } +func (e *InsertValues) txnValueGetter(txn kv.Transaction) kv.Getter { + tblInfo := e.Table.Meta() + if tblInfo.TempTableType == model.TempTableNone { + return txn + } + + return e.ctx.GetSessionVars().TemporaryTableTxnReader(txn, tblInfo) +} + func (e *InsertValues) addRecord(ctx context.Context, row []types.Datum) error { return e.addRecordWithAutoIDHint(ctx, row, 0) } diff --git a/executor/point_get.go b/executor/point_get.go index 31bd7c16e53a3..6bc96f3e755e3 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -410,7 +410,7 @@ func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) // Local temporary table always get snapshot value from session if e.tblInfo.TempTableType == model.TempTableLocal { - return e.ctx.GetSessionVars().GetTemporaryTableSnapshotValue(ctx, key) + return e.ctx.GetSessionVars().TemporaryTableSnapshotReader(e.tblInfo).Get(ctx, key) } lock := e.tblInfo.Lock diff --git a/executor/replace.go b/executor/replace.go index 83df806489524..e1c70f07a15b6 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -61,9 +61,9 @@ func (e *ReplaceExec) Open(ctx context.Context) error { // removeRow removes the duplicate row and cleanup its keys in the key-value map, // but if the to-be-removed row equals to the to-be-added row, no remove or add things to do. -func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle kv.Handle, r toBeCheckedRow) (bool, error) { +func (e *ReplaceExec) removeRow(ctx context.Context, kvGetter kv.Getter, handle kv.Handle, r toBeCheckedRow) (bool, error) { newRow := r.row - oldRow, err := getOldRow(ctx, e.ctx, txn, r.t, handle, e.GenExprs) + oldRow, err := getOldRow(ctx, e.ctx, kvGetter, r.t, handle, e.GenExprs) if err != nil { logutil.BgLogger().Error("get old row failed when replace", zap.String("handle", handle.String()), @@ -119,14 +119,15 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { return err } + txnValueGetter := e.txnValueGetter(txn) if r.handleKey != nil { handle, err := tablecodec.DecodeRowKey(r.handleKey.newKey) if err != nil { return err } - if _, err := txn.Get(ctx, r.handleKey.newKey); err == nil { - rowUnchanged, err := e.removeRow(ctx, txn, handle, r) + if _, err := txnValueGetter.Get(ctx, r.handleKey.newKey); err == nil { + rowUnchanged, err := e.removeRow(ctx, txnValueGetter, handle, r) if err != nil { return err } @@ -142,7 +143,7 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { // Keep on removing duplicated rows. for { - rowUnchanged, foundDupKey, err := e.removeIndexRow(ctx, txn, r) + rowUnchanged, foundDupKey, err := e.removeIndexRow(ctx, txnValueGetter, r) if err != nil { return err } @@ -169,9 +170,9 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { // 2. bool: true when found the duplicated key. This only means that duplicated key was found, // and the row was removed. // 3. error: the error. -func (e *ReplaceExec) removeIndexRow(ctx context.Context, txn kv.Transaction, r toBeCheckedRow) (bool, bool, error) { +func (e *ReplaceExec) removeIndexRow(ctx context.Context, kvGetter kv.Getter, r toBeCheckedRow) (bool, bool, error) { for _, uk := range r.uniqueKeys { - val, err := txn.Get(ctx, uk.newKey) + val, err := kvGetter.Get(ctx, uk.newKey) if err != nil { if kv.IsErrNotFound(err) { continue @@ -182,7 +183,7 @@ func (e *ReplaceExec) removeIndexRow(ctx context.Context, txn kv.Transaction, r if err != nil { return false, true, err } - rowUnchanged, err := e.removeRow(ctx, txn, handle, r) + rowUnchanged, err := e.removeRow(ctx, kvGetter, handle, r) if err != nil { return false, true, err } @@ -228,9 +229,10 @@ func (e *ReplaceExec) exec(ctx context.Context, newRows [][]types.Datum) error { prefetchStart := time.Now() // Use BatchGet to fill cache. // It's an optimization and could be removed without affecting correctness. - if err = prefetchDataCache(ctx, txn, toBeCheckedRows); err != nil { + if err = e.prefetchDataCache(ctx, txn, toBeCheckedRows); err != nil { return err } + if e.stats != nil { e.stats.Prefetch = time.Since(prefetchStart) } diff --git a/session/session_test.go b/session/session_test.go index 0bd5309870798..42378b536b8c2 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4994,6 +4994,123 @@ func (s *testSessionSuite) TestLocalTemporaryTableInsert(c *C) { tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows()) } +func (s *testSessionSuite) TestLocalTemporaryTableInsertIgnore(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 values(1, 11, 101)") + tk.MustExec("insert into tmp1 values(2, 12, 102)") + + // test outside transaction + tk.MustExec("insert ignore into tmp1 values(1, 100, 1000)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '1' for key 'PRIMARY'")) + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustExec("insert ignore into tmp1 values(5, 15, 105)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows("5 15 105")) + + // test in transaction and rollback + tk.MustExec("begin") + tk.MustExec("insert ignore into tmp1 values(1, 100, 1000)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '1' for key 'PRIMARY'")) + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustExec("insert ignore into tmp1 values(3, 13, 103)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + tk.MustExec("insert ignore into tmp1 values(3, 100, 1000)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '3' for key 'PRIMARY'")) + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + tk.MustExec("rollback") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 11 101", "2 12 102", "5 15 105")) + + // test commit + tk.MustExec("begin") + tk.MustExec("insert ignore into tmp1 values(1, 100, 1000)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '1' for key 'PRIMARY'")) + tk.MustExec("insert ignore into tmp1 values(3, 13, 103)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustExec("insert ignore into tmp1 values(3, 100, 1000)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '3' for key 'PRIMARY'")) + tk.MustExec("commit") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 11 101", "2 12 102", "3 13 103", "5 15 105")) +} + +func (s *testSessionSuite) TestLocalTemporaryTableInsertOnDuplicateKeyUpdate(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 values(1, 11, 101)") + tk.MustExec("insert into tmp1 values(2, 12, 102)") + + // test outside transaction + tk.MustExec("insert ignore into tmp1 values(1, 100, 1000) on duplicate key update u=12") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '12' for key 'u'")) + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustExec("insert into tmp1 values(2, 100, 1000) on duplicate key update v=202") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 202")) + tk.MustExec("insert into tmp1 values(3, 13, 103) on duplicate key update v=203") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + + // test in transaction and rollback + tk.MustExec("begin") + tk.MustExec("insert ignore into tmp1 values(1, 100, 1000) on duplicate key update u=12") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '12' for key 'u'")) + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustExec("insert into tmp1 values(2, 100, 1000) on duplicate key update v=302") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 302")) + tk.MustExec("insert into tmp1 values(4, 14, 104) on duplicate key update v=204") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104")) + tk.MustExec("rollback") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 11 101", "2 12 202", "3 13 103")) + + // test commit + tk.MustExec("begin") + tk.MustExec("insert ignore into tmp1 values(1, 100, 1000) on duplicate key update u=12") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '12' for key 'u'")) + tk.MustExec("insert into tmp1 values(2, 100, 1000) on duplicate key update v=302") + tk.MustExec("insert into tmp1 values(4, 14, 104) on duplicate key update v=204") + tk.MustExec("commit") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 11 101", "2 12 302", "3 13 103", "4 14 104")) +} + +func (s *testSessionSuite) TestLocalTemporaryTableReplace(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 values(1, 11, 101)") + tk.MustExec("insert into tmp1 values(2, 12, 102)") + tk.MustExec("insert into tmp1 values(3, 13, 103)") + + // out of transaction + tk.MustExec("replace into tmp1 values(1, 12, 1000)") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 12 1000", "3 13 103")) + tk.MustExec("replace into tmp1 values(4, 14, 104)") + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104")) + + // in transaction and rollback + tk.MustExec("begin") + tk.MustExec("replace into tmp1 values(1, 13, 999)") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 13 999", "4 14 104")) + tk.MustExec("replace into tmp1 values(5, 15, 105)") + tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows("5 15 105")) + tk.MustExec("rollback") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 12 1000", "3 13 103", "4 14 104")) + + // out of transaction + tk.MustExec("begin") + tk.MustExec("replace into tmp1 values(1, 13, 999)") + tk.MustExec("replace into tmp1 values(5, 15, 105)") + tk.MustExec("commit") + tk.MustQuery("select * from tmp1").Check(testkit.Rows("1 13 999", "4 14 104", "5 15 105")) +} + func (s *testSessionSuite) TestLocalTemporaryTableDelete(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("set @@tidb_enable_noop_functions=1") diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 960add8b2c57b..226df03b2d6fd 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -2220,14 +2220,18 @@ func (s *SessionVars) GetSeekFactor(tbl *model.TableInfo) float64 { return s.seekFactor } -// GetTemporaryTableSnapshotValue get temporary table value from session -func (s *SessionVars) GetTemporaryTableSnapshotValue(ctx context.Context, key kv.Key) ([]byte, error) { - memData := s.TemporaryTableData - if memData == nil { +// TemporaryTableSnapshotReader can read the temporary table snapshot data +type TemporaryTableSnapshotReader struct { + memBuffer kv.MemBuffer +} + +// Get gets the value for key k from snapshot. +func (s *TemporaryTableSnapshotReader) Get(ctx context.Context, k kv.Key) ([]byte, error) { + if s.memBuffer == nil { return nil, kv.ErrNotExist } - v, err := memData.Get(ctx, key) + v, err := s.memBuffer.Get(ctx, k) if err != nil { return v, err } @@ -2239,9 +2243,23 @@ func (s *SessionVars) GetTemporaryTableSnapshotValue(ctx context.Context, key kv return v, nil } -// GetTemporaryTableTxnValue returns a kv.Getter to fetch temporary table data in txn -func (s *SessionVars) GetTemporaryTableTxnValue(ctx context.Context, txn kv.Transaction, key kv.Key) ([]byte, error) { - v, err := txn.GetMemBuffer().Get(ctx, key) +// TemporaryTableSnapshotReader can read the temporary table snapshot data +func (s *SessionVars) TemporaryTableSnapshotReader(tblInfo *model.TableInfo) *TemporaryTableSnapshotReader { + if tblInfo.TempTableType == model.TempTableGlobal { + return &TemporaryTableSnapshotReader{nil} + } + return &TemporaryTableSnapshotReader{s.TemporaryTableData} +} + +// TemporaryTableTxnReader can read the temporary table txn data +type TemporaryTableTxnReader struct { + memBuffer kv.MemBuffer + snapshot *TemporaryTableSnapshotReader +} + +// Get gets the value for key k from txn. +func (s *TemporaryTableTxnReader) Get(ctx context.Context, k kv.Key) ([]byte, error) { + v, err := s.memBuffer.Get(ctx, k) if err == nil { if len(v) == 0 { return nil, kv.ErrNotExist @@ -2254,5 +2272,13 @@ func (s *SessionVars) GetTemporaryTableTxnValue(ctx context.Context, txn kv.Tran return v, err } - return s.GetTemporaryTableSnapshotValue(ctx, key) + return s.snapshot.Get(ctx, k) +} + +// TemporaryTableTxnReader can read the temporary table txn data +func (s *SessionVars) TemporaryTableTxnReader(txn kv.Transaction, tblInfo *model.TableInfo) *TemporaryTableTxnReader { + return &TemporaryTableTxnReader{ + memBuffer: txn.GetMemBuffer(), + snapshot: s.TemporaryTableSnapshotReader(tblInfo), + } } diff --git a/table/tables/index.go b/table/tables/index.go index ae0eca1339482..afb9275c2dad9 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -201,7 +201,7 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue var value []byte if c.tblInfo.TempTableType != model.TempTableNone { // Always check key for temporary table because it does not write to TiKV - value, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key) + value, err = sctx.GetSessionVars().TemporaryTableTxnReader(txn, c.tblInfo).Get(ctx, key) } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { value, err = txn.GetMemBuffer().Get(ctx, key) } else { diff --git a/table/tables/tables.go b/table/tables/tables.go index 471c31d5c0ef4..b43110c2e4a05 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -777,7 +777,7 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . if (t.meta.IsCommonHandle || t.meta.PKIsHandle) && !skipCheck && !opt.SkipHandleCheck { if t.meta.TempTableType != model.TempTableNone { // Always check key for temporary table because it does not write to TiKV - _, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key) + _, err = sctx.GetSessionVars().TemporaryTableTxnReader(txn, t.meta).Get(ctx, key) } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { var v []byte v, err = txn.GetMemBuffer().Get(ctx, key)