diff --git a/pkg/executor/adapter.go b/pkg/executor/adapter.go index ebb033fe247a5..2aba9fce14c73 100644 --- a/pkg/executor/adapter.go +++ b/pkg/executor/adapter.go @@ -233,7 +233,7 @@ func (a *recordSet) OnFetchReturned() { a.stmt.LogSlowQuery(a.txnStartTS, len(a.lastErrs) == 0, true) } -// Detach creates a new `RecordSet` which doesn't depend on the current session context. +// TryDetach creates a new `RecordSet` which doesn't depend on the current session context. func (a *recordSet) TryDetach() (sqlexec.RecordSet, bool, error) { // TODO: also detach the executor. Currently, the executor inside may contain the session context. Once // the executor itself supports detach, we should also detach it here. diff --git a/pkg/executor/staticrecordset/BUILD.bazel b/pkg/executor/staticrecordset/BUILD.bazel index b9d04da961767..bff3c7ec97b2c 100644 --- a/pkg/executor/staticrecordset/BUILD.bazel +++ b/pkg/executor/staticrecordset/BUILD.bazel @@ -25,11 +25,12 @@ go_test( timeout = "short", srcs = ["integration_test.go"], flaky = True, - shard_count = 6, + shard_count = 7, deps = [ "//pkg/session/cursor", "//pkg/testkit", "//pkg/util/sqlexec", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//tikv", ], diff --git a/pkg/executor/staticrecordset/integration_test.go b/pkg/executor/staticrecordset/integration_test.go index 0e5aed4e279b0..961d7460474a6 100644 --- a/pkg/executor/staticrecordset/integration_test.go +++ b/pkg/executor/staticrecordset/integration_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/session/cursor" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -211,3 +212,23 @@ func TestCursorWillBlockMinStartTS(t *testing.T) { return infoSyncer.GetMinStartTS() == secondStartTS }, time.Second*5, time.Millisecond*100) } + +func TestFinishStmtError(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t(id int)") + tk.MustExec("insert into t values (1), (2), (3)") + + rs, err := tk.Exec("select * from t") + require.NoError(t, err) + drs := rs.(sqlexec.DetachableRecordSet) + + failpoint.Enable("github.com/pingcap/tidb/pkg/session/finishStmtError", "return") + defer failpoint.Disable("github.com/pingcap/tidb/pkg/session/finishStmtError") + // Then `TryDetach` should return `true`, because the original record set is detached and cannot be used anymore. + _, ok, err := drs.TryDetach() + require.True(t, ok) + require.Error(t, err) +} diff --git a/pkg/server/conn.go b/pkg/server/conn.go index caab2400d71c8..1c8c8236c75e2 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -2438,10 +2438,10 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset start = time.Now() } - iter := rs.GetRowContainerReader() + iter := rs.GetRowIterator() // send the rows to the client according to fetchSize. - for i := 0; i < fetchSize && iter.Current() != iter.End(); i++ { - row := iter.Current() + for i := 0; i < fetchSize && iter.Current(ctx) != iter.End(); i++ { + row := iter.Current(ctx) data = data[0:4] data, err = column.DumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder) @@ -2452,7 +2452,7 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset return err } - iter.Next() + iter.Next(ctx) } if iter.Error() != nil { return iter.Error() @@ -2460,7 +2460,7 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, // and close ResultSet. - if iter.Current() == iter.End() { + if iter.Current(ctx) == iter.End() { serverStatus &^= mysql.ServerStatusCursorExists serverStatus |= mysql.ServerStatusLastRowSend } diff --git a/pkg/server/conn_stmt.go b/pkg/server/conn_stmt.go index fbf690b135cd9..19ac430500944 100644 --- a/pkg/server/conn_stmt.go +++ b/pkg/server/conn_stmt.go @@ -276,8 +276,8 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm // first, try to clear the left cursor if there is one if useCursor && stmt.GetCursorActive() { - if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowContainerReader() != nil { - stmt.GetResultSet().GetRowContainerReader().Close() + if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowIterator() != nil { + stmt.GetResultSet().GetRowIterator().Close() } if stmt.GetRowContainer() != nil { stmt.GetRowContainer().GetMemTracker().Detach() @@ -304,8 +304,13 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm } execStmt.SetText(charset.EncodingUTF8Impl, sql) rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt) + var lazy bool if rs != nil { - defer rs.Close() + defer func() { + if !lazy { + rs.Close() + } + }() } if err != nil { // If error is returned during the planner phase or the executor.Open @@ -331,97 +336,143 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm // we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo. // Tell the client cursor exists in server by setting proper serverStatus. if useCursor { - crs := resultset.WrapWithCursor(rs) - - cc.initResultEncoder(ctx) - defer cc.rsEncoder.Clean() - // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read - // the rows directly to avoid running executor and accessing shared params/variables in the session - // NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command - // but the rows are still needed in the following FETCH command. - - // create the row container to manage spill - // this `rowContainer` will be released when the statement (or the connection) is closed. - rowContainer := chunk.NewRowContainer(crs.FieldTypes(), vars.MaxChunkSize) - rowContainer.GetMemTracker().AttachTo(vars.MemTracker) - rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch) - rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) - rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) - if variable.EnableTmpStorageOnOOM.Load() { - failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { - if val, ok := val.(bool); val && ok { - actionSpill := rowContainer.ActionSpillForTest() - defer actionSpill.WaitForTest() - } - }) - action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) - vars.MemTracker.FallbackOldAndSetNewAction(action) - } - defer func() { - if err != nil { - rowContainer.GetMemTracker().Detach() - rowContainer.GetDiskTracker().Detach() - errCloseRowContainer := rowContainer.Close() - if errCloseRowContainer != nil { - logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak", - zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer)) - } - } - }() - - for { - chk := crs.NewChunk(nil) + lazy, err = cc.executeWithCursor(ctx, stmt, rs) + return false, err + } + retryable, err := cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), 0) + if err != nil { + return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) + } + return false, nil +} - if err = crs.Next(ctx, chk); err != nil { - return false, err - } - rowCount := chk.NumRows() - if rowCount == 0 { - break - } +func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (lazy bool, err error) { + vars := (&cc.ctx).GetSessionVars() + if vars.EnableLazyCursorFetch { + // try to execute with lazy cursor fetch + ok, err := cc.executeWithLazyCursor(ctx, stmt, rs) - err = rowContainer.Add(chk) - if err != nil { - return false, err - } + // if `ok` is false, should try to execute without lazy cursor fetch + if ok { + return true, err } + } - reader := chunk.NewRowContainerReader(rowContainer) - crs.StoreRowContainerReader(reader) - stmt.StoreResultSet(crs) - stmt.StoreRowContainer(rowContainer) - if cl, ok := crs.(resultset.FetchNotifier); ok { - cl.OnFetchReturned() + failpoint.Inject("avoidEagerCursorFetch", func() { + failpoint.Return(false, errors.New("failpoint avoids eager cursor fetch")) + }) + cc.initResultEncoder(ctx) + defer cc.rsEncoder.Clean() + // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read + // the rows directly to avoid running executor and accessing shared params/variables in the session + // NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command + // but the rows are still needed in the following FETCH command. + + // create the row container to manage spill + // this `rowContainer` will be released when the statement (or the connection) is closed. + rowContainer := chunk.NewRowContainer(rs.FieldTypes(), vars.MaxChunkSize) + rowContainer.GetMemTracker().AttachTo(vars.MemTracker) + rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch) + rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) + rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) + if variable.EnableTmpStorageOnOOM.Load() { + failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { + if val, ok := val.(bool); val && ok { + actionSpill := rowContainer.ActionSpillForTest() + defer actionSpill.WaitForTest() + } + }) + action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) + vars.MemTracker.FallbackOldAndSetNewAction(action) + } + defer func() { + if err != nil { + rowContainer.GetMemTracker().Detach() + rowContainer.GetDiskTracker().Detach() + errCloseRowContainer := rowContainer.Close() + if errCloseRowContainer != nil { + logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak", + zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer)) + } } - stmt.SetCursorActive(true) - defer func() { - if err != nil { - reader.Close() + }() - // the resultSet and rowContainer have been closed in former "defer" statement. - stmt.StoreResultSet(nil) - stmt.StoreRowContainer(nil) - stmt.SetCursorActive(false) - } - }() + for { + chk := rs.NewChunk(nil) - if err = cc.writeColumnInfo(crs.Columns()); err != nil { + if err = rs.Next(ctx, chk); err != nil { return false, err } + rowCount := chk.NumRows() + if rowCount == 0 { + break + } - // explicitly flush columnInfo to client. - err = cc.writeEOF(ctx, cc.ctx.Status()) + err = rowContainer.Add(chk) if err != nil { return false, err } + } + + reader := chunk.NewRowContainerReader(rowContainer) + defer func() { + if err != nil { + reader.Close() + } + }() + crs := resultset.WrapWithRowContainerCursor(rs, reader) + if cl, ok := crs.(resultset.FetchNotifier); ok { + cl.OnFetchReturned() + } + stmt.StoreRowContainer(rowContainer) + + err = cc.writeExecuteResultWithCursor(ctx, stmt, crs) + return false, err +} - return false, cc.flush(ctx) +// executeWithLazyCursor tries to detach the `ResultSet` and make it suitable to execute lazily. +// Be careful that the return value `(bool, error)` has different meaning with other similar functions. The first `bool` represent whether +// the `ResultSet` is suitable for lazy execution. If the return value is `(false, _)`, the `rs` in argument can still be used. If the +// first return value is `true` and `err` is not nil, the `rs` cannot be used anymore and should return the error to the upper layer. +func (cc *clientConn) executeWithLazyCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (ok bool, err error) { + drs, ok, err := rs.TryDetach() + if !ok || err != nil { + return false, err } - retryable, err := cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), 0) + + vars := (&cc.ctx).GetSessionVars() + crs := resultset.WrapWithLazyCursor(drs, vars.InitChunkSize, vars.MaxChunkSize) + err = cc.writeExecuteResultWithCursor(ctx, stmt, crs) + return true, err +} + +// writeExecuteResultWithCursor will store the `ResultSet` in `stmt` and send the column info to the client. The logic is shared between +// lazy cursor fetch and normal(eager) cursor fetch. +func (cc *clientConn) writeExecuteResultWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.CursorResultSet) error { + var err error + + stmt.StoreResultSet(rs) + stmt.SetCursorActive(true) + defer func() { + if err != nil { + // the resultSet and rowContainer have been closed in former "defer" statement. + stmt.StoreResultSet(nil) + stmt.StoreRowContainer(nil) + stmt.SetCursorActive(false) + } + }() + + if err = cc.writeColumnInfo(rs.Columns()); err != nil { + return err + } + + // explicitly flush columnInfo to client. + err = cc.writeEOF(ctx, cc.ctx.Status()) if err != nil { - return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) + return err } - return false, nil + + return cc.flush(ctx) } func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { @@ -476,7 +527,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err _, err = cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), int(fetchSize)) // if the iterator reached the end before writing result, we could say the `FETCH` command will send EOF - if rs.GetRowContainerReader().Current() == rs.GetRowContainerReader().End() { + if rs.GetRowIterator().Current(ctx) == rs.GetRowIterator().End() { // also reset the statement when the cursor reaches the end // don't overwrite the `err` in outer scope, to avoid redundant `Reset()` in `defer` statement (though, it's not // a big problem, as the `Reset()` function call is idempotent.) diff --git a/pkg/server/conn_stmt_test.go b/pkg/server/conn_stmt_test.go index 8709202e7bb5f..3a7f1c45e645b 100644 --- a/pkg/server/conn_stmt_test.go +++ b/pkg/server/conn_stmt_test.go @@ -129,11 +129,11 @@ func TestCursorWithParams(t *testing.T) { 0x0, 0x1, 0x3, 0x0, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, ))) - rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetRowContainerReader() - require.Equal(t, int64(1), rows.Current().GetInt64(0)) - require.Equal(t, int64(2), rows.Current().GetInt64(1)) - rows.Next() - require.Equal(t, rows.End(), rows.Current()) + rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetRowIterator() + require.Equal(t, int64(1), rows.Current(context.Background()).GetInt64(0)) + require.Equal(t, int64(2), rows.Current(context.Background()).GetInt64(1)) + rows.Next(context.Background()) + require.Equal(t, rows.End(), rows.Current(context.Background())) // `execute stmt2 using 1` with cursor require.NoError(t, c.Dispatch(ctx, append( @@ -142,13 +142,13 @@ func TestCursorWithParams(t *testing.T) { 0x0, 0x1, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, ))) - rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetRowContainerReader() - require.Equal(t, int64(1), rows.Current().GetInt64(0)) - require.Equal(t, int64(1), rows.Current().GetInt64(1)) - require.Equal(t, int64(1), rows.Next().GetInt64(0)) - require.Equal(t, int64(2), rows.Current().GetInt64(1)) - rows.Next() - require.Equal(t, rows.End(), rows.Current()) + rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetRowIterator() + require.Equal(t, int64(1), rows.Current(context.Background()).GetInt64(0)) + require.Equal(t, int64(1), rows.Current(context.Background()).GetInt64(1)) + require.Equal(t, int64(1), rows.Next(context.Background()).GetInt64(0)) + require.Equal(t, int64(2), rows.Current(context.Background()).GetInt64(1)) + rows.Next(context.Background()) + require.Equal(t, rows.End(), rows.Current(context.Background())) // fetch stmt2 with fetch size 256 require.NoError(t, c.Dispatch(ctx, append( diff --git a/pkg/server/driver_tidb.go b/pkg/server/driver_tidb.go index bd8446cf56392..2369342728861 100644 --- a/pkg/server/driver_tidb.go +++ b/pkg/server/driver_tidb.go @@ -150,8 +150,8 @@ func (ts *TiDBStatement) Reset() error { } ts.hasActiveCursor = false - if ts.rs != nil && ts.rs.GetRowContainerReader() != nil { - ts.rs.GetRowContainerReader().Close() + if ts.rs != nil && ts.rs.GetRowIterator() != nil { + ts.rs.GetRowIterator().Close() } ts.rs = nil @@ -173,8 +173,8 @@ func (ts *TiDBStatement) Reset() error { // Close implements PreparedStatement Close method. func (ts *TiDBStatement) Close() error { - if ts.rs != nil && ts.rs.GetRowContainerReader() != nil { - ts.rs.GetRowContainerReader().Close() + if ts.rs != nil && ts.rs.GetRowIterator() != nil { + ts.rs.GetRowIterator().Close() } if ts.rowContainer != nil { diff --git a/pkg/server/internal/resultset/cursor.go b/pkg/server/internal/resultset/cursor.go index 1b07a96668ee4..624d0201d675a 100644 --- a/pkg/server/internal/resultset/cursor.go +++ b/pkg/server/internal/resultset/cursor.go @@ -14,37 +14,70 @@ package resultset -import "github.com/pingcap/tidb/pkg/util/chunk" +import ( + "context" + + "github.com/pingcap/tidb/pkg/util/chunk" +) + +// RowIterator has similar method with `chunk.RowContainerReader`. The only difference is that +// it needs a `context.Context` for the `Next` and `Current` method. +type RowIterator interface { + // Next returns the next Row. + Next(context.Context) chunk.Row + + // Current returns the current Row. + Current(context.Context) chunk.Row + + // End returns the invalid end Row. + End() chunk.Row + + // Error returns none-nil error if anything wrong happens during the iteration. + Error() error + + // Close closes the dumper + Close() +} // CursorResultSet extends the `ResultSet` to provide the ability to store an iterator type CursorResultSet interface { ResultSet - StoreRowContainerReader(reader chunk.RowContainerReader) - GetRowContainerReader() chunk.RowContainerReader + GetRowIterator() RowIterator } -// WrapWithCursor wraps a ResultSet into a CursorResultSet -func WrapWithCursor(rs ResultSet) CursorResultSet { +// WrapWithRowContainerCursor wraps a ResultSet into a CursorResultSet +func WrapWithRowContainerCursor(rs ResultSet, rowContainer chunk.RowContainerReader) CursorResultSet { return &tidbCursorResultSet{ - rs, nil, + ResultSet: rs, + reader: rowContainerReaderIter{ + RowContainerReader: rowContainer, + }, } } -var _ CursorResultSet = &tidbCursorResultSet{} - type tidbCursorResultSet struct { ResultSet - reader chunk.RowContainerReader + reader rowContainerReaderIter +} + +func (tcrs *tidbCursorResultSet) GetRowIterator() RowIterator { + return &tcrs.reader } -func (tcrs *tidbCursorResultSet) StoreRowContainerReader(reader chunk.RowContainerReader) { - tcrs.reader = reader +var _ RowIterator = &rowContainerReaderIter{} + +type rowContainerReaderIter struct { + chunk.RowContainerReader } -func (tcrs *tidbCursorResultSet) GetRowContainerReader() chunk.RowContainerReader { - return tcrs.reader +func (iter *rowContainerReaderIter) Next(context.Context) chunk.Row { + return iter.RowContainerReader.Next() +} + +func (iter *rowContainerReaderIter) Current(context.Context) chunk.Row { + return iter.RowContainerReader.Current() } // FetchNotifier represents notifier will be called in COM_FETCH. @@ -53,3 +86,84 @@ type FetchNotifier interface { // it will be used in server-side cursor. OnFetchReturned() } + +var _ CursorResultSet = &tidbLazyCursorResultSet{} + +type tidbLazyCursorResultSet struct { + ResultSet + iter lazyRowIterator +} + +// WrapWithLazyCursor wraps a ResultSet into a CursorResultSet +func WrapWithLazyCursor(rs ResultSet, capacity, maxChunkSize int) CursorResultSet { + chk := chunk.New(rs.FieldTypes(), capacity, maxChunkSize) + + return &tidbLazyCursorResultSet{ + ResultSet: rs, + iter: lazyRowIterator{ + rs: rs, + chk: chk, + }, + } +} + +func (tcrs *tidbLazyCursorResultSet) GetRowIterator() RowIterator { + return &tcrs.iter +} + +type lazyRowIterator struct { + rs ResultSet + err error + chk *chunk.Chunk + idxInChk int + started bool +} + +func (iter *lazyRowIterator) Next(ctx context.Context) chunk.Row { + if !iter.started { + iter.started = true + } + + iter.idxInChk++ + + if iter.idxInChk >= iter.chk.NumRows() { + // Reached the end, update the chunk + err := iter.rs.Next(ctx, iter.chk) + if err != nil { + iter.err = err + return chunk.Row{} + } + + if iter.chk.NumRows() == 0 { + // reached the end + return chunk.Row{} + } + iter.idxInChk = 0 + } + + return iter.chk.GetRow(iter.idxInChk) +} + +func (iter *lazyRowIterator) Current(ctx context.Context) chunk.Row { + if !iter.started { + return iter.Next(ctx) + } + + if iter.chk.NumRows() == 0 { + return chunk.Row{} + } + + return iter.chk.GetRow(iter.idxInChk) +} + +func (*lazyRowIterator) End() chunk.Row { + return chunk.Row{} +} + +func (iter *lazyRowIterator) Error() error { + return iter.err +} + +func (iter *lazyRowIterator) Close() { + iter.rs.Close() +} diff --git a/pkg/server/internal/resultset/resultset.go b/pkg/server/internal/resultset/resultset.go index 3062afdc9feed..a4b5493b44848 100644 --- a/pkg/server/internal/resultset/resultset.go +++ b/pkg/server/internal/resultset/resultset.go @@ -38,6 +38,7 @@ type ResultSet interface { FieldTypes() []*types.FieldType SetPreparedStmt(stmt *core.PlanCacheStmt) Finish() error + TryDetach() (ResultSet, bool, error) } var _ ResultSet = &tidbResultSet{} @@ -142,3 +143,22 @@ func (trs *tidbResultSet) FieldTypes() []*types.FieldType { func (trs *tidbResultSet) SetPreparedStmt(stmt *core.PlanCacheStmt) { trs.preparedStmt = stmt } + +// TryDetach creates a new `ResultSet` which doesn't depend on the current session context. +func (trs *tidbResultSet) TryDetach() (ResultSet, bool, error) { + detachableRecordSet, ok := trs.recordSet.(sqlexec.DetachableRecordSet) + if !ok { + return nil, false, nil + } + + recordSet, detached, err := detachableRecordSet.TryDetach() + if !detached || err != nil { + return nil, detached, err + } + + return &tidbResultSet{ + recordSet: recordSet, + preparedStmt: trs.preparedStmt, + columns: trs.columns, + }, true, nil +} diff --git a/pkg/server/tests/commontest/BUILD.bazel b/pkg/server/tests/commontest/BUILD.bazel index 8e1baedf9abbf..4f0ea5dac5cba 100644 --- a/pkg/server/tests/commontest/BUILD.bazel +++ b/pkg/server/tests/commontest/BUILD.bazel @@ -4,6 +4,7 @@ go_test( name = "commontest_test", timeout = "short", srcs = [ + "cursor_test.go", "main_test.go", "tidb_test.go", ], diff --git a/pkg/server/tests/commontest/cursor_test.go b/pkg/server/tests/commontest/cursor_test.go new file mode 100644 index 0000000000000..f1cf1743506ed --- /dev/null +++ b/pkg/server/tests/commontest/cursor_test.go @@ -0,0 +1,63 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commontest + +import ( + "context" + "fmt" + "testing" + + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/server/internal/resultset" + "github.com/pingcap/tidb/pkg/server/tests/servertestkit" + "github.com/stretchr/testify/require" +) + +func TestLazyRowIterator(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(mysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "create table t (id int)") + require.NoError(t, err) + // insert 1000 rows + for i := 0; i < 1000; i++ { + _, err = Execute(ctx, qctx, fmt.Sprintf("insert into t values (%d)", i)) + require.NoError(t, err) + } + + for _, chkSize := range []struct{ initSize, maxSize int }{ + {1024, 1024}, {512, 512}, {256, 256}, {100, 100}} { + rs, err := Execute(ctx, qctx, "select * from t") + require.NoError(t, err) + crs := resultset.WrapWithLazyCursor(rs, chkSize.initSize, chkSize.maxSize) + iter := crs.GetRowIterator() + for i := 0; i < 1000; i++ { + row := iter.Current(ctx) + require.Equal(t, int64(i), row.GetInt64(0)) + row = iter.Next(ctx) + if i+1 >= 1000 { + require.True(t, row.IsEmpty()) + } else { + require.Equal(t, int64(i+1), row.GetInt64(0)) + } + } + require.True(t, iter.Current(ctx).IsEmpty()) + iter.Close() + } +} diff --git a/pkg/server/tests/cursor/BUILD.bazel b/pkg/server/tests/cursor/BUILD.bazel index ab1460b231e05..cfc030e98e7da 100644 --- a/pkg/server/tests/cursor/BUILD.bazel +++ b/pkg/server/tests/cursor/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 4, + shard_count = 5, deps = [ "//pkg/config", "//pkg/metrics", diff --git a/pkg/server/tests/cursor/cursor_test.go b/pkg/server/tests/cursor/cursor_test.go index bde15dd88a331..907b08e7718ab 100644 --- a/pkg/server/tests/cursor/cursor_test.go +++ b/pkg/server/tests/cursor/cursor_test.go @@ -190,7 +190,7 @@ outerLoop: require.NoError(t, err) stmt := rawStmt.(mysqlcursor.Statement) - // This query will return 10000 rows and use cursor fetch. + // This query will return `rowCount` rows and use cursor fetch. rows, err := stmt.QueryContext(context.Background(), nil) require.NoError(t, err) @@ -226,3 +226,66 @@ outerLoop: } } } + +func TestSerialLazyExecuteAndFetch(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + + mysqldriver := &mysqlcursor.MySQLDriver{} + rawConn, err := mysqldriver.Open(ts.GetDSNWithCursor(10)) + require.NoError(t, err) + conn := rawConn.(mysqlcursor.Connection) + defer conn.Close() + + _, err = conn.ExecContext(context.Background(), "drop table if exists t1", nil) + require.NoError(t, err) + _, err = conn.ExecContext(context.Background(), "create table t1(id int primary key, v int)", nil) + require.NoError(t, err) + rowCount := 1000 + for i := 0; i < rowCount; i++ { + _, err = conn.ExecContext(context.Background(), fmt.Sprintf("insert into t1 values(%d, %d)", i, i), nil) + require.NoError(t, err) + } + + _, err = conn.ExecContext(context.Background(), "set tidb_enable_lazy_cursor_fetch = 'ON'", nil) + require.NoError(t, err) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/server/avoidEagerCursorFetch", "return")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/server/avoidEagerCursorFetch")) + }() + + // Normal execute. Simple table reader. + execTimes := 0 +outerLoop: + for execTimes < 50 { + execTimes++ + rawStmt, err := conn.Prepare("select * from t1") + require.NoError(t, err) + stmt := rawStmt.(mysqlcursor.Statement) + + // This query will return `rowCount` rows and use cursor fetch. + rows, err := stmt.QueryContext(context.Background(), nil) + require.NoError(t, err) + + dest := make([]driver.Value, 2) + fetchRowCount := int64(0) + + for { + // it'll send `FETCH` commands for every 10 rows. + err := rows.Next(dest) + if err != nil { + switch err { + case io.EOF: + require.Equal(t, int64(rowCount), fetchRowCount) + rows.Close() + break outerLoop + default: + require.NoError(t, err) + } + } + require.Equal(t, fetchRowCount, dest[0]) + require.Equal(t, fetchRowCount, dest[1]) + fetchRowCount++ + } + } +} diff --git a/pkg/session/session.go b/pkg/session/session.go index 708c97b10b902..ff7d7854f9a16 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2412,7 +2412,7 @@ func (rs *execStmtResult) TryDetach() (sqlexec.RecordSet, bool, error) { if err2 != nil { logutil.BgLogger().Error("close detached record set failed", zap.Error(err2)) } - return nil, false, err + return nil, true, err } return crs, true, nil diff --git a/pkg/session/tidb.go b/pkg/session/tidb.go index 270e7c59c5a9d..226f351bd27c1 100644 --- a/pkg/session/tidb.go +++ b/pkg/session/tidb.go @@ -25,6 +25,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/ddl" "github.com/pingcap/tidb/pkg/ddl/schematracker" @@ -224,6 +225,9 @@ func recordAbortTxnDuration(sessVars *variable.SessionVars, isInternal bool) { } func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { + failpoint.Inject("finishStmtError", func() { + failpoint.Return(errors.New("occur an error after finishStmt")) + }) sessVars := se.sessionVars if !sql.IsReadOnly(sessVars) { // All the history should be added here. diff --git a/pkg/sessionctx/variable/session.go b/pkg/sessionctx/variable/session.go index 3bd19fb828e98..d51e46be1fcd2 100644 --- a/pkg/sessionctx/variable/session.go +++ b/pkg/sessionctx/variable/session.go @@ -1622,6 +1622,9 @@ type SessionVars struct { // GroupConcatMaxLen represents the maximum length of the result of GROUP_CONCAT. GroupConcatMaxLen uint64 + + // EnableLazyCursorFetch defines whether to enable the lazy cursor fetch. + EnableLazyCursorFetch bool } // GetOptimizerFixControlMap returns the specified value of the optimizer fix control. diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index 8321b29fc7cdd..f3fb13dee7ef6 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -3256,6 +3256,10 @@ var defaultSysVars = []*SysVar{ }, IsHintUpdatableVerified: true, }, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBEnableLazyCursorFetch, Value: BoolToOnOff(DefTiDBEnableLazyCursorFetch), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { + s.EnableLazyCursorFetch = TiDBOptOn(val) + return nil + }}, } // GlobalSystemVariableInitialValue gets the default value for a system variable including ones that are dynamically set (e.g. based on the store) diff --git a/pkg/sessionctx/variable/tidb_vars.go b/pkg/sessionctx/variable/tidb_vars.go index a793b691848ac..cf073f1a2430d 100644 --- a/pkg/sessionctx/variable/tidb_vars.go +++ b/pkg/sessionctx/variable/tidb_vars.go @@ -1171,6 +1171,9 @@ const ( // The value can be STANDARD, BULK. // Currently, the BULK mode only affects auto-committed DML. TiDBDMLType = "tidb_dml_type" + // TiDBEnableLazyCursorFetch defines whether to enable the lazy cursor fetch. If it's `OFF`, all results of + // of a cursor will be stored in the tidb node in `EXECUTE` command. + TiDBEnableLazyCursorFetch = "tidb_enable_lazy_cursor_fetch" ) // TiDB intentional limits @@ -1503,6 +1506,7 @@ const ( DefTiDBDMLType = "STANDARD" DefGroupConcatMaxLen = uint64(1024) DefDefaultWeekFormat = "0" + DefTiDBEnableLazyCursorFetch = false ) // Process global variables.