Skip to content

Commit

Permalink
bring retry logic into spec
Browse files Browse the repository at this point in the history
  • Loading branch information
abr-egn committed Mar 23, 2022
1 parent d5da67c commit e9b1ce7
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 29 deletions.
21 changes: 15 additions & 6 deletions src/client/auth/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
client::{auth::AuthMechanism, options::ServerApi},
cmap::Command,
error::{Error, Result},
operation::{CommandErrorBody, CommandResponse},
};

/// Encapsulates the command building of a `saslStart` command.
Expand Down Expand Up @@ -98,12 +99,20 @@ fn validate_command_success(auth_mechanism: &str, response: &Document) -> Result

match bson_util::get_int(ok) {
Some(1) => Ok(()),
Some(_) => Err(Error::authentication_error(
auth_mechanism,
response
.get_str("errmsg")
.unwrap_or("Authentication failure"),
)),
Some(_) => {
let source = bson::from_bson::<CommandResponse<CommandErrorBody>>(Bson::Document(
response.clone(),
))
.map(|cmd_resp| cmd_resp.body.into())
.ok();
Err(Error::authentication_error(
auth_mechanism,
response
.get_str("errmsg")
.unwrap_or("Authentication failure"),
)
.with_source(source))
}
_ => Err(Error::invalid_authentication_response(auth_mechanism)),
}
}
Expand Down
68 changes: 45 additions & 23 deletions src/client/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,24 @@ impl Client {
Ok(c) => c,
Err(mut err) => {
err.add_labels_and_update_pin(None, &mut session, None)?;
if err.is_read_retryable() {
err.add_label(RETRYABLE_WRITE_ERROR);
}

if err.is_network_error() || err.is_auth_error() {
let op_retry = match self.get_op_retryability(&op, &session) {
Retryability::Read => err.is_read_retryable(),
Retryability::Write => err.is_write_retryable(),
_ => false,
};
if err.is_pool_cleared() || op_retry {
return self.execute_retry(&mut op, &mut session, None, err).await;
} else {
return Err(err);
}
}
};

let retryability = self.get_retryability(&conn, &op, &session).await?;
let retryability = self.get_retryability(&conn, &op, &session)?;

let txn_number = match session {
Some(ref mut session) => {
Expand Down Expand Up @@ -432,7 +440,7 @@ impl Client {
Err(_) => return Err(first_error),
};

let retryability = self.get_retryability(&conn, op, session).await?;
let retryability = self.get_retryability(&conn, op, session)?;
if retryability == Retryability::None {
return Err(first_error);
}
Expand Down Expand Up @@ -824,35 +832,49 @@ impl Client {
}

/// Returns the retryability level for the execution of this operation.
async fn get_retryability<T: Operation>(
fn get_op_retryability<T: Operation>(
&self,
conn: &Connection,
op: &T,
session: &Option<&mut ClientSession>,
) -> Result<Retryability> {
if !session
) -> Retryability {
if session
.as_ref()
.map(|session| session.in_transaction())
.unwrap_or(false)
{
match op.retryability() {
Retryability::Read if self.inner.options.retry_reads != Some(false) => {
return Ok(Retryability::Read);
}
Retryability::Write if conn.stream_description()?.supports_retryable_writes() => {
// commitTransaction and abortTransaction should be retried regardless of the
// value for retry_writes set on the Client
if op.name() == CommitTransaction::NAME
|| op.name() == AbortTransaction::NAME
|| self.inner.options.retry_writes != Some(false)
{
return Ok(Retryability::Write);
}
}
_ => {}
return Retryability::None;
}
match op.retryability() {
Retryability::Read if self.inner.options.retry_reads != Some(false) => {
Retryability::Read
}
// commitTransaction and abortTransaction should be retried regardless of the
// value for retry_writes set on the Client
Retryability::Write
if op.name() == CommitTransaction::NAME
|| op.name() == AbortTransaction::NAME
|| self.inner.options.retry_writes != Some(false) =>
{
Retryability::Write
}
_ => Retryability::None,
}
}

/// Returns the retryability level for the execution of this operation on this connection.
fn get_retryability<T: Operation>(
&self,
conn: &Connection,
op: &T,
session: &Option<&mut ClientSession>,
) -> Result<Retryability> {
match self.get_op_retryability(op, session) {
Retryability::Read => Ok(Retryability::Read),
Retryability::Write if conn.stream_description()?.supports_retryable_writes() => {
Ok(Retryability::Write)
}
_ => Ok(Retryability::None),
}
Ok(Retryability::None)
}

async fn update_cluster_time(
Expand Down
13 changes: 13 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub struct Error {
pub kind: Box<ErrorKind>,
labels: HashSet<String>,
pub(crate) wire_version: Option<i32>,
#[source]
pub(crate) source: Option<Box<Error>>,
}

impl Error {
Expand All @@ -61,6 +63,7 @@ impl Error {
kind: Box::new(kind),
labels,
wire_version: None,
source: None,
}
}

Expand Down Expand Up @@ -237,6 +240,7 @@ impl Error {
ErrorKind::Write(WriteFailure::WriteConcernError(wc_error)) => Some(wc_error.code),
_ => None,
}
.or_else(|| self.source.as_ref().and_then(|s| s.code()))
}

/// Gets the message for this error, if applicable, for use in testing.
Expand Down Expand Up @@ -314,6 +318,10 @@ impl Error {
.unwrap_or(false)
}

pub(crate) fn is_pool_cleared(&self) -> bool {
matches!(self.kind.as_ref(), ErrorKind::ConnectionPoolCleared { .. })
}

/// If this error is resumable as per the change streams spec.
pub(crate) fn is_resumable(&self) -> bool {
if !self.is_server_error() {
Expand Down Expand Up @@ -341,6 +349,11 @@ impl Error {
}
false
}

pub(crate) fn with_source<E: Into<Option<Error>>>(mut self, source: E) -> Self {
self.source = source.into().map(Box::new);
self
}
}

impl<E> From<E> for Error
Expand Down

0 comments on commit e9b1ce7

Please sign in to comment.