Skip to content

Commit

Permalink
spanner: fix race condition in batch read-only tx
Browse files Browse the repository at this point in the history
BatchReadOnlyTransaction contained a possible race condition as the
same read or query request would be re-used for the different partitions
of a transaction. That could cause one of the partitions to execute the
wrong query, as it could be assigned the partition token of a different
partition if the partitions were being used in parallel by the user
application.

Fixes #1895.

Change-Id: Ie79c58d843ca59d1259f8fc7a78b00d9e7fa1d40
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/54190
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Hengfeng Li <hengfeng@google.com>
  • Loading branch information
olavloite committed Apr 7, 2020
1 parent 03bba21 commit b0b68af
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 30 deletions.
14 changes: 8 additions & 6 deletions spanner/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,18 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
}
// Read or query partition.
if p.rreq != nil {
p.rreq.PartitionToken = p.pt
req := *p.rreq
req.PartitionToken = p.pt
rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
p.rreq.ResumeToken = resumeToken
return client.StreamingRead(ctx, p.rreq)
req.ResumeToken = resumeToken
return client.StreamingRead(ctx, &req)
}
} else {
p.qreq.PartitionToken = p.pt
req := *p.qreq
req.PartitionToken = p.pt
rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
p.qreq.ResumeToken = resumeToken
return client.ExecuteStreamingSql(ctx, p.qreq)
req.ResumeToken = resumeToken
return client.ExecuteStreamingSql(ctx, &req)
}
}
return stream(
Expand Down
52 changes: 52 additions & 0 deletions spanner/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package spanner
import (
"context"
"os"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -116,3 +117,54 @@ func TestPartitionQuery_QueryOptions(t *testing.T) {
})
}
}

func TestPartitionQuery_Parallel(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServer(t)
defer teardown()

txn, err := client.BatchReadOnlyTransaction(ctx, StrongRead())
if err != nil {
t.Fatal(err)
}
defer txn.Cleanup(ctx)
ps, err := txn.PartitionQuery(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), PartitionOptions{0, 10})
if err != nil {
t.Fatal(err)
}
for i, p := range ps {
server.TestSpanner.PutPartitionResult(p.pt, server.CreateSingleRowSingersResult(int64(i)))
}

wg := &sync.WaitGroup{}
mu := sync.Mutex{}
var total int64

for _, p := range ps {
p := p
go func() {
iter := txn.Execute(context.Background(), p)
defer iter.Stop()

var count int64
err := iter.Do(func(row *Row) error {
count++
return nil
})
if err != nil {
return
}

mu.Lock()
total += count
mu.Unlock()
wg.Done()
}()
wg.Add(1)
}

wg.Wait()
if g, w := total, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; g != w {
t.Errorf("Row count mismatch\nGot: %d\nWant: %d", g, w)
}
}
41 changes: 39 additions & 2 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ type InMemSpannerServer interface {
// expect a SQL statement, including (batch) DML methods.
PutStatementResult(sql string, result *StatementResult) error

// Puts a mocked result on the server for a specific partition token. The
// result will only be used for query requests that specify a partition
// token.
PutPartitionResult(partitionToken []byte, result *StatementResult) error

// Adds a PartialResultSetExecutionTime to the server that should be returned
// for the specified SQL string.
AddPartialResultSetError(sql string, err PartialResultSetExecutionTime)
Expand Down Expand Up @@ -248,6 +253,7 @@ type inMemSpannerServer struct {
partitionedDmlTransactions map[string]bool
// The mocked results for this server.
statementResults map[string]*StatementResult
partitionResults map[string]*StatementResult
// The simulated execution times per method.
executionTimes map[string]*SimulatedExecutionTime
// The simulated errors for partial result sets
Expand All @@ -271,6 +277,7 @@ func NewInMemSpannerServer() InMemSpannerServer {
res := &inMemSpannerServer{}
res.initDefaults()
res.statementResults = make(map[string]*StatementResult)
res.partitionResults = make(map[string]*StatementResult)
res.executionTimes = make(map[string]*SimulatedExecutionTime)
res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime)
res.receivedRequests = make(chan interface{}, 1000000)
Expand Down Expand Up @@ -318,6 +325,15 @@ func (s *inMemSpannerServer) RemoveStatementResult(sql string) {
delete(s.statementResults, sql)
}

// Registers a mocked result for a partition token on the server.
func (s *inMemSpannerServer) PutPartitionResult(partitionToken []byte, result *StatementResult) error {
tokenString := string(partitionToken)
s.mu.Lock()
defer s.mu.Unlock()
s.partitionResults[tokenString] = result
return nil
}

func (s *inMemSpannerServer) AbortTransaction(id []byte) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down Expand Up @@ -527,6 +543,17 @@ func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
delete(s.partitionedDmlTransactions, string(tx.Id))
}

func (s *inMemSpannerServer) getPartitionResult(partitionToken []byte) (*StatementResult, error) {
tokenString := string(partitionToken)
s.mu.Lock()
defer s.mu.Unlock()
result, ok := s.partitionResults[tokenString]
if !ok {
return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for partition token %v", tokenString))
}
return result, nil
}

func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down Expand Up @@ -711,7 +738,12 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
return nil, err
}
}
statementResult, err := s.getStatementResult(req.Sql)
var statementResult *StatementResult
if req.PartitionToken != nil {
statementResult, err = s.getPartitionResult(req.PartitionToken)
} else {
statementResult, err = s.getStatementResult(req.Sql)
}
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -748,7 +780,12 @@ func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlReques
return err
}
}
statementResult, err := s.getStatementResult(req.Sql)
var statementResult *StatementResult
if req.PartitionToken != nil {
statementResult, err = s.getPartitionResult(req.PartitionToken)
} else {
statementResult, err = s.getStatementResult(req.Sql)
}
if err != nil {
return err
}
Expand Down
76 changes: 54 additions & 22 deletions spanner/internal/testutil/mocked_inmem_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,45 @@ func (s *MockedSpannerInMemTestServer) setupFooResults() {
}

