Skip to content

Commit

Permalink
fix: merge sort for receiving fields in multiple packets
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Dec 19, 2024
1 parent 6c71923 commit 9dc7651
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
30 changes: 12 additions & 18 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,17 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
defer mu.Unlock()

inputSize := len(qr.Rows)
if inputSize == 0 {
if wantfields && len(qr.Fields) != 0 {
wantfields = false
}
return callback(qr)
}

// If this is the first callback and fields are requested, send the fields immediately.
if wantfields && len(qr.Fields) != 0 {
wantfields = false
// otherwise, we need to send the fields first, and then the rows
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
}

// If we still need to skip `offset` rows before returning any to the client:
if offset > 0 {
if inputSize <= offset {
// not enough to return anything yet, but we still want to pass on metadata such as last_insert_id
offset -= inputSize
if !l.mustRetrieveAll(vcursor) {
if !wantfields && !l.mustRetrieveAll(vcursor) {
return nil
}
if len(qr.Fields) > 0 {
wantfields = false
}
qr.Rows = nil
return callback(qr)
}
Expand All @@ -147,16 +134,23 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
// At this point, we've dealt with the offset. Now handle the count (limit).
if count == 0 {
// If count is zero, we've fetched everything we need.
if !l.mustRetrieveAll(vcursor) {
if !wantfields && !l.mustRetrieveAll(vcursor) {
return io.EOF
}
if len(qr.Fields) > 0 {
wantfields = false
}

// If we require the complete input, or we are in a transaction, we cannot return io.EOF early.
// Instead, we return empty results as needed until input ends.
qr.Rows = nil
return callback(qr)
}

if len(qr.Fields) > 0 {
wantfields = false
}

// reduce count till 0.
resultSize := len(qr.Rows)
if count > resultSize {
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/engine/merge_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ import (
"io"

"vitess.io/vitess/go/mysql/sqlerror"
"vitess.io/vitess/go/vt/vtgate/evalengine"

"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// StreamExecutor is a subset of Primitive that MergeSort
Expand Down Expand Up @@ -216,9 +214,10 @@ func (ms *MergeSort) description() PrimitiveDescription {
// routine that pulls the rows out of each streamHandle can abort the stream
// by calling canceling the context.
type streamHandle struct {
fields chan []*querypb.Field
row chan []sqltypes.Value
err error
fields chan []*querypb.Field
fieldSeen bool
row chan []sqltypes.Value
err error
}

// runOnestream starts a streaming query on one shard, and returns a streamHandle for it.
Expand All @@ -233,7 +232,8 @@ func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bi
defer close(handle.row)

handle.err = input.StreamExecute(ctx, vcursor, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
if !handle.fieldSeen && len(qr.Fields) != 0 {
handle.fieldSeen = true
select {
case handle.fields <- qr.Fields:
case <-ctx.Done():
Expand Down
43 changes: 43 additions & 0 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3302,6 +3302,49 @@ func TestSelectFromInformationSchema(t *testing.T) {
sbc1.StringQueries())
}

func TestStreamOrderByWithMultipleResults(t *testing.T) {
ctx := utils.LeakCheckContext(t)

// Special setup: Don't use createExecutorEnv.
cell := "aa"
hc := discovery.NewFakeHealthCheck(nil)
u := createSandbox(KsTestUnsharded)
s := createSandbox(KsTestSharded)
s.VSchema = executorVSchema
u.VSchema = unshardedVSchema
serv := newSandboxForCells(ctx, []string{cell})
resolver := newTestResolver(ctx, hc, serv, cell)
shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"}
count := 1
for _, shard := range shards {
sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil)
sbc.SetResults([]*sqltypes.Result{
sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count, count)),
sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count+10, count)),
})
count++
}
queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize)
plans := DefaultPlanCache()
executor := NewExecutor(ctx, vtenv.NewTestEnv(), serv, cell, resolver, true, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, 0)
executor.SetQueryLogger(queryLogger)
defer executor.Close()
// some sleep for all goroutines to start
time.Sleep(100 * time.Millisecond)
before := runtime.NumGoroutine()

query := "select id, col from user order by id"
gotResult, err := executorStream(ctx, executor, query)
require.NoError(t, err)

wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"),
"1|1", "2|2", "3|3", "4|4", "5|5", "6|6", "7|7", "8|8", "11|1", "12|2", "13|3", "14|4", "15|5", "16|6", "17|7", "18|8")
assert.Equal(t, fmt.Sprintf("%v", wantResult.Rows), fmt.Sprintf("%v", gotResult.Rows))
// some sleep to close all goroutines.
time.Sleep(100 * time.Millisecond)
assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering")
}

func TestStreamOrderByLimitWithMultipleResults(t *testing.T) {
ctx := utils.LeakCheckContext(t)

Expand Down

0 comments on commit 9dc7651

Please sign in to comment.