Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dht): remove some invalid saf failure cases #4787

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions comms/core/src/protocol/rpc/server/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ pub enum RpcServerError {
ServiceCallExceededDeadline,
#[error("Stream read exceeded deadline")]
ReadStreamExceededDeadline,
#[error("Early close error: {0}")]
EarlyCloseError(#[from] EarlyCloseError<BytesMut>),
#[error("Early close: {0}")]
EarlyClose(#[from] EarlyCloseError<BytesMut>),
}

impl RpcServerError {
pub fn early_close_io(&self) -> Option<&io::Error> {
match self {
Self::EarlyClose(e) => e.io(),
_ => None,
}
}
}

impl From<oneshot::error::RecvError> for RpcServerError {
Expand Down
16 changes: 12 additions & 4 deletions comms/core/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use std::{
convert::TryFrom,
future::Future,
io,
io::ErrorKind,
pin::Pin,
sync::Arc,
task::Poll,
Expand Down Expand Up @@ -353,7 +354,7 @@ where
{
Ok(_) => {},
Err(err @ RpcServerError::HandshakeError(_)) => {
debug!(target: LOG_TARGET, "{}", err);
debug!(target: LOG_TARGET, "Handshake error: {}", err);
metrics::handshake_error_counter(&node_id, &notification.protocol).inc();
},
Err(err) => {
Expand Down Expand Up @@ -530,7 +531,7 @@ where
metrics::error_counter(&self.node_id, &self.protocol, &err).inc();
let level = match &err {
RpcServerError::Io(e) => err_to_log_level(e),
RpcServerError::EarlyCloseError(e) => e.io().map(err_to_log_level).unwrap_or(log::Level::Error),
RpcServerError::EarlyClose(e) => e.io().map(err_to_log_level).unwrap_or(log::Level::Error),
_ => log::Level::Error,
};
log!(
Expand Down Expand Up @@ -562,8 +563,10 @@ where
err,
);
}
error!(
let level = err.early_close_io().map(err_to_log_level).unwrap_or(log::Level::Error);
log!(
target: LOG_TARGET,
level,
"(peer: {}, protocol: {}) Failed to handle request: {}",
self.node_id,
self.protocol_name(),
Expand Down Expand Up @@ -880,8 +883,13 @@ fn into_response(request_id: u32, result: Result<BodyBytes, RpcStatus>) -> RpcRe
}

fn err_to_log_level(err: &io::Error) -> log::Level {
error!(target: LOG_TARGET, "KIND: {}", err.kind());
match err.kind() {
io::ErrorKind::BrokenPipe | io::ErrorKind::WriteZero => log::Level::Debug,
ErrorKind::ConnectionReset |
ErrorKind::ConnectionAborted |
ErrorKind::BrokenPipe |
ErrorKind::WriteZero |
ErrorKind::UnexpectedEof => log::Level::Debug,
_ => log::Level::Error,
}
}
16 changes: 16 additions & 0 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use tari_comms::{
pipeline::PipelineError,
};
use tari_shutdown::ShutdownSignal;
use tari_utilities::epoch_time::EpochTime;
use thiserror::Error;
use tokio::sync::{broadcast, mpsc};
use tower::{layer::Layer, Service, ServiceBuilder};
Expand Down Expand Up @@ -298,6 +299,7 @@ impl Dht {
.layer(MetricsLayer::new(self.metrics_collector.clone()))
.layer(inbound::DeserializeLayer::new(self.peer_manager.clone()))
.layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter()))
.layer(filter::FilterLayer::new(discard_expired_messages))
.layer(inbound::DecryptionLayer::new(
self.config.clone(),
self.node_identity.clone(),
Expand Down Expand Up @@ -432,6 +434,20 @@ fn filter_messages_to_rebroadcast(msg: &DecryptedDhtMessage) -> bool {
}
}

/// Check message expiry and immediately discard if expired
fn discard_expired_messages(msg: &DhtInboundMessage) -> bool {
if let Some(expires) = msg.dht_header.expires {
if expires < EpochTime::now() {
debug!(
target: LOG_TARGET,
"[discard_expired_messages] Discarding expired message {}", msg
);
return false;
}
}
true
}

#[cfg(test)]
mod test {
use std::{sync::Arc, time::Duration};
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::version::DhtProtocolVersion;
pub(crate) fn datetime_to_timestamp(datetime: DateTime<Utc>) -> Timestamp {
Timestamp {
seconds: datetime.timestamp(),
nanos: datetime.timestamp_subsec_nanos().try_into().unwrap_or(std::i32::MAX),
nanos: datetime.timestamp_subsec_nanos().try_into().unwrap_or(i32::MAX),
}
}

Expand Down
10 changes: 4 additions & 6 deletions comms/dht/src/store_forward/database/stored_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::convert::TryInto;

use chrono::NaiveDateTime;
use tari_comms::message::MessageExt;
use tari_utilities::{hex, hex::Hex};
Expand Down Expand Up @@ -50,7 +48,7 @@ pub struct NewStoredMessage {
}

impl NewStoredMessage {
pub fn try_construct(message: DecryptedDhtMessage, priority: StoredMessagePriority) -> Option<Self> {
pub fn new(message: DecryptedDhtMessage, priority: StoredMessagePriority) -> Self {
let DecryptedDhtMessage {
authenticated_origin,
decryption_result,
Expand All @@ -64,8 +62,8 @@ impl NewStoredMessage {
};
let body_hash = hex::to_hex(&dedup::create_message_hash(&dht_header.message_signature, &body));

Some(Self {
version: dht_header.version.as_major().try_into().ok()?,
Self {
version: dht_header.version.as_major() as i32,
origin_pubkey: authenticated_origin.as_ref().map(|pk| pk.to_hex()),
message_type: dht_header.message_type as i32,
destination_pubkey: dht_header.destination.public_key().map(|pk| pk.to_hex()),
Expand All @@ -81,7 +79,7 @@ impl NewStoredMessage {
},
body_hash,
body,
})
}
}
}

Expand Down
10 changes: 5 additions & 5 deletions comms/dht/src/store_forward/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tari_comms::{
message::MessageError,
peer_manager::{NodeId, PeerManagerError},
};
use tari_utilities::byte_array::ByteArrayError;
use tari_utilities::{byte_array::ByteArrayError, epoch_time::EpochTime};
use thiserror::Error;

use crate::{
Expand Down Expand Up @@ -81,10 +81,10 @@ pub enum StoreAndForwardError {
RequesterChannelClosed,
#[error("The request was cancelled by the store and forward service")]
RequestCancelled,
#[error("The message was not valid for store and forward")]
InvalidStoreMessage,
#[error("The envelope version is invalid")]
InvalidEnvelopeVersion,
#[error("The {field} field was not valid, discarding SAF response: {details}")]
InvalidSafResponseMessage { field: &'static str, details: String },
#[error("The message has expired, not storing message in SAF db (expiry: {expired}, now: {now})")]
NotStoringExpiredMessage { expired: EpochTime, now: EpochTime },
#[error("MalformedNodeId: {0}")]
MalformedNodeId(#[from] ByteArrayError),
#[error("DHT message type should not have been forwarded")]
Expand Down
7 changes: 2 additions & 5 deletions comms/dht/src/store_forward/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::convert::{TryFrom, TryInto};
use std::convert::TryFrom;

use chrono::{DateTime, Utc};
use prost::Message;
Expand Down Expand Up @@ -76,10 +76,7 @@ impl TryFrom<database::StoredMessage> for StoredMessage {
let dht_header = DhtHeader::decode(message.header.as_slice())?;
Ok(Self {
stored_at: Some(datetime_to_timestamp(DateTime::from_utc(message.stored_at, Utc))),
version: message
.version
.try_into()
.map_err(|_| StoreAndForwardError::InvalidEnvelopeVersion)?,
version: message.version as u32,
body: message.body,
dht_header: Some(dht_header),
})
Expand Down
17 changes: 11 additions & 6 deletions comms/dht/src/store_forward/saf_handler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use tari_comms::{
types::CommsPublicKey,
BytesMut,
};
use tari_utilities::{convert::try_convert_all, ByteArray};
use tari_utilities::ByteArray;
use tokio::sync::mpsc;
use tower::{Service, ServiceExt};

Expand Down Expand Up @@ -216,7 +216,7 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
let messages = self.saf_requester.fetch_messages(query.clone()).await?;

let stored_messages = StoredMessagesResponse {
messages: try_convert_all(messages)?,
messages: messages.into_iter().map(TryInto::try_into).collect::<Result<_, _>>()?,
request_id: retrieve_msgs.request_id,
response_type: resp_type as i32,
};
Expand Down Expand Up @@ -430,8 +430,13 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
.stored_at
.map(|t| {
Result::<_, StoreAndForwardError>::Ok(DateTime::from_utc(
NaiveDateTime::from_timestamp_opt(t.seconds, t.nanos.try_into().unwrap_or(u32::MAX))
.ok_or(StoreAndForwardError::InvalidStoreMessage)?,
NaiveDateTime::from_timestamp_opt(t.seconds, 0).ok_or_else(|| {
StoreAndForwardError::InvalidSafResponseMessage {
field: "stored_at",
details: "number of seconds provided represents more days than can fit in a u32"
.to_string(),
}
})?,
Utc,
))
})
Expand Down Expand Up @@ -618,7 +623,7 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
mod test {
use std::time::Duration;

use chrono::Utc;
use chrono::{Timelike, Utc};
use tari_comms::{message::MessageExt, runtime, wrap_in_envelope_body};
use tari_test_utils::collect_recv;
use tari_utilities::{hex, hex::Hex};
Expand Down Expand Up @@ -932,7 +937,7 @@ mod test {
.unwrap()
.unwrap();

assert_eq!(last_saf_received, msg2_time);
assert_eq!(last_saf_received.second(), msg2_time.second());
}

#[runtime::test]
Expand Down
8 changes: 4 additions & 4 deletions comms/dht/src/store_forward/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,13 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Se
);

if let Some(expires) = message.dht_header.expires {
if expires < EpochTime::now() {
return SafResult::Err(StoreAndForwardError::InvalidStoreMessage);
let now = EpochTime::now();
if expires < now {
return Err(StoreAndForwardError::NotStoringExpiredMessage { expired: expires, now });
}
}

let stored_message =
NewStoredMessage::try_construct(message, priority).ok_or(StoreAndForwardError::InvalidStoreMessage)?;
let stored_message = NewStoredMessage::new(message, priority);
self.saf_requester.insert_message(stored_message).await
}
}
Expand Down