Skip to content

Commit

Permalink
[dag] async message handler support (#10055)
Browse files Browse the repository at this point in the history
* [dag] async message handler support
* [dag] pull payload in dag driver
  • Loading branch information
ibalajiarun authored Sep 18, 2023
1 parent 6b2197f commit c147d54
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 77 deletions.
7 changes: 5 additions & 2 deletions consensus/src/dag/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ pub(super) fn bootstrap_dag_for_test(

let (ordered_nodes_tx, ordered_nodes_rx) = futures_channel::mpsc::unbounded();
let adapter = Arc::new(NotifierAdapter::new(ordered_nodes_tx, storage.clone()));
let (dag_rpc_tx, mut dag_rpc_rx) = aptos_channel::new(QueueStyle::FIFO, 64, None);
let (dag_rpc_tx, dag_rpc_rx) = aptos_channel::new(QueueStyle::FIFO, 64, None);

let (dag_store, order_rule) =
bootstraper.bootstrap_dag_store(latest_ledger_info, adapter.clone());
Expand All @@ -266,7 +266,10 @@ pub(super) fn bootstrap_dag_for_test(
let (handler, fetch_service) =
bootstraper.bootstrap_components(dag_store.clone(), order_rule, state_sync_trigger);

let dh_handle = tokio::spawn(async move { handler.run(&mut dag_rpc_rx).await });
let dh_handle = tokio::spawn(async move {
let mut dag_rpc_rx = dag_rpc_rx;
handler.run(&mut dag_rpc_rx).await
});
let df_handle = tokio::spawn(fetch_service.start());

(dh_handle, df_handle, dag_rpc_tx, ordered_nodes_rx)
Expand Down
81 changes: 58 additions & 23 deletions consensus/src/dag/dag_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ use crate::{
},
state_replication::PayloadClient,
};
use anyhow::{bail, Ok};
use aptos_consensus_types::common::{Author, Payload};
use anyhow::bail;
use aptos_consensus_types::common::{Author, PayloadFilter};
use aptos_infallible::RwLock;
use aptos_logger::error;
use aptos_logger::{debug, error};
use aptos_reliable_broadcast::ReliableBroadcast;
use aptos_time_service::{TimeService, TimeServiceTrait};
use aptos_types::{block_info::Round, epoch_state::EpochState};
use async_trait::async_trait;
use futures::{
executor::block_on,
future::{AbortHandle, Abortable},
FutureExt,
};
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use thiserror::Error as ThisError;
use tokio_retry::strategy::ExponentialBackoff;

Expand Down Expand Up @@ -71,6 +73,12 @@ impl DagDriver {
.read()
.get_strong_links_for_round(highest_round, &epoch_state.verifier)
.map_or_else(|| highest_round.saturating_sub(1), |_| highest_round);

debug!(
"highest_round: {}, current_round: {}",
highest_round, current_round
);

let mut driver = Self {
author,
epoch_state,
Expand All @@ -96,37 +104,62 @@ impl DagDriver {
.read()
.get_strong_links_for_round(current_round, &driver.epoch_state.verifier)
.unwrap_or(vec![]);
driver.enter_new_round(current_round + 1, strong_links);
block_on(driver.enter_new_round(current_round + 1, strong_links));
}
driver
}

pub fn add_node(&mut self, node: CertifiedNode) -> anyhow::Result<()> {
let mut dag_writer = self.dag.write();
let round = node.metadata().round();
pub async fn add_node(&mut self, node: CertifiedNode) -> anyhow::Result<()> {
let maybe_strong_links = {
let mut dag_writer = self.dag.write();
let round = node.metadata().round();

if !dag_writer.all_exists(node.parents_metadata()) {
if let Err(err) = self.fetch_requester.request_for_certified_node(node) {
error!("request to fetch failed: {}", err);
if !dag_writer.all_exists(node.parents_metadata()) {
if let Err(err) = self.fetch_requester.request_for_certified_node(node) {
error!("request to fetch failed: {}", err);
}
bail!(DagDriverError::MissingParents);
}
bail!(DagDriverError::MissingParents);
}

dag_writer.add_node(node)?;
if self.current_round == round {
let maybe_strong_links = dag_writer
.get_strong_links_for_round(self.current_round, &self.epoch_state.verifier);
drop(dag_writer);
if let Some(strong_links) = maybe_strong_links {
self.enter_new_round(self.current_round + 1, strong_links);
dag_writer.add_node(node)?;
if self.current_round == round {
dag_writer
.get_strong_links_for_round(self.current_round, &self.epoch_state.verifier)
} else {
None
}
};

if let Some(strong_links) = maybe_strong_links {
self.enter_new_round(self.current_round + 1, strong_links)
.await;
}
Ok(())
}

pub fn enter_new_round(&mut self, new_round: Round, strong_links: Vec<NodeCertificate>) {
pub async fn enter_new_round(&mut self, new_round: Round, strong_links: Vec<NodeCertificate>) {
debug!("entering new round {}", new_round);
// TODO: support pulling payload
let payload = Payload::empty(false);
let payload = match self
.payload_client
.pull_payload(
Duration::from_secs(1),
100,
1000,
PayloadFilter::Empty,
Box::pin(async {}),
false,
0,
0.0,
)
.await
{
Ok(payload) => payload,
Err(e) => {
error!("error pulling payload: {}", e);
return;
},
};
// TODO: need to wait to pass median of parents timestamp
let timestamp = self.time_service.now_unix_time();
self.current_round = new_round;
Expand Down Expand Up @@ -171,11 +204,12 @@ impl DagDriver {
}
}

#[async_trait]
impl RpcHandler for DagDriver {
type Request = CertifiedNode;
type Response = CertifiedAck;

fn process(&mut self, node: Self::Request) -> anyhow::Result<Self::Response> {
async fn process(&mut self, node: Self::Request) -> anyhow::Result<Self::Response> {
let epoch = node.metadata().epoch();
{
let dag_reader = self.dag.read();
Expand All @@ -186,6 +220,7 @@ impl RpcHandler for DagDriver {

let node_metadata = node.metadata().clone();
self.add_node(node)
.await
.map(|_| self.order_rule.process_new_node(&node_metadata))?;

Ok(CertifiedAck::new(epoch))
Expand Down
3 changes: 2 additions & 1 deletion consensus/src/dag/dag_fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,12 @@ impl FetchRequestHandler {
}
}

#[async_trait]
impl RpcHandler for FetchRequestHandler {
type Request = RemoteFetchRequest;
type Response = FetchResponse;

fn process(&mut self, message: Self::Request) -> anyhow::Result<Self::Response> {
async fn process(&mut self, message: Self::Request) -> anyhow::Result<Self::Response> {
let dag_reader = self.dag.read();

// `Certified Node`: In the good case, there should exist at least one honest validator that
Expand Down
86 changes: 54 additions & 32 deletions consensus/src/dag/dag_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
use anyhow::bail;
use aptos_channels::aptos_channel;
use aptos_consensus_types::common::Author;
use aptos_logger::{error, warn};
use aptos_logger::{debug, warn};
use aptos_network::protocols::network::RpcError;
use aptos_types::epoch_state::EpochState;
use bytes::Bytes;
Expand Down Expand Up @@ -76,19 +76,42 @@ impl NetworkHandler {
}
},
Some(res) = self.node_fetch_waiter.next() => {
if let Err(e) = res.map_err(|e| anyhow::anyhow!("recv error: {}", e)).and_then(|node| self.node_receiver.process(node)) {
warn!(error = ?e, "error processing node fetch notification");
}
match res {
Ok(node) => if let Err(e) = self.node_receiver.process(node).await {
warn!(error = ?e, "error processing node fetch notification");
},
Err(e) => {
debug!("sender dropped channel: {}", e);
},
};
},
Some(res) = self.certified_node_fetch_waiter.next() => {
if let Err(e) = res.map_err(|e| anyhow::anyhow!("recv error: {}", e)).and_then(|certified_node| self.dag_driver.process(certified_node)) {
warn!(error = ?e, "error processing certified node fetch notification");
}
match res {
Ok(certified_node) => if let Err(e) = self.dag_driver.process(certified_node).await {
warn!(error = ?e, "error processing certified node fetch notification"); },
Err(e) => {
debug!("sender dropped channel: {}", e);
},
};
}
}
}
}

fn verify_incoming_rpc(&self, dag_message: &DAGMessage) -> Result<(), anyhow::Error> {
match dag_message {
DAGMessage::NodeMsg(node) => node.verify(&self.epoch_state.verifier),
DAGMessage::CertifiedNodeMsg(certified_node) => {
certified_node.verify(&self.epoch_state.verifier)
},
DAGMessage::FetchRequest(request) => request.verify(&self.epoch_state.verifier),
_ => Err(anyhow::anyhow!(
"unexpected rpc message{:?}",
std::mem::discriminant(dag_message)
)),
}
}

async fn process_rpc(
&mut self,
rpc_request: IncomingDAGRequest,
Expand All @@ -102,32 +125,31 @@ impl NetworkHandler {
bail!("message author and network author mismatch");
}

let response: anyhow::Result<DAGMessage> = match dag_message {
DAGMessage::NodeMsg(node) => node
.verify(&self.epoch_state.verifier)
.and_then(|_| self.node_receiver.process(node))
.map(|r| r.into()),
DAGMessage::CertifiedNodeMsg(certified_node_msg) => {
match certified_node_msg.verify(&self.epoch_state.verifier) {
Ok(_) => match self.state_sync_trigger.check(certified_node_msg).await {
ret @ (NeedsSync(_), None) => return Ok(ret.0),
(Synced, Some(certified_node_msg)) => self
.dag_driver
.process(certified_node_msg.certified_node())
.map(|r| r.into()),
_ => unreachable!(),
let response: anyhow::Result<DAGMessage> = {
let verification_result = self.verify_incoming_rpc(&dag_message);
match verification_result {
Ok(_) => match dag_message {
DAGMessage::NodeMsg(node) => {
self.node_receiver.process(node).await.map(|r| r.into())
},
Err(e) => Err(e),
}
},
DAGMessage::FetchRequest(request) => request
.verify(&self.epoch_state.verifier)
.and_then(|_| self.fetch_receiver.process(request))
.map(|r| r.into()),
_ => {
error!("unknown rpc message {:?}", dag_message);
Err(anyhow::anyhow!("unknown rpc message"))
},
DAGMessage::CertifiedNodeMsg(certified_node_msg) => {
match self.state_sync_trigger.check(certified_node_msg).await {
ret @ (NeedsSync(_), None) => return Ok(ret.0),
(Synced, Some(certified_node_msg)) => self
.dag_driver
.process(certified_node_msg.certified_node())
.await
.map(|r| r.into()),
_ => unreachable!(),
}
},
DAGMessage::FetchRequest(request) => {
self.fetch_receiver.process(request).await.map(|r| r.into())
},
_ => unreachable!("verification must catch this error"),
},
Err(err) => Err(err),
}
};

let response = response
Expand Down
3 changes: 2 additions & 1 deletion consensus/src/dag/dag_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ use std::{
time::Duration,
};

#[async_trait]
pub trait RpcHandler {
type Request;
type Response;

fn process(&mut self, message: Self::Request) -> anyhow::Result<Self::Response>;
async fn process(&mut self, message: Self::Request) -> anyhow::Result<Self::Response>;
}

#[async_trait]
Expand Down
3 changes: 1 addition & 2 deletions consensus/src/dag/dag_state_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ impl StateSyncTrigger {
}

/// This method checks if a state sync is required, and if so,
/// notifies the bootstraper and yields the current task infinitely,
/// to let the bootstraper can abort this task.
/// notifies the bootstraper, to let the bootstraper can abort this task.
pub(super) async fn check(
&self,
node: CertifiedNodeMessage,
Expand Down
4 changes: 3 additions & 1 deletion consensus/src/dag/rb_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use aptos_consensus_types::common::{Author, Round};
use aptos_infallible::RwLock;
use aptos_logger::error;
use aptos_types::{epoch_state::EpochState, validator_signer::ValidatorSigner};
use async_trait::async_trait;
use std::{collections::BTreeMap, mem, sync::Arc};
use thiserror::Error as ThisError;

Expand Down Expand Up @@ -140,11 +141,12 @@ fn read_votes_from_storage(
votes_by_round_peer
}

#[async_trait]
impl RpcHandler for NodeBroadcastHandler {
type Request = Node;
type Response = Vote;

fn process(&mut self, node: Self::Request) -> anyhow::Result<Self::Response> {
async fn process(&mut self, node: Self::Request) -> anyhow::Result<Self::Response> {
let node = self.validate(node)?;

let votes_by_peer = self
Expand Down
6 changes: 3 additions & 3 deletions consensus/src/dag/tests/dag_driver_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ async fn test_certified_node_handler() {

let first_round_node = new_certified_node(1, signers[0].author(), vec![]);
// expect an ack for a valid message
assert_ok!(driver.process(first_round_node.clone()));
assert_ok!(driver.process(first_round_node.clone()).await);
// expect an ack if the same message is sent again
assert_ok_eq!(driver.process(first_round_node), CertifiedAck::new(1));
assert_ok_eq!(driver.process(first_round_node).await, CertifiedAck::new(1));

let parent_node = new_certified_node(1, signers[1].author(), vec![]);
let invalid_node = new_certified_node(2, signers[0].author(), vec![parent_node.certificate()]);
assert_eq!(
driver.process(invalid_node).unwrap_err().to_string(),
driver.process(invalid_node).await.unwrap_err().to_string(),
DagDriverError::MissingParents.to_string()
);
}
1 change: 1 addition & 0 deletions consensus/src/dag/tests/dag_state_sync_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl TDagFetcher for MockDagFetcher {
) -> anyhow::Result<()> {
let response = FetchRequestHandler::new(self.target_dag.clone(), self.epoch_state.clone())
.process(remote_request)
.await
.unwrap();

let mut new_dag_writer = new_dag.write();
Expand Down
6 changes: 3 additions & 3 deletions consensus/src/dag/tests/fetcher_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use aptos_types::{epoch_state::EpochState, validator_verifier::random_validator_
use claims::assert_ok_eq;
use std::sync::Arc;

#[test]
fn test_dag_fetcher_receiver() {
#[tokio::test]
async fn test_dag_fetcher_receiver() {
let (signers, validator_verifier) = random_validator_verifier(4, None, false);
let epoch_state = Arc::new(EpochState {
epoch: 1,
Expand Down Expand Up @@ -56,7 +56,7 @@ fn test_dag_fetcher_receiver() {
DagSnapshotBitmask::new(1, vec![vec![true, false]]),
);
assert_ok_eq!(
fetcher.process(request),
fetcher.process(request).await,
FetchResponse::new(1, vec![first_round_nodes[1].clone()])
);
}
Expand Down
Loading

0 comments on commit c147d54

Please sign in to comment.