diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 76a6d5121a..0139d273da 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -768,6 +768,77 @@ func TestClient(t *testing.T) { "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) } }) + + opts := mtest.NewOptions(). + // Blocking failpoints don't work on pre-4.2 and sharded clusters. + Topologies(mtest.Single, mtest.ReplicaSet). + MinServerVersion("4.2"). + // Expliticly enable retryable reads and retryable writes. + ClientOptions(options.Client().SetRetryReads(true).SetRetryWrites(true)) + mt.RunOpts("operations don't retry after a context timeout", opts, func(mt *mtest.T) { + testCases := []struct { + desc string + operation func(context.Context, *mongo.Collection) error + }{ + { + desc: "read op", + operation: func(ctx context.Context, coll *mongo.Collection) error { + return coll.FindOne(ctx, bson.D{}).Err() + }, + }, + { + desc: "write op", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.InsertOne(ctx, bson.D{}) + return err + }, + }, + } + + for _, tc := range testCases { + mt.Run(tc.desc, func(mt *mtest.T) { + _, err := mt.Coll.InsertOne(context.Background(), bson.D{}) + require.NoError(mt, err) + + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: "alwaysOn", + Data: mtest.FailPointData{ + FailCommands: []string{"find", "insert"}, + BlockConnection: true, + BlockTimeMS: 500, + }, + }) + + mt.ClearEvents() + + for i := 0; i < 50; i++ { + // Run 50 operations, each with a timeout of 50ms. Expect + // them to all return a timeout error because the failpoint + // blocks find operations for 500ms. Run 50 to increase the + // probability that an operation will time out in a way that + // can cause a retry. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + err = tc.operation(ctx, mt.Coll) + cancel() + assert.ErrorIs(mt, err, context.DeadlineExceeded) + assert.True(mt, mongo.IsTimeout(err), "expected mongo.IsTimeout(err) to be true") + + // Assert that each operation reported exactly one command + // started events, which means the operation did not retry + // after the context timeout. + evts := mt.GetAllStartedEvents() + require.Len(mt, + mt.GetAllStartedEvents(), + 1, + "expected exactly 1 command started event per operation, but got %d after %d iterations", + len(evts), + i) + mt.ClearEvents() + } + }) + } + }) } func TestClient_BSONOptions(t *testing.T) { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 8f87c21d3f..eb1acec88f 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -622,6 +622,13 @@ func (op Operation) Execute(ctx context.Context) error { } }() for { + // If we're starting a retry and the the error from the previous try was + // a context canceled or deadline exceeded error, stop retrying and + // return that error. + if errors.Is(prevErr, context.Canceled) || errors.Is(prevErr, context.DeadlineExceeded) { + return prevErr + } + requestID := wiremessage.NextRequestID() // If the server or connection are nil, try to select a new server and get a new connection.