Skip to content

Commit

Permalink
server: add protocol support for lazy cursor fetch (#54527)
Browse files Browse the repository at this point in the history
close #54526
  • Loading branch information
YangKeao authored Jul 18, 2024
1 parent 0ad8b84 commit afd6d6a
Show file tree
Hide file tree
Showing 18 changed files with 466 additions and 117 deletions.
2 changes: 1 addition & 1 deletion pkg/executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (a *recordSet) OnFetchReturned() {
a.stmt.LogSlowQuery(a.txnStartTS, len(a.lastErrs) == 0, true)
}

// Detach creates a new `RecordSet` which doesn't depend on the current session context.
// TryDetach creates a new `RecordSet` which doesn't depend on the current session context.
func (a *recordSet) TryDetach() (sqlexec.RecordSet, bool, error) {
// TODO: also detach the executor. Currently, the executor inside may contain the session context. Once
// the executor itself supports detach, we should also detach it here.
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/staticrecordset/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ go_test(
timeout = "short",
srcs = ["integration_test.go"],
flaky = True,
shard_count = 6,
shard_count = 7,
deps = [
"//pkg/session/cursor",
"//pkg/testkit",
"//pkg/util/sqlexec",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_stretchr_testify//require",
"@com_github_tikv_client_go_v2//tikv",
],
Expand Down
21 changes: 21 additions & 0 deletions pkg/executor/staticrecordset/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"
"time"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/session/cursor"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/util/sqlexec"
Expand Down Expand Up @@ -211,3 +212,23 @@ func TestCursorWillBlockMinStartTS(t *testing.T) {
return infoSyncer.GetMinStartTS() == secondStartTS
}, time.Second*5, time.Millisecond*100)
}

func TestFinishStmtError(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec("create table t(id int)")
tk.MustExec("insert into t values (1), (2), (3)")

rs, err := tk.Exec("select * from t")
require.NoError(t, err)
drs := rs.(sqlexec.DetachableRecordSet)

failpoint.Enable("github.com/pingcap/tidb/pkg/session/finishStmtError", "return")
defer failpoint.Disable("github.com/pingcap/tidb/pkg/session/finishStmtError")
// Then `TryDetach` should return `true`, because the original record set is detached and cannot be used anymore.
_, ok, err := drs.TryDetach()
require.True(t, ok)
require.Error(t, err)
}
10 changes: 5 additions & 5 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2438,10 +2438,10 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset
start = time.Now()
}

iter := rs.GetRowContainerReader()
iter := rs.GetRowIterator()
// send the rows to the client according to fetchSize.
for i := 0; i < fetchSize && iter.Current() != iter.End(); i++ {
row := iter.Current()
for i := 0; i < fetchSize && iter.Current(ctx) != iter.End(); i++ {
row := iter.Current(ctx)

data = data[0:4]
data, err = column.DumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder)
Expand All @@ -2452,15 +2452,15 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset
return err
}

iter.Next()
iter.Next(ctx)
}
if iter.Error() != nil {
return iter.Error()
}

// tell the client COM_STMT_FETCH has finished by setting proper serverStatus,
// and close ResultSet.
if iter.Current() == iter.End() {
if iter.Current(ctx) == iter.End() {
serverStatus &^= mysql.ServerStatusCursorExists
serverStatus |= mysql.ServerStatusLastRowSend
}
Expand Down
207 changes: 129 additions & 78 deletions pkg/server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm

// first, try to clear the left cursor if there is one
if useCursor && stmt.GetCursorActive() {
if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowContainerReader() != nil {
stmt.GetResultSet().GetRowContainerReader().Close()
if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowIterator() != nil {
stmt.GetResultSet().GetRowIterator().Close()
}
if stmt.GetRowContainer() != nil {
stmt.GetRowContainer().GetMemTracker().Detach()
Expand All @@ -304,8 +304,13 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
}
execStmt.SetText(charset.EncodingUTF8Impl, sql)
rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt)
var lazy bool
if rs != nil {
defer rs.Close()
defer func() {
if !lazy {
rs.Close()
}
}()
}
if err != nil {
// If error is returned during the planner phase or the executor.Open
Expand All @@ -331,97 +336,143 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
// we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo.
// Tell the client cursor exists in server by setting proper serverStatus.
if useCursor {
crs := resultset.WrapWithCursor(rs)

cc.initResultEncoder(ctx)
defer cc.rsEncoder.Clean()
// fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read
// the rows directly to avoid running executor and accessing shared params/variables in the session
// NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command
// but the rows are still needed in the following FETCH command.

// create the row container to manage spill
// this `rowContainer` will be released when the statement (or the connection) is closed.
rowContainer := chunk.NewRowContainer(crs.FieldTypes(), vars.MaxChunkSize)
rowContainer.GetMemTracker().AttachTo(vars.MemTracker)
rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch)
rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker)
rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch)
if variable.EnableTmpStorageOnOOM.Load() {
failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) {
if val, ok := val.(bool); val && ok {
actionSpill := rowContainer.ActionSpillForTest()
defer actionSpill.WaitForTest()
}
})
action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority)
vars.MemTracker.FallbackOldAndSetNewAction(action)
}
defer func() {
if err != nil {
rowContainer.GetMemTracker().Detach()
rowContainer.GetDiskTracker().Detach()
errCloseRowContainer := rowContainer.Close()
if errCloseRowContainer != nil {
logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak",
zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer))
}
}
}()

