diff --git a/src/client/executor.rs b/src/client/executor.rs index a59f4adcf5..3b252cd8fb 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -159,7 +159,13 @@ impl Client { T: DeserializeOwned + Unpin + Send + Sync, { let mut details = self.execute_operation_with_details(op, session).await?; - let pinned = self.pin_connection_for_cursor(&mut details.output)?; + let pinned = if details.output.connection.is_pinned() { + // Cursor operations on load-balanced transactions will be pinned via the transaction + // pin. + None + } else { + self.pin_connection_for_cursor(&mut details.output)? + }; Ok(SessionCursor::new( self.clone(), details.output.operation_output, @@ -167,6 +173,10 @@ impl Client { )) } + fn is_load_balanced(&self) -> bool { + self.inner.options.load_balanced.unwrap_or(false) + } + fn pin_connection_for_cursor( &self, details: &mut ExecutionOutput, @@ -174,8 +184,7 @@ impl Client { where Op: Operation>, { - let is_load_balanced = self.inner.options.load_balanced.unwrap_or(false); - if is_load_balanced && details.operation_output.info.id != 0 { + if self.is_load_balanced() && details.operation_output.info.id != 0 { Ok(Some(details.connection.pin()?)) } else { Ok(None) @@ -205,7 +214,7 @@ impl Client { let selection_criteria = session .as_ref() - .and_then(|s| s.transaction.pinned_mongos.as_ref()) + .and_then(|s| s.transaction.pinned_mongos()) .or_else(|| op.selection_criteria()); let server = match self.select_server(selection_criteria).await { @@ -317,9 +326,20 @@ impl Client { } }; - let mut conn = match op.pinned_connection() { - Some(c) => c.take_connection().await?, - None => match server.pool.check_out().await { + let session_pinned = session + .as_ref() + .and_then(|s| s.transaction.pinned_connection()); + let mut conn = match (session_pinned, op.pinned_connection()) { + (Some(c), None) | (None, Some(c)) => c.take_connection().await?, + (Some(c), Some(_)) => { + // An operation executing in a transaction should never have a pinned connection, + // but in case it does, prefer the transaction's pin. + if cfg!(debug_assertions) { + panic!("pinned operation executing in pinned transaction"); + } + c.take_connection().await? + } + (None, None) => match server.pool.check_out().await { Ok(c) => c, Err(_) => return Err(first_error), }, @@ -411,7 +431,9 @@ impl Client { cmd.set_start_transaction(); cmd.set_autocommit(); cmd.set_txn_read_concern(*session); - if is_sharded { + if self.is_load_balanced() { + session.pin_connection(connection.pin()?); + } else if is_sharded { session.pin_mongos(connection.address().clone()); } session.transaction.state = TransactionState::InProgress; @@ -816,7 +838,7 @@ impl Error { if self.contains_label(TRANSIENT_TRANSACTION_ERROR) || self.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) { - session.unpin_mongos(); + session.unpin(); } } diff --git a/src/client/session/mod.rs b/src/client/session/mod.rs index 0ee7038124..477b282546 100644 --- a/src/client/session/mod.rs +++ b/src/client/session/mod.rs @@ -14,6 +14,7 @@ use uuid::Uuid; use crate::{ bson::{doc, spec::BinarySubtype, Binary, Bson, Document, Timestamp}, + cmap::conn::PinnedConnectionHandle, error::{ErrorKind, Result}, operation::{AbortTransaction, CommitTransaction, Operation}, options::{SessionOptions, TransactionOptions}, @@ -110,11 +111,11 @@ pub struct ClientSession { pub(crate) snapshot_time: Option, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub(crate) struct Transaction { pub(crate) state: TransactionState, pub(crate) options: Option, - pub(crate) pinned_mongos: Option, + pub(crate) pinned: Option, pub(crate) recovery_token: Option, } @@ -132,15 +133,38 @@ impl Transaction { pub(crate) fn abort(&mut self) { self.state = TransactionState::Aborted; self.options = None; - self.pinned_mongos = None; + self.pinned = None; } pub(crate) fn reset(&mut self) { self.state = TransactionState::None; self.options = None; - self.pinned_mongos = None; + self.pinned = None; self.recovery_token = None; } + + pub(crate) fn pinned_mongos(&self) -> Option<&SelectionCriteria> { + match &self.pinned { + Some(TransactionPin::Mongos(s)) => Some(s), + _ => None, + } + } + + pub(crate) fn pinned_connection(&self) -> Option<&PinnedConnectionHandle> { + match &self.pinned { + Some(TransactionPin::Connection(c)) => Some(c), + _ => None, + } + } + + fn take(&mut self) -> Self { + Transaction { + state: self.state.clone(), + options: self.options.take(), + pinned: self.pinned.take(), + recovery_token: self.recovery_token.take(), + } + } } impl Default for Transaction { @@ -148,7 +172,7 @@ impl Default for Transaction { Self { state: TransactionState::None, options: None, - pinned_mongos: None, + pinned: None, recovery_token: None, } } @@ -168,6 +192,12 @@ pub(crate) enum TransactionState { Aborted, } +#[derive(Debug)] +pub(crate) enum TransactionPin { + Mongos(SelectionCriteria), + Connection(PinnedConnectionHandle), +} + impl ClientSession { /// Creates a new `ClientSession` wrapping the provided server session. pub(crate) fn new( @@ -256,13 +286,18 @@ impl ClientSession { /// Pin mongos to session. pub(crate) fn pin_mongos(&mut self, address: ServerAddress) { - self.transaction.pinned_mongos = Some(SelectionCriteria::Predicate(Arc::new( - move |server_info: &ServerInfo| *server_info.address() == address, + self.transaction.pinned = Some(TransactionPin::Mongos(SelectionCriteria::Predicate( + Arc::new(move |server_info: &ServerInfo| *server_info.address() == address), ))); } - pub(crate) fn unpin_mongos(&mut self) { - self.transaction.pinned_mongos = None; + /// Pin the connection to the session. + pub(crate) fn pin_connection(&mut self, handle: PinnedConnectionHandle) { + self.transaction.pinned = Some(TransactionPin::Connection(handle)); + } + + pub(crate) fn unpin(&mut self) { + self.transaction.pinned = None; } /// Whether this session is dirty. @@ -319,7 +354,7 @@ impl ClientSession { .into()); } TransactionState::Committed { .. } => { - self.unpin_mongos(); // Unpin session if previous transaction is committed. + self.unpin(); // Unpin session if previous transaction is committed. } _ => {} } @@ -495,8 +530,8 @@ impl ClientSession { .as_ref() .and_then(|options| options.write_concern.as_ref()) .cloned(); - let selection_criteria = self.transaction.pinned_mongos.clone(); - let abort_transaction = AbortTransaction::new(write_concern, selection_criteria); + let abort_transaction = + AbortTransaction::new(write_concern, self.transaction.pinned.take()); self.transaction.abort(); // Errors returned from running an abortTransaction command should be ignored. let _result = self @@ -549,7 +584,7 @@ impl Drop for ClientSession { client: self.client.clone(), is_implicit: self.is_implicit, options: self.options.clone(), - transaction: self.transaction.clone(), + transaction: self.transaction.take(), snapshot_time: self.snapshot_time, }; RUNTIME.execute(async move { diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index 18bd8294ff..ab11a515b6 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -313,6 +313,11 @@ impl Connection { }) } + /// Whether this connection has a live `PinnedConnectionHandle`. + pub(crate) fn is_pinned(&self) -> bool { + self.pinned_sender.is_some() + } + /// Close this connection, emitting a `ConnectionClosedEvent` with the supplied reason. pub(super) fn close_and_drop(mut self, reason: ConnectionClosedReason) { self.close(reason); diff --git a/src/operation/abort_transaction/mod.rs b/src/operation/abort_transaction/mod.rs index 79be500409..f9d5ea682e 100644 --- a/src/operation/abort_transaction/mod.rs +++ b/src/operation/abort_transaction/mod.rs @@ -2,7 +2,8 @@ use bson::Document; use crate::{ bson::doc, - cmap::{Command, StreamDescription}, + client::session::TransactionPin, + cmap::{conn::PinnedConnectionHandle, Command, StreamDescription}, error::Result, operation::{Operation, Retryability}, options::WriteConcern, @@ -13,17 +14,14 @@ use super::{CommandResponse, Response, WriteConcernOnlyBody}; pub(crate) struct AbortTransaction { write_concern: Option, - selection_criteria: Option, + pinned: Option, } impl AbortTransaction { - pub(crate) fn new( - write_concern: Option, - selection_criteria: Option, - ) -> Self { + pub(crate) fn new(write_concern: Option, pinned: Option) -> Self { Self { write_concern, - selection_criteria, + pinned, } } } @@ -59,7 +57,17 @@ impl Operation for AbortTransaction { } fn selection_criteria(&self) -> Option<&SelectionCriteria> { - self.selection_criteria.as_ref() + match &self.pinned { + Some(TransactionPin::Mongos(s)) => Some(s), + _ => None, + } + } + + fn pinned_connection(&self) -> Option<&PinnedConnectionHandle> { + match &self.pinned { + Some(TransactionPin::Connection(h)) => Some(h), + _ => None, + } } fn write_concern(&self) -> Option<&WriteConcern> { @@ -72,6 +80,6 @@ impl Operation for AbortTransaction { fn update_for_retry(&mut self) { // The session must be "unpinned" before server selection for a retry. - self.selection_criteria = None; + self.pinned = None; } } diff --git a/src/test/spec/unified_runner/operation.rs b/src/test/spec/unified_runner/operation.rs index 66178a0c0f..3bb3b137bf 100644 --- a/src/test/spec/unified_runner/operation.rs +++ b/src/test/spec/unified_runner/operation.rs @@ -1077,8 +1077,8 @@ impl TestOperation for TargetedFailPoint { let session = test_runner.get_session(&self.session); let selection_criteria = session .transaction - .pinned_mongos - .clone() + .pinned_mongos() + .cloned() .unwrap_or_else(|| panic!("ClientSession not pinned")); let fail_point_guard = test_runner .internal_client @@ -1312,7 +1312,7 @@ impl TestOperation for AssertSessionPinned { assert!(test_runner .get_session(&self.session) .transaction - .pinned_mongos + .pinned_mongos() .is_some()); } .boxed() @@ -1334,7 +1334,7 @@ impl TestOperation for AssertSessionUnpinned { assert!(test_runner .get_session(&self.session) .transaction - .pinned_mongos + .pinned_mongos() .is_none()); } .boxed() diff --git a/src/test/spec/v2_runner/mod.rs b/src/test/spec/v2_runner/mod.rs index 75ed3776f8..a95c8e64dd 100644 --- a/src/test/spec/v2_runner/mod.rs +++ b/src/test/spec/v2_runner/mod.rs @@ -263,8 +263,8 @@ pub async fn run_v2_test(test_file: TestFile) { let selection_criteria = session .unwrap() .transaction - .pinned_mongos - .clone() + .pinned_mongos() + .cloned() .unwrap_or_else(|| panic!("ClientSession is not pinned")); fail_point_guards.push( diff --git a/src/test/spec/v2_runner/operation.rs b/src/test/spec/v2_runner/operation.rs index ce6d5ad719..863f22d512 100644 --- a/src/test/spec/v2_runner/operation.rs +++ b/src/test/spec/v2_runner/operation.rs @@ -1056,7 +1056,7 @@ impl TestOperation for AssertSessionPinned { session: &'a mut ClientSession, ) -> BoxFuture<'a, Result>> { async move { - assert!(session.transaction.pinned_mongos.is_some()); + assert!(session.transaction.pinned_mongos().is_some()); Ok(None) } .boxed() @@ -1072,7 +1072,7 @@ impl TestOperation for AssertSessionUnpinned { session: &'a mut ClientSession, ) -> BoxFuture<'a, Result>> { async move { - assert!(session.transaction.pinned_mongos.is_none()); + assert!(session.transaction.pinned_mongos().is_none()); Ok(None) } .boxed()