diff --git a/executor/builder.go b/executor/builder.go index ecd2907d41def..c9f5398c6e8ec 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1270,7 +1270,7 @@ func (b *executorBuilder) getStartTS() (uint64, error) { if err != nil { return 0, errors.Trace(err) } - if startTS == 0 && txn.Valid() { + if startTS == 0 { startTS = txn.StartTS() } b.startTS = startTS diff --git a/executor/executor_test.go b/executor/executor_test.go index 27f6920aaa870..f504ae47dacff 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -2370,7 +2370,7 @@ func (s *testSuite) TestSelectForUpdate(c *C) { tk.MustExec("drop table if exists t, t1") txn, err := tk.Se.Txn(true) - c.Assert(err, IsNil) + c.Assert(kv.ErrInvalidTxn.Equal(err), IsTrue) c.Assert(txn.Valid(), IsFalse) tk.MustExec("create table t (c1 int, c2 int, c3 int)") tk.MustExec("insert t values (11, 2, 3)") diff --git a/executor/point_get.go b/executor/point_get.go index e9dbed0617865..db506dfb55bf0 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -154,7 +154,7 @@ func (e *PointGetExecutor) encodeIndexKey() (_ []byte, err error) { } func (e *PointGetExecutor) get(key kv.Key) (val []byte, err error) { - txn, err := e.ctx.Txn(true) + txn, err := e.ctx.Txn(false) if err != nil { return nil, errors.Trace(err) } diff --git a/executor/simple.go b/executor/simple.go index ab87766e2cf7a..4cd5a9c37cc22 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -140,7 +140,7 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { sessVars := e.ctx.GetSessionVars() logutil.Logger(context.Background()).Debug("execute rollback statement", zap.Uint64("conn", sessVars.ConnectionID)) sessVars.SetStatusFlag(mysql.ServerStatusInTrans, false) - txn, err := e.ctx.Txn(true) + txn, err := e.ctx.Txn(false) if err != nil { return errors.Trace(err) } diff --git a/session/session.go b/session/session.go index 03b731715b821..a6b9b2e5ebf1f 100644 --- a/session/session.go +++ b/session/session.go @@ -1047,6 +1047,9 @@ func (s *session) DropPreparedStmt(stmtID uint32) error { } func (s *session) Txn(active bool) (kv.Transaction, error) { + if !s.txn.validOrPending() && active { + return &s.txn, kv.ErrInvalidTxn + } if s.txn.pending() && active { // Transaction is lazy initialized. // PrepareTxnCtx is called to get a tso future, makes s.txn a pending txn, diff --git a/session/session_test.go b/session/session_test.go index 418ff6e88c25e..cf42d4f0eea41 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -242,7 +242,7 @@ func (s *testSessionSuite) TestRowLock(c *C) { tk.MustExec("drop table if exists t") txn, err := tk.Se.Txn(true) - c.Assert(err, IsNil) + c.Assert(kv.ErrInvalidTxn.Equal(err), IsTrue) c.Assert(txn.Valid(), IsFalse) tk.MustExec("create table t (c1 int, c2 int, c3 int)") tk.MustExec("insert t values (11, 2, 3)") @@ -330,7 +330,9 @@ func (s *testSessionSuite) TestTxnLazyInitialize(c *C) { tk.MustExec("create table t (id int)") tk.MustExec("set @@autocommit = 0") - txn, err := tk.Se.Txn(false) + txn, err := tk.Se.Txn(true) + c.Assert(kv.ErrInvalidTxn.Equal(err), IsTrue) + txn, err = tk.Se.Txn(false) c.Assert(err, IsNil) c.Assert(txn.Valid(), IsFalse) tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) @@ -623,7 +625,7 @@ func (s *testSessionSuite) TestRetryPreparedStmt(c *C) { tk.MustExec("drop table if exists t") txn, err := tk.Se.Txn(true) - c.Assert(err, IsNil) + c.Assert(kv.ErrInvalidTxn.Equal(err), IsTrue) c.Assert(txn.Valid(), IsFalse) tk.MustExec("create table t (c1 int, c2 int, c3 int)") tk.MustExec("insert t values (11, 2, 3)")