for {
chk := crs.NewChunk(nil)
lazy, err = cc.executeWithCursor(ctx, stmt, rs)
return false, err
}
retryable, err := cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), 0)
if err != nil {
return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
}
return false, nil
}

if err = crs.Next(ctx, chk); err != nil {
return false, err
}
rowCount := chk.NumRows()
if rowCount == 0 {
break
}
func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (lazy bool, err error) {
vars := (&cc.ctx).GetSessionVars()
if vars.EnableLazyCursorFetch {
// try to execute with lazy cursor fetch
ok, err := cc.executeWithLazyCursor(ctx, stmt, rs)

err = rowContainer.Add(chk)
if err != nil {
return false, err
}
// if `ok` is false, should try to execute without lazy cursor fetch
if ok {
return true, err
}
}

reader := chunk.NewRowContainerReader(rowContainer)
crs.StoreRowContainerReader(reader)
stmt.StoreResultSet(crs)
stmt.StoreRowContainer(rowContainer)
if cl, ok := crs.(resultset.FetchNotifier); ok {
cl.OnFetchReturned()
failpoint.Inject("avoidEagerCursorFetch", func() {
failpoint.Return(false, errors.New("failpoint avoids eager cursor fetch"))
})
cc.initResultEncoder(ctx)
defer cc.rsEncoder.Clean()
// fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read
// the rows directly to avoid running executor and accessing shared params/variables in the session
// NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command
// but the rows are still needed in the following FETCH command.

// create the row container to manage spill
// this `rowContainer` will be released when the statement (or the connection) is closed.
rowContainer := chunk.NewRowContainer(rs.FieldTypes(), vars.MaxChunkSize)
rowContainer.GetMemTracker().AttachTo(vars.MemTracker)
rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch)
rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker)
rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch)
if variable.EnableTmpStorageOnOOM.Load() {
failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) {
if val, ok := val.(bool); val && ok {
actionSpill := rowContainer.ActionSpillForTest()
defer actionSpill.WaitForTest()
}
})
action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority)
vars.MemTracker.FallbackOldAndSetNewAction(action)
}
defer func() {
if err != nil {
rowContainer.GetMemTracker().Detach()
rowContainer.GetDiskTracker().Detach()
errCloseRowContainer := rowContainer.Close()
if errCloseRowContainer != nil {
logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak",
zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer))
}
}
stmt.SetCursorActive(true)
defer func() {
if err != nil {
reader.Close()
}()

// the resultSet and rowContainer have been closed in former "defer" statement.
stmt.StoreResultSet(nil)
stmt.StoreRowContainer(nil)
stmt.SetCursorActive(false)
}
}()
for {
chk := rs.NewChunk(nil)

if err = cc.writeColumnInfo(crs.Columns()); err != nil {
if err = rs.Next(ctx, chk); err != nil {
return false, err
}
rowCount := chk.NumRows()
if rowCount == 0 {
break
}

// explicitly flush columnInfo to client.
err = cc.writeEOF(ctx, cc.ctx.Status())
err = rowContainer.Add(chk)
if err != nil {
return false, err
}
}

reader := chunk.NewRowContainerReader(rowContainer)
defer func() {
if err != nil {
reader.Close()
}
}()
crs := resultset.WrapWithRowContainerCursor(rs, reader)
if cl, ok := crs.(resultset.FetchNotifier); ok {
cl.OnFetchReturned()
}
stmt.StoreRowContainer(rowContainer)

err = cc.writeExecuteResultWithCursor(ctx, stmt, crs)
return false, err
}

return false, cc.flush(ctx)
// executeWithLazyCursor tries to detach the `ResultSet` and make it suitable to execute lazily.
// Be careful that the return value `(bool, error)` has different meaning with other similar functions. The first `bool` represent whether
// the `ResultSet` is suitable for lazy execution. If the return value is `(false, _)`, the `rs` in argument can still be used. If the
// first return value is `true` and `err` is not nil, the `rs` cannot be used anymore and should return the error to the upper layer.
func (cc *clientConn) executeWithLazyCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (ok bool, err error) {
drs, ok, err := rs.TryDetach()
if !ok || err != nil {
return false, err
}
retryable, err := cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), 0)

