Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-1274 Fix commitTransaction on check out retries #651

Merged
merged 8 commits into from
May 16, 2022
47 changes: 28 additions & 19 deletions src/client/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,27 +351,15 @@ impl Client {

let retryability = self.get_retryability(&conn, &op, &session)?;

let txn_number = match session {
Some(ref mut session) => {
if session.transaction.state != TransactionState::None {
Some(session.txn_number())
} else {
match retryability {
Retryability::Write => Some(session.get_and_increment_txn_number()),
_ => None,
}
}
}
None => None,
};
let txn_number = get_txn_number(&mut session, retryability);

match self
.execute_operation_on_connection(
&mut op,
&mut conn,
&mut session,
txn_number,
&retryability,
retryability,
)
.await
{
Expand Down Expand Up @@ -424,7 +412,7 @@ impl Client {
&self,
op: &mut T,
session: &mut Option<&mut ClientSession>,
txn_number: Option<i64>,
prior_txn_number: Option<i64>,
first_error: Error,
) -> Result<ExecutionOutput<T>> {
op.update_for_retry();
Expand All @@ -446,8 +434,10 @@ impl Client {
return Err(first_error);
}

let txn_number = prior_txn_number.or_else(|| get_txn_number(session, retryability));

match self
.execute_operation_on_connection(op, &mut conn, session, txn_number, &retryability)
.execute_operation_on_connection(op, &mut conn, session, txn_number, retryability)
.await
{
Ok(operation_output) => Ok(ExecutionOutput {
Expand Down Expand Up @@ -481,7 +471,7 @@ impl Client {
connection: &mut Connection,
session: &mut Option<&mut ClientSession>,
txn_number: Option<i64>,
retryability: &Retryability,
retryability: Retryability,
) -> Result<T::O> {
if let Some(wc) = op.write_concern() {
wc.validate()?;
Expand Down Expand Up @@ -918,6 +908,25 @@ async fn get_connection<T: Operation>(
}
}

fn get_txn_number(
session: &mut Option<&mut ClientSession>,
retryability: Retryability,
) -> Option<i64> {
match session {
Some(ref mut session) => {
if session.transaction.state != TransactionState::None {
Some(session.txn_number())
} else {
match retryability {
Retryability::Write => Some(session.get_and_increment_txn_number()),
_ => None,
}
}
}
None => None,
}
}

impl Error {
/// Adds the necessary labels to this Error, and unpins the session if needed.
///
Expand All @@ -936,7 +945,7 @@ impl Error {
&mut self,
conn: Option<&Connection>,
session: &mut Option<&mut ClientSession>,
retryability: Option<&Retryability>,
retryability: Option<Retryability>,
) -> Result<()> {
let transaction_state = session.as_ref().map_or(&TransactionState::None, |session| {
&session.transaction.state
Expand Down Expand Up @@ -970,7 +979,7 @@ impl Error {
}
}
TransactionState::None => {
if retryability == Some(&Retryability::Write) {
if retryability == Some(Retryability::Write) {
if let Some(max_wire_version) = max_wire_version {
if self.should_add_retryable_write_label(max_wire_version) {
self.add_label(RETRYABLE_WRITE_ERROR);
Expand Down
11 changes: 11 additions & 0 deletions src/event/sdam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ pub struct ServerDescriptionChangedEvent {
pub new_description: ServerDescription,
}

impl ServerDescriptionChangedEvent {
#[cfg(test)]
pub(crate) fn is_marked_unknown_event(&self) -> bool {
self.previous_description
.description
.server_type
.is_available()
&& self.new_description.description.server_type == crate::ServerType::Unknown
}
}

/// Published when a server is initialized.
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[non_exhaustive]
Expand Down
2 changes: 1 addition & 1 deletion src/operation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ where
}
}

#[derive(Debug, PartialEq)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) enum Retryability {
Write,
Read,
Expand Down
1 change: 1 addition & 0 deletions src/sdam/description/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl PartialEq for ServerDescription {

self_response == other_response
}
(Err(self_err), Err(other_err)) => self_err == other_err,
_ => false,
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/sdam/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ impl HeartbeatMonitor {
let mut topology_check_requests_subscriber =
topology.subscribe_to_topology_check_requests();

if self.check_server(&topology, &server).await {
topology.notify_topology_changed();
}
self.check_server(&topology, &server).await;
topology.notify_topology_changed();
Copy link
Contributor Author

@patrickfreed patrickfreed May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once I added a commit to fix RUST-1317, some tests started failing due to this if statement. The problem is if the error stayed the same, this check would not call notify_topology_changed, which in turn would mean that the operation blocked on server selection would not try again and would not request another immediate check. This means that the blocked operation wouldn't succeed until the next heartbeat regularly scheduled, which could be 10s of seconds in the future.

Note that this new implementation matches the pseudocode in the server monitoring spec as well as the Python implementation. Also, by passing the test added in DRIVERS-1251, the C#, Go and Java likely have this behavior too. The spec does allow for our prior only-notify-if-changed behavior though, so I filed DRIVERS-2329 to require the behavior seen here.


// drop strong reference to topology before going back to sleep in case it drops off
// in between checks.
Expand Down
149 changes: 147 additions & 2 deletions src/test/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Cow, collections::HashMap, time::Duration};
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};

use bson::Document;
use serde::Deserialize;
Expand All @@ -7,11 +7,25 @@ use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
use crate::{
bson::{doc, Bson},
error::{CommandError, Error, ErrorKind},
hello::LEGACY_HELLO_COMMAND_NAME,
options::{AuthMechanism, ClientOptions, Credential, ListDatabasesOptions, ServerAddress},
runtime,
selection_criteria::{ReadPreference, ReadPreferenceOptions, SelectionCriteria},
test::{log_uncaptured, util::TestClient, CLIENT_OPTIONS, LOCK},
test::{
log_uncaptured,
util::TestClient,
CmapEvent,
Event,
EventHandler,
FailCommandOptions,
FailPoint,
FailPointMode,
SdamEvent,
CLIENT_OPTIONS,
LOCK,
},
Client,
ServerType,
};

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -663,3 +677,134 @@ async fn plain_auth() {
}
);
}

/// Test verifies that retrying a commitTransaction operation after a checkOut
/// failure works.
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn retry_commit_txn_check_out() {
kmahar marked this conversation as resolved.
Show resolved Hide resolved
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;

let setup_client = TestClient::new().await;
if !setup_client.is_replica_set() {
log_uncaptured("skipping retry_commit_txn_check_out due to non-replicaset topology");
return;
}

if !setup_client.supports_transactions() {
log_uncaptured("skipping retry_commit_txn_check_out due to lack of transaction support");
return;
}

if !setup_client.supports_fail_command_appname_initial_handshake() {
log_uncaptured(
"skipping retry_commit_txn_check_out due to insufficient failCommand support",
);
return;
}

// ensure namespace exists
setup_client
.database("retry_commit_txn_check_out")
.collection("retry_commit_txn_check_out")
.insert_one(doc! {}, None)
.await
.unwrap();

let mut options = CLIENT_OPTIONS.clone();
let handler = Arc::new(EventHandler::new());
options.cmap_event_handler = Some(handler.clone());
options.sdam_event_handler = Some(handler.clone());
options.heartbeat_freq = Some(Duration::from_secs(120));
options.app_name = Some("retry_commit_txn_check_out".to_string());
let client = Client::with_options(options).unwrap();

let mut session = client.start_session(None).await.unwrap();
session.start_transaction(None).await.unwrap();
// transition transaction to "in progress" so that the commit
// actually executes an operation.
client
.database("retry_commit_txn_check_out")
.collection("retry_commit_txn_check_out")
.insert_one_with_session(doc! {}, None, &mut session)
.await
.unwrap();

// enable a fail point that clears the connection pools so that
// commitTransaction will create a new connection during check out.
let fp = FailPoint::fail_command(
&["ping"],
FailPointMode::Times(1),
FailCommandOptions::builder().error_code(11600).build(),
);
let _guard = setup_client.enable_failpoint(fp, None).await.unwrap();

let mut subscriber = handler.subscribe();
client
.database("foo")
.run_command(doc! { "ping": 1 }, None)
.await
.unwrap_err();

// failing with a state change error will request an immediate check
// wait for the mark unknown and subsequent succeeded heartbeat
let mut primary = None;
subscriber
.wait_for_event(Duration::from_secs(1), |e| {
if let Event::Sdam(SdamEvent::ServerDescriptionChanged(event)) = e {
if event.is_marked_unknown_event() {
primary = Some(event.address.clone());
return true;
}
}
false
})
.await
.expect("should see marked unknown event");

subscriber
.wait_for_event(Duration::from_secs(1), |e| {
if let Event::Sdam(SdamEvent::ServerDescriptionChanged(event)) = e {
if &event.address == primary.as_ref().unwrap()
&& event.previous_description.server_type() == ServerType::Unknown
{
return true;
}
}
false
})
.await
.expect("should see mark available event");

// enable a failpoint on the handshake to cause check_out
// to fail with a retryable error
let fp = FailPoint::fail_command(
&[LEGACY_HELLO_COMMAND_NAME, "hello"],
FailPointMode::Times(1),
FailCommandOptions::builder()
.error_code(11600)
.app_name("retry_commit_txn_check_out".to_string())
.build(),
);
let _guard2 = setup_client.enable_failpoint(fp, None).await.unwrap();

// finally, attempt the commit.
// this should succeed due to retry
session.commit_transaction().await.unwrap();

// ensure the first check out attempt fails
subscriber
.wait_for_event(Duration::from_secs(1), |e| {
matches!(e, Event::Cmap(CmapEvent::ConnectionCheckOutFailed(_)))
})
.await
.expect("should see check out failed event");

// ensure the second one succeeds
subscriber
.wait_for_event(Duration::from_secs(1), |e| {
matches!(e, Event::Cmap(CmapEvent::ConnectionCheckedOut(_)))
})
.await
.expect("should see checked out event");
}
4 changes: 2 additions & 2 deletions src/test/util/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,9 @@ pub struct EventSubscriber<'a> {
}

impl EventSubscriber<'_> {
pub async fn wait_for_event<F>(&mut self, timeout: Duration, filter: F) -> Option<Event>
pub async fn wait_for_event<F>(&mut self, timeout: Duration, mut filter: F) -> Option<Event>
where
F: Fn(&Event) -> bool,
F: FnMut(&Event) -> bool,
{
runtime::timeout(timeout, async {
loop {
Expand Down
15 changes: 15 additions & 0 deletions src/test/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,21 @@ impl TestClient {
version.matches(&self.server_version)
}

/// Whether the deployment supports failing the initial handshake
/// only when it uses a specified appName.
///
/// See SERVER-49336 for more info.
pub fn supports_fail_command_appname_initial_handshake(&self) -> bool {
let requirements = [
VersionReq::parse(">= 4.2.15, < 4.3.0").unwrap(),
VersionReq::parse(">= 4.4.7, < 4.5.0").unwrap(),
VersionReq::parse(">= 4.9.0").unwrap(),
];
requirements
.iter()
.any(|req| req.matches(&self.server_version))
}

pub fn supports_transactions(&self) -> bool {
self.is_replica_set() && self.server_version_gte(4, 0)
|| self.is_sharded() && self.server_version_gte(4, 2)
Expand Down