Skip to content

Commit

Permalink
RUST-1274 Fix commitTransaction on check out retries (#651)
Browse files Browse the repository at this point in the history
This also fixes RUST-1317.
  • Loading branch information
patrickfreed authored May 16, 2022
1 parent cb45c29 commit ab491ae
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 27 deletions.
47 changes: 28 additions & 19 deletions src/client/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,27 +351,15 @@ impl Client {

let retryability = self.get_retryability(&conn, &op, &session)?;

let txn_number = match session {
Some(ref mut session) => {
if session.transaction.state != TransactionState::None {
Some(session.txn_number())
} else {
match retryability {
Retryability::Write => Some(session.get_and_increment_txn_number()),
_ => None,
}
}
}
None => None,
};
let txn_number = get_txn_number(&mut session, retryability);

match self
.execute_operation_on_connection(
&mut op,
&mut conn,
&mut session,
txn_number,
&retryability,
retryability,
)
.await
{
Expand Down Expand Up @@ -424,7 +412,7 @@ impl Client {
&self,
op: &mut T,
session: &mut Option<&mut ClientSession>,
txn_number: Option<i64>,
prior_txn_number: Option<i64>,
first_error: Error,
) -> Result<ExecutionOutput<T>> {
op.update_for_retry();
Expand All @@ -446,8 +434,10 @@ impl Client {
return Err(first_error);
}

let txn_number = prior_txn_number.or_else(|| get_txn_number(session, retryability));

match self
.execute_operation_on_connection(op, &mut conn, session, txn_number, &retryability)
.execute_operation_on_connection(op, &mut conn, session, txn_number, retryability)
.await
{
Ok(operation_output) => Ok(ExecutionOutput {
Expand Down Expand Up @@ -481,7 +471,7 @@ impl Client {
connection: &mut Connection,
session: &mut Option<&mut ClientSession>,
txn_number: Option<i64>,
retryability: &Retryability,
retryability: Retryability,
) -> Result<T::O> {
if let Some(wc) = op.write_concern() {
wc.validate()?;
Expand Down Expand Up @@ -918,6 +908,25 @@ async fn get_connection<T: Operation>(
}
}

fn get_txn_number(
session: &mut Option<&mut ClientSession>,
retryability: Retryability,
) -> Option<i64> {
match session {
Some(ref mut session) => {
if session.transaction.state != TransactionState::None {
Some(session.txn_number())
} else {
match retryability {
Retryability::Write => Some(session.get_and_increment_txn_number()),
_ => None,
}
}
}
None => None,
}
}

impl Error {
/// Adds the necessary labels to this Error, and unpins the session if needed.
///
Expand All @@ -936,7 +945,7 @@ impl Error {
&mut self,
conn: Option<&Connection>,
session: &mut Option<&mut ClientSession>,
retryability: Option<&Retryability>,
retryability: Option<Retryability>,
) -> Result<()> {
let transaction_state = session.as_ref().map_or(&TransactionState::None, |session| {
&session.transaction.state
Expand Down Expand Up @@ -970,7 +979,7 @@ impl Error {
}
}
TransactionState::None => {
if retryability == Some(&Retryability::Write) {
if retryability == Some(Retryability::Write) {
if let Some(max_wire_version) = max_wire_version {
if self.should_add_retryable_write_label(max_wire_version) {
self.add_label(RETRYABLE_WRITE_ERROR);
Expand Down
11 changes: 11 additions & 0 deletions src/event/sdam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ pub struct ServerDescriptionChangedEvent {
pub new_description: ServerDescription,
}

impl ServerDescriptionChangedEvent {
#[cfg(test)]
pub(crate) fn is_marked_unknown_event(&self) -> bool {
self.previous_description
.description
.server_type
.is_available()
&& self.new_description.description.server_type == crate::ServerType::Unknown
}
}

/// Published when a server is initialized.
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[non_exhaustive]
Expand Down
2 changes: 1 addition & 1 deletion src/operation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ where
}
}

#[derive(Debug, PartialEq)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) enum Retryability {
Write,
Read,
Expand Down
1 change: 1 addition & 0 deletions src/sdam/description/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl PartialEq for ServerDescription {

self_response == other_response
}
(Err(self_err), Err(other_err)) => self_err == other_err,
_ => false,
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/sdam/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ impl HeartbeatMonitor {
let mut topology_check_requests_subscriber =
topology.subscribe_to_topology_check_requests();

if self.check_server(&topology, &server).await {
topology.notify_topology_changed();
}
self.check_server(&topology, &server).await;
topology.notify_topology_changed();

// drop strong reference to topology before going back to sleep in case it drops off
// in between checks.
Expand Down
149 changes: 147 additions & 2 deletions src/test/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Cow, collections::HashMap, time::Duration};
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};

use bson::Document;
use serde::Deserialize;
Expand All @@ -7,11 +7,25 @@ use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
use crate::{
bson::{doc, Bson},
error::{CommandError, Error, ErrorKind},
hello::LEGACY_HELLO_COMMAND_NAME,
options::{AuthMechanism, ClientOptions, Credential, ListDatabasesOptions, ServerAddress},
runtime,
selection_criteria::{ReadPreference, ReadPreferenceOptions, SelectionCriteria},
test::{log_uncaptured, util::TestClient, CLIENT_OPTIONS, LOCK},
test::{
log_uncaptured,
util::TestClient,
CmapEvent,
Event,
EventHandler,
FailCommandOptions,
FailPoint,
FailPointMode,
SdamEvent,
CLIENT_OPTIONS,
LOCK,
},
Client,
ServerType,
};

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -663,3 +677,134 @@ async fn plain_auth() {
}
);
}

/// Test verifies that retrying a commitTransaction operation after a checkOut
/// failure works.
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn retry_commit_txn_check_out() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;

let setup_client = TestClient::new().await;
if !setup_client.is_replica_set() {
log_uncaptured("skipping retry_commit_txn_check_out due to non-replicaset topology");
return;
}

if !setup_client.supports_transactions() {
log_uncaptured("skipping retry_commit_txn_check_out due to lack of transaction support");
return;
}

if !setup_client.supports_fail_command_appname_initial_handshake() {
log_uncaptured(
"skipping retry_commit_txn_check_out due to insufficient failCommand support",
);
return;
}

// ensure namespace exists
setup_client
.database("retry_commit_txn_check_out")
.collection("retry_commit_txn_check_out")
.insert_one(doc! {}, None)
.await
.unwrap();

let mut options = CLIENT_OPTIONS.clone();
let handler = Arc::new(EventHandler::new());
options.cmap_event_handler = Some(handler.clone());
options.sdam_event_handler = Some(handler.clone());
options.heartbeat_freq = Some(Duration::from_secs(120));
options.app_name = Some("retry_commit_txn_check_out".to_string());
let client = Client::with_options(options).unwrap();

let mut session = client.start_session(None).await.unwrap();
session.start_transaction(None).await.unwrap();
// transition transaction to "in progress" so that the commit
// actually executes an operation.
client
.database("retry_commit_txn_check_out")
.collection("retry_commit_txn_check_out")
.insert_one_with_session(doc! {}, None, &mut session)
.await
.unwrap();

// enable a fail point that clears the connection pools so that
// commitTransaction will create a new connection during check out.
let fp = FailPoint::fail_command(
&["ping"],
FailPointMode::Times(1),
FailCommandOptions::builder().error_code(11600).build(),
);
let _guard = setup_client.enable_failpoint(fp, None).await.unwrap();

let mut subscriber = handler.subscribe();
client
.database("foo")
.run_command(doc! { "ping": 1 }, None)
.await
.unwrap_err();

// failing with a state change error will request an immediate check
// wait for the mark unknown and subsequent succeeded heartbeat
let mut primary = None;
subscriber
.wait_for_event(Duration::from_secs(1), |e| {
if let Event::Sdam(SdamEvent::ServerDescriptionChanged(event)) = e {
if event.is_marked_unknown_event() {
primary = Some(event.address.clone());
return true;
}
}
false
})
.await
.expect("should see marked unknown event");

subscriber
.wait_for_event(Duration::from_secs(1), |e| {
if let Event::Sdam(SdamEvent::ServerDescriptionChanged(event)) = e {
if &event.address == primary.as_ref().unwrap()
&& event.previous_description.server_type() == ServerType::Unknown
{
return true;
}
}
false
})
.await
.expect("should see mark available event");

// enable a failpoint on the handshake to cause check_out
// to fail with a retryable error
let fp = FailPoint::fail_command(
&[LEGACY_HELLO_COMMAND_NAME, "hello"],
FailPointMode::Times(1),
FailCommandOptions::builder()
.error_code(11600)
.app_name("retry_commit_txn_check_out".to_string())
.build(),
);
let _guard2 = setup_client.enable_failpoint(fp, None).await.unwrap();

// finally, attempt the commit.
// this should succeed due to retry
session.commit_transaction().await.unwrap();

// ensure the first check out attempt fails
subscriber
.wait_for_event(Duration::from_secs(1), |e| {
matches!(e, Event::Cmap(CmapEvent::ConnectionCheckOutFailed(_)))
})
.await
.expect("should see check out failed event");

// ensure the second one succeeds
subscriber
.wait_for_event(Duration::from_secs(1), |e| {
matches!(e, Event::Cmap(CmapEvent::ConnectionCheckedOut(_)))
})
.await
.expect("should see checked out event");
}
4 changes: 2 additions & 2 deletions src/test/util/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,9 @@ pub struct EventSubscriber<'a> {
}

impl EventSubscriber<'_> {
pub async fn wait_for_event<F>(&mut self, timeout: Duration, filter: F) -> Option<Event>
pub async fn wait_for_event<F>(&mut self, timeout: Duration, mut filter: F) -> Option<Event>
where
F: Fn(&Event) -> bool,
F: FnMut(&Event) -> bool,
{
runtime::timeout(timeout, async {
loop {
Expand Down
15 changes: 15 additions & 0 deletions src/test/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,21 @@ impl TestClient {
version.matches(&self.server_version)
}

/// Whether the deployment supports failing the initial handshake
/// only when it uses a specified appName.
///
/// See SERVER-49336 for more info.
pub fn supports_fail_command_appname_initial_handshake(&self) -> bool {
let requirements = [
VersionReq::parse(">= 4.2.15, < 4.3.0").unwrap(),
VersionReq::parse(">= 4.4.7, < 4.5.0").unwrap(),
VersionReq::parse(">= 4.9.0").unwrap(),
];
requirements
.iter()
.any(|req| req.matches(&self.server_version))
}

pub fn supports_transactions(&self) -> bool {
self.is_replica_set() && self.server_version_gte(4, 0)
|| self.is_sharded() && self.server_version_gte(4, 2)
Expand Down

0 comments on commit ab491ae

Please sign in to comment.