diff --git a/consensus/src/dag/dag_driver.rs b/consensus/src/dag/dag_driver.rs index 1e55fae2c9018..90b2dc81a33cb 100644 --- a/consensus/src/dag/dag_driver.rs +++ b/consensus/src/dag/dag_driver.rs @@ -55,11 +55,7 @@ impl DagDriver { pub fn add_node(&mut self, node: CertifiedNode) -> anyhow::Result<()> { let mut dag_writer = self.dag.write(); let round = node.metadata().round(); - if dag_writer.all_exists( - node.parents() - .iter() - .map(|certificate| certificate.metadata().digest()), - ) { + if dag_writer.all_exists(node.parents()) { dag_writer.add_node(node)?; if self.current_round == round { let maybe_strong_links = dag_writer diff --git a/consensus/src/dag/dag_fetcher.rs b/consensus/src/dag/dag_fetcher.rs index 394e200258635..5e3b2e1bd5a81 100644 --- a/consensus/src/dag/dag_fetcher.rs +++ b/consensus/src/dag/dag_fetcher.rs @@ -5,7 +5,7 @@ use crate::{ dag::{ dag_network::DAGNetworkSender, dag_store::Dag, - types::{CertifiedNode, DAGMessage, FetchRequest, FetchResponse, Node}, + types::{CertifiedNode, DAGMessage, FetchResponse, Node, RemoteFetchRequest}, }, network::TConsensusMsg, }; @@ -19,36 +19,43 @@ use tokio::sync::{ oneshot, }; -enum FetchCallback { +pub enum LocalFetchRequest { Node(Node, oneshot::Sender), CertifiedNode(CertifiedNode, oneshot::Sender), } -impl FetchCallback { +impl LocalFetchRequest { pub fn responders(&self, validators: &[Author]) -> Vec { match self { - FetchCallback::Node(node, _) => vec![*node.author()], - FetchCallback::CertifiedNode(node, _) => node.certificate().signers(validators), + LocalFetchRequest::Node(node, _) => vec![*node.author()], + LocalFetchRequest::CertifiedNode(node, _) => node.certificate().signers(validators), } } pub fn notify(self) { if match self { - FetchCallback::Node(node, sender) => sender.send(node).map_err(|_| ()), - FetchCallback::CertifiedNode(node, sender) => sender.send(node).map_err(|_| ()), + LocalFetchRequest::Node(node, sender) => sender.send(node).map_err(|_| ()), + LocalFetchRequest::CertifiedNode(node, sender) => sender.send(node).map_err(|_| ()), } .is_err() { error!("Failed to send node back"); } } + + pub fn node(&self) -> &Node { + match self { + LocalFetchRequest::Node(node, _) => node, + LocalFetchRequest::CertifiedNode(node, _) => node, + } + } } struct DagFetcher { epoch_state: Arc, network: Arc, dag: Arc>, - request_rx: Receiver<(FetchRequest, FetchCallback)>, + request_rx: Receiver, } impl DagFetcher { @@ -56,7 +63,7 @@ impl DagFetcher { epoch_state: Arc, network: Arc, dag: Arc>, - ) -> (Self, Sender<(FetchRequest, FetchCallback)>) { + ) -> (Self, Sender) { let (request_tx, request_rx) = tokio::sync::mpsc::channel(16); ( Self { @@ -70,17 +77,29 @@ impl DagFetcher { } pub async fn start(mut self) { - while let Some((request, callback)) = self.request_rx.recv().await { - let responders = - callback.responders(&self.epoch_state.verifier.get_ordered_account_addresses()); - let network_request = DAGMessage::from(request.clone()).into_network_message(); + while let Some(local_request) = self.request_rx.recv().await { + let responders = local_request + .responders(&self.epoch_state.verifier.get_ordered_account_addresses()); + let remote_request = { + let dag_reader = self.dag.read(); + if dag_reader.all_exists(local_request.node().parents()) { + local_request.notify(); + continue; + } + RemoteFetchRequest::new( + local_request.node().metadata().clone(), + dag_reader.lowest_round(), + dag_reader.bitmask(), + ) + }; + let network_request = DAGMessage::from(remote_request.clone()).into_network_message(); if let Ok(response) = self .network .send_rpc_with_fallbacks(responders, network_request, Duration::from_secs(1)) .await .and_then(DAGMessage::try_from) .and_then(FetchResponse::try_from) - .and_then(|response| response.verify(&request, &self.epoch_state.verifier)) + .and_then(|response| response.verify(&remote_request, &self.epoch_state.verifier)) { // TODO: support chunk response or fallback to state sync let mut dag_writer = self.dag.write(); @@ -91,7 +110,7 @@ impl DagFetcher { } } } - callback.notify(); + local_request.notify(); } } } diff --git a/consensus/src/dag/dag_store.rs b/consensus/src/dag/dag_store.rs index b369d5283556e..e749162356825 100644 --- a/consensus/src/dag/dag_store.rs +++ b/consensus/src/dag/dag_store.rs @@ -40,7 +40,7 @@ impl Dag { .unwrap_or(&0) } - fn highest_round(&self) -> Round { + pub fn highest_round(&self) -> Round { *self .nodes_by_round .last_key_value() @@ -79,8 +79,11 @@ impl Dag { self.nodes_by_digest.contains_key(digest) } - pub fn all_exists<'a>(&self, mut digests: impl Iterator) -> bool { - digests.all(|digest| self.nodes_by_digest.contains_key(digest)) + pub fn all_exists(&self, nodes: &[NodeCertificate]) -> bool { + nodes.iter().all(|certificate| { + self.nodes_by_digest + .contains_key(certificate.metadata().digest()) + }) } pub fn get_node(&self, digest: &HashValue) -> Option> { @@ -115,4 +118,9 @@ impl Dag { None } } + + pub fn bitmask(&self) -> Vec> { + // TODO: extract local bitvec + todo!(); + } } diff --git a/consensus/src/dag/types.rs b/consensus/src/dag/types.rs index 37dc2bbcf3bba..fc1ef5e7938d6 100644 --- a/consensus/src/dag/types.rs +++ b/consensus/src/dag/types.rs @@ -386,12 +386,22 @@ impl BroadcastStatus for CertificateAckState { /// the first round we care about in the DAG, `exists_bitmask` is a two dimensional bitmask represents /// if a node exist at [start_round + index][validator_index]. #[derive(Serialize, Deserialize, Clone, Debug)] -pub struct FetchRequest { +pub struct RemoteFetchRequest { target: NodeMetadata, start_round: Round, exists_bitmask: Vec>, } +impl RemoteFetchRequest { + pub fn new(target: NodeMetadata, start_round: Round, exists_bitmask: Vec>) -> Self { + Self { + target, + start_round, + exists_bitmask, + } + } +} + /// Represents a response to FetchRequest, `certified_nodes` are indexed by [round][validator_index] /// It should fill in gaps from the `exists_bitmask` according to the parents from the `target_digest` node. #[derive(Serialize, Deserialize, Clone, Debug)] @@ -407,7 +417,7 @@ impl FetchResponse { pub fn verify( self, - _request: &FetchRequest, + _request: &RemoteFetchRequest, _validator_verifier: &ValidatorVerifier, ) -> anyhow::Result { todo!("verification"); @@ -427,7 +437,7 @@ pub enum DAGMessage { NodeDigestSignatureMsg(NodeDigestSignature), NodeCertificateMsg(NodeCertificate), CertifiedAckMsg(CertifiedAck), - FetchRequest(FetchRequest), + FetchRequest(RemoteFetchRequest), FetchResponse(FetchResponse), #[cfg(test)] @@ -522,7 +532,7 @@ impl TryFrom for CertifiedAck { } } -impl TryFrom for FetchRequest { +impl TryFrom for RemoteFetchRequest { type Error = anyhow::Error; fn try_from(msg: DAGMessage) -> Result { @@ -568,8 +578,8 @@ impl From for DAGMessage { } } -impl From for DAGMessage { - fn from(req: FetchRequest) -> Self { +impl From for DAGMessage { + fn from(req: RemoteFetchRequest) -> Self { Self::FetchRequest(req) } }