vars := (&cc.ctx).GetSessionVars()
crs := resultset.WrapWithLazyCursor(drs, vars.InitChunkSize, vars.MaxChunkSize)
err = cc.writeExecuteResultWithCursor(ctx, stmt, crs)
return true, err
}

// writeExecuteResultWithCursor will store the `ResultSet` in `stmt` and send the column info to the client. The logic is shared between
// lazy cursor fetch and normal(eager) cursor fetch.
func (cc *clientConn) writeExecuteResultWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.CursorResultSet) error {
var err error

stmt.StoreResultSet(rs)
stmt.SetCursorActive(true)
defer func() {
if err != nil {
// the resultSet and rowContainer have been closed in former "defer" statement.
stmt.StoreResultSet(nil)
stmt.StoreRowContainer(nil)
stmt.SetCursorActive(false)
}
}()

if err = cc.writeColumnInfo(rs.Columns()); err != nil {
return err
}

// explicitly flush columnInfo to client.
err = cc.writeEOF(ctx, cc.ctx.Status())
if err != nil {
return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
return err
}
return false, nil

return cc.flush(ctx)
}

func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) {
Expand Down Expand Up @@ -476,7 +527,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err

_, err = cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), int(fetchSize))
// if the iterator reached the end before writing result, we could say the `FETCH` command will send EOF
if rs.GetRowContainerReader().Current() == rs.GetRowContainerReader().End() {
if rs.GetRowIterator().Current(ctx) == rs.GetRowIterator().End() {
// also reset the statement when the cursor reaches the end
// don't overwrite the `err` in outer scope, to avoid redundant `Reset()` in `defer` statement (though, it's not
// a big problem, as the `Reset()` function call is idempotent.)
Expand Down
24 changes: 12 additions & 12 deletions pkg/server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ func TestCursorWithParams(t *testing.T) {
0x0, 0x1, 0x3, 0x0, 0x3, 0x0,
0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0,
)))
rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetRowContainerReader()
require.Equal(t, int64(1), rows.Current().GetInt64(0))
require.Equal(t, int64(2), rows.Current().GetInt64(1))
rows.Next()
require.Equal(t, rows.End(), rows.Current())
rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetRowIterator()
require.Equal(t, int64(1), rows.Current(context.Background()).GetInt64(0))
require.Equal(t, int64(2), rows.Current(context.Background()).GetInt64(1))
rows.Next(context.Background())
require.Equal(t, rows.End(), rows.Current(context.Background()))

// `execute stmt2 using 1` with cursor
require.NoError(t, c.Dispatch(ctx, append(
Expand All @@ -142,13 +142,13 @@ func TestCursorWithParams(t *testing.T) {
0x0, 0x1, 0x3, 0x0,
0x1, 0x0, 0x0, 0x0,
)))
rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetRowContainerReader()
require.Equal(t, int64(1), rows.Current().GetInt64(0))
require.Equal(t, int64(1), rows.Current().GetInt64(1))
require.Equal(t, int64(1), rows.Next().GetInt64(0))
require.Equal(t, int64(2), rows.Current().GetInt64(1))
rows.Next()
require.Equal(t, rows.End(), rows.Current())
rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetRowIterator()
require.Equal(t, int64(1), rows.Current(context.Background()).GetInt64(0))
require.Equal(t, int64(1), rows.Current(context.Background()).GetInt64(1))
require.Equal(t, int64(1), rows.Next(context.Background()).GetInt64(0))
require.Equal(t, int64(2), rows.Current(context.Background()).GetInt64(1))
rows.Next(context.Background())
require.Equal(t, rows.End(), rows.Current(context.Background()))

// fetch stmt2 with fetch size 256
require.NoError(t, c.Dispatch(ctx, append(
Expand Down
8 changes: 4 additions & 4 deletions pkg/server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ func (ts *TiDBStatement) Reset() error {
}
ts.hasActiveCursor = false

if ts.rs != nil && ts.rs.GetRowContainerReader() != nil {
ts.rs.GetRowContainerReader().Close()
if ts.rs != nil && ts.rs.GetRowIterator() != nil {
ts.rs.GetRowIterator().Close()
}
ts.rs = nil

Expand All @@ -173,8 +173,8 @@ func (ts *TiDBStatement) Reset() error {

// Close implements PreparedStatement Close method.
func (ts *TiDBStatement) Close() error {
if ts.rs != nil && ts.rs.GetRowContainerReader() != nil {
ts.rs.GetRowContainerReader().Close()
if ts.rs != nil && ts.rs.GetRowIterator() != nil {
ts.rs.GetRowIterator().Close()
}

if ts.rowContainer != nil {
Expand Down
Loading

0 comments on commit afd6d6a

Please sign in to comment.