func (s *MockedSpannerInMemTestServer) setupSingersResults() {
metadata := createSingersMetadata()
rows := make([]*structpb.ListValue, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
var idx int64
for idx = 0; idx < SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ {
rows[idx] = createSingersRow(idx)
}
resultSet := &spannerpb.ResultSet{
Metadata: metadata,
Rows: rows,
}
result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet}
s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result)
}

// CreateSingleRowSingersResult creates a result set containing a single row of
// the SelectSingerIDAlbumIDAlbumTitleFromAlbums result set, or zero rows if
// the given rowNum is greater than the number of rows in the result set. This
// method can be used to mock results for different partitions of a
// BatchReadOnlyTransaction.
func (s *MockedSpannerInMemTestServer) CreateSingleRowSingersResult(rowNum int64) *StatementResult {
metadata := createSingersMetadata()
var returnedRows int
if rowNum < SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
returnedRows = 1
} else {
returnedRows = 0
}
rows := make([]*structpb.ListValue, returnedRows)
if returnedRows > 0 {
rows[0] = createSingersRow(rowNum)
}
resultSet := &spannerpb.ResultSet{
Metadata: metadata,
Rows: rows,
}
return &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet}
}

func createSingersMetadata() *spannerpb.ResultSetMetadata {
fields := make([]*spannerpb.StructType_Field, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount)
fields[0] = &spannerpb.StructType_Field{
Name: "SingerId",
Expand All @@ -159,30 +198,23 @@ func (s *MockedSpannerInMemTestServer) setupSingersResults() {
rowType := &spannerpb.StructType{
Fields: fields,
}
metadata := &spannerpb.ResultSetMetadata{
return &spannerpb.ResultSetMetadata{
RowType: rowType,
}
rows := make([]*structpb.ListValue, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
var idx int64
for idx = 0; idx < SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ {
rowValue := make([]*structpb.Value, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount)
rowValue[0] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx+1, 10)},
}
rowValue[1] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx*10+idx, 10)},
}
rowValue[2] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("Album title %d", idx)},
}
rows[idx] = &structpb.ListValue{
Values: rowValue,
}
}

func createSingersRow(idx int64) *structpb.ListValue {
rowValue := make([]*structpb.Value, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount)
rowValue[0] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx+1, 10)},
}
resultSet := &spannerpb.ResultSet{
Metadata: metadata,
Rows: rows,
rowValue[1] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx*10+idx, 10)},
}
rowValue[2] = &structpb.Value{
Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("Album title %d", idx)},
}
return &structpb.ListValue{
Values: rowValue,
}
result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet}
s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result)
}

0 comments on commit b0b68af

Please sign in to comment.