diff --git a/src/consensus/dkg.rs b/src/consensus/dkg.rs index 814527db42..fd8886c988 100644 --- a/src/consensus/dkg.rs +++ b/src/consensus/dkg.rs @@ -172,8 +172,9 @@ impl DkgVoter { let _ = self.sessions.insert(dkg_key, session); // Remove uneeded old sessions. - self.sessions - .retain(|old_dkg_key, _| old_dkg_key.generation >= dkg_key.generation); + self.sessions.retain(|existing_dkg_key, _| { + existing_dkg_key.generation >= dkg_key.generation + }); self.backlog.prune(&dkg_key); commands diff --git a/src/network/mod.rs b/src/network/mod.rs index e88e75caf4..5047a6ca40 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -111,7 +111,7 @@ impl Network { /// key is not in our `section_chain`. To prove the key is valid, it must be accompanied by an /// additional `key_proof` which signs it using a key that is present in `section_chain`. /// - /// If this is for a non-sibling section, the currently we require the info to be signed by our + /// If this is for a non-sibling section, then currently we require the info to be signed by our /// section (so we need to accumulate the signature for it first) and so `key_proof` is not /// needed in that case. pub fn update_neighbour_info( diff --git a/src/routing/approved.rs b/src/routing/approved.rs index 3de693f03b..30df0cbd9e 100644 --- a/src/routing/approved.rs +++ b/src/routing/approved.rs @@ -1384,11 +1384,12 @@ impl Approved { .ok_or(Error::InvalidSrcLocation)? .peer; + let generation = self.section.chain().main_branch_len() as u64; let elders_info = self .section .promote_and_demote_elders(&self.node.name()) .into_iter() - .find(|elders_info| proofs.verify(elders_info, self.section.chain().len() as u64)); + .find(|elders_info| proofs.verify(elders_info, generation)); let elders_info = if let Some(elders_info) = elders_info { elders_info } else { @@ -1993,7 +1994,8 @@ impl Approved { ) -> Result> { trace!("Send DKGStart for {} to {:?}", elders_info, recipients); - let dkg_key = DkgKey::new(&elders_info, self.section.chain().len() as u64); + let generation = self.section.chain().main_branch_len() as u64; + let dkg_key = DkgKey::new(&elders_info, generation); let variant = Variant::DKGStart { dkg_key, elders_info, diff --git a/src/section/section_chain.rs b/src/section/section_chain.rs index a349e6f5df..cca3cf23e5 100644 --- a/src/section/section_chain.rs +++ b/src/section/section_chain.rs @@ -303,26 +303,7 @@ impl SectionChain { I: IntoIterator, { let trusted_keys: HashSet<_> = trusted_keys.into_iter().collect(); - let mut index = self.tree.len(); - - loop { - let (key, parent_index) = if index > 0 { - let block = &self.tree[index - 1]; - (&block.key, Some(block.parent_index)) - } else { - (&self.root, None) - }; - - if trusted_keys.contains(key) { - return true; - } - - if let Some(next_index) = parent_index { - index = next_index; - } else { - return false; - } - } + self.main_branch().any(|key| trusted_keys.contains(key)) } /// Compare the two keys by their position in the chain. The key that is higher (closer to the @@ -343,6 +324,14 @@ impl SectionChain { 1 + self.tree.len() } + /// Returns the number of block on the main branch of the chain - that is - the ones reachable + /// from the last block. + /// + /// NOTE: this is a `O(n)` operation. + pub fn main_branch_len(&self) -> usize { + self.main_branch().count() + } + fn insert_block(&mut self, new_block: Block) -> usize { // Find the index into `self.tree` to insert the new block at so that the block order as // described in the `SectionChain` doc comment is maintained. @@ -429,6 +418,14 @@ impl SectionChain { max_index -= 1; } } + + // Iterator over the key on the main branch of the chain in reverse order. + fn main_branch(&self) -> Branch { + Branch { + chain: self, + index: Some(self.tree.len()), + } + } } impl Debug for SectionChain { @@ -479,6 +476,29 @@ impl PartialOrd for Block { } } +// Iterator over the keys on a single branch of the chain in reverse order. +struct Branch<'a> { + chain: &'a SectionChain, + index: Option, +} + +impl<'a> Iterator for Branch<'a> { + type Item = &'a bls::PublicKey; + + fn next(&mut self) -> Option { + let index = self.index?; + + if index == 0 { + self.index = None; + Some(&self.chain.root) + } else { + let block = self.chain.tree.get(index - 1)?; + self.index = Some(block.parent_index); + Some(&block.key) + } + } +} + // `SectionChain` is deserialized by first deserializing it into this intermediate structure and // then converting it into `SectionChain` using `try_from` which fails when the chain is invalid. // This makes it impossible to obtain invalid `SectionChain` from malformed serialized data, thus @@ -1071,6 +1091,23 @@ mod tests { assert_eq!(main_chain.cmp_by_position(&pk0, &pk1), Ordering::Less); } + #[test] + fn main_branch_len() { + let (sk0, pk0) = gen_keypair(); + let (_, pk1, sig1) = gen_signed_keypair(&sk0); + let (_, pk2, sig2) = gen_signed_keypair(&sk0); + + // 0->1 + let chain = make_chain(pk0, vec![(&pk0, pk1, sig1.clone())]); + assert_eq!(chain.main_branch_len(), 2); + + // 0->1 + // | + // +->2 + let chain = make_chain(pk0, vec![(&pk0, pk1, sig1), (&pk0, pk2, sig2)]); + assert_eq!(chain.main_branch_len(), 2); + } + fn gen_keypair() -> (bls::SecretKey, bls::PublicKey) { let sk = bls::SecretKey::random(); let pk = sk.public_key();