Skip to content

Commit

Permalink
move around
Browse files Browse the repository at this point in the history
  • Loading branch information
benr-ml committed Jun 16, 2024
1 parent 3cbc700 commit 74d0cde
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 138 deletions.
277 changes: 139 additions & 138 deletions crates/sui-core/src/epoch/randomness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,75 @@ impl VersionedProcessedMessage {
VersionedProcessedMessage::V1(msg) => msg.message.sender,
}
}

pub fn process(
dkg_version: u64,
party: Arc<dkg::Party<PkG, EncG>>,
message: VersionedDkgMessage,
) -> FastCryptoResult<VersionedProcessedMessage> {
match message {
VersionedDkgMessage::V0(msg) => {
if dkg_version != 0 {
panic!("BUG: invalid versioned message")
}
let processed = party.process_message(msg, &mut rand::thread_rng())?;
Ok(VersionedProcessedMessage::V0(processed))
}
VersionedDkgMessage::V1(msg) => {
if dkg_version != 1 {
panic!("BUG: invalid versioned message")
}
let processed = party.process_message_v1(msg, &mut rand::thread_rng())?;
Ok(VersionedProcessedMessage::V1(processed))
}
}
}

pub fn merge(
dkg_version: u64,
party: Arc<dkg::Party<PkG, EncG>>,
messages: Vec<Self>,
) -> FastCryptoResult<(VersionedDkgConfimation, VersionedUsedProcessedMessages)> {
match dkg_version {
0 => {
let (conf, msgs) = party.merge(
&messages
.into_iter()
.map(|vm| {
if let VersionedProcessedMessage::V0(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.collect::<Vec<_>>(),
)?;
Ok((
VersionedDkgConfimation::V0(conf),
VersionedUsedProcessedMessages::V0(msgs),
))
}
1 => {
let (conf, msgs) = party.merge_v1(
&messages
.into_iter()
.map(|vm| {
if let VersionedProcessedMessage::V1(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.collect::<Vec<_>>(),
)?;
Ok((
VersionedDkgConfimation::V1(conf),
VersionedUsedProcessedMessages::V1(msgs),
))
}
_ => panic!("BUG: invalid DKG version"),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand Down Expand Up @@ -322,124 +391,6 @@ impl RandomnessManager {
Some(rm)
}

fn create_dkg_message(&self, version: u64) -> FastCryptoResult<VersionedDkgMessage> {
match version {
0 => {
let msg = self.party.create_message(&mut rand::thread_rng())?;
Ok(VersionedDkgMessage::V0(msg))
}
1 => {
let msg = self.party.create_message_v1(&mut rand::thread_rng())?;
Ok(VersionedDkgMessage::V1(msg))
}
_ => panic!("BUG: invalid DKG version"),
}
}

fn merge_dkg_messages(
&self,
dkg_version: u64,
) -> FastCryptoResult<(VersionedDkgConfimation, VersionedUsedProcessedMessages)> {
match dkg_version {
0 => {
let (conf, msgs) = self.party.merge(
&self
.processed_messages
.values()
.map(|vm| {
if let VersionedProcessedMessage::V0(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.cloned()
.collect::<Vec<_>>(),
)?;
Ok((
VersionedDkgConfimation::V0(conf),
VersionedUsedProcessedMessages::V0(msgs),
))
}
1 => {
let (conf, msgs) = self.party.merge_v1(
&self
.processed_messages
.values()
.map(|vm| {
if let VersionedProcessedMessage::V1(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.cloned()
.collect::<Vec<_>>(),
)?;
Ok((
VersionedDkgConfimation::V1(conf),
VersionedUsedProcessedMessages::V1(msgs),
))
}
_ => panic!("BUG: invalid DKG version"),
}
}

fn complete_dkg(&self, dkg_version: u64) -> FastCryptoResult<Output<PkG, EncG>> {
let rng = &mut StdRng::from_rng(OsRng).expect("RNG construction should not fail");
match dkg_version {
0 => self.party.complete(
if let VersionedUsedProcessedMessages::V0(msg) = self
.used_messages
.get()
.expect("checked above that `used_messages` is initialized")
{
msg
} else {
panic!("BUG: used_messages should be V0")
},
&self
.confirmations
.values()
.map(|vm| {
if let VersionedDkgConfimation::V0(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.cloned()
.collect::<Vec<_>>(),
rng,
),
1 => self.party.complete_v1(
if let VersionedUsedProcessedMessages::V1(msg) = self
.used_messages
.get()
.expect("checked above that `used_messages` is initialized")
{
msg
} else {
panic!("BUG: used_messages should be V1")
},
&self
.confirmations
.values()
.map(|vm| {
if let VersionedDkgConfimation::V1(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.cloned()
.collect::<Vec<_>>(),
rng,
),
_ => panic!("BUG: invalid DKG version"),
}
}

/// Sends the initial dkg::Message to begin the randomness DKG protocol.
pub fn start_dkg(&mut self) -> SuiResult {
if self.used_messages.initialized() || self.dkg_output.initialized() {
Expand All @@ -450,7 +401,10 @@ impl RandomnessManager {

let epoch_store = self.epoch_store()?;

let msg = match self.create_dkg_message(epoch_store.protocol_config().dkg_version()) {
let msg = match VersionedDkgMessage::create(
epoch_store.protocol_config().dkg_version(),
self.party.clone(),
) {
Ok(msg) => msg,
Err(FastCryptoError::IgnoredMessage) => {
info!(
Expand Down Expand Up @@ -492,10 +446,60 @@ impl RandomnessManager {
Ok(())
}

fn complete_dkg(&self, dkg_version: u64) -> FastCryptoResult<Output<PkG, EncG>> {
let used_processed_messages = self
.used_messages
.get()
.expect("checked above that `used_messages` is initialized");
let confirmations = self.confirmations.values();
let rng = &mut StdRng::from_rng(OsRng).expect("RNG construction should not fail");

match dkg_version {
0 => self.party.complete(
if let VersionedUsedProcessedMessages::V0(msg) = used_processed_messages {
msg
} else {
panic!("BUG: used_messages should be V0")
},
&confirmations
.map(|vm| {
if let VersionedDkgConfimation::V0(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.cloned()
.collect::<Vec<_>>(),
rng,
),
1 => self.party.complete_v1(
if let VersionedUsedProcessedMessages::V1(msg) = used_processed_messages {
msg
} else {
panic!("BUG: used_messages should be V1")
},
&confirmations
.map(|vm| {
if let VersionedDkgConfimation::V1(msg) = vm {
msg
} else {
panic!("BUG: invalid versioned message")
}
})
.cloned()
.collect::<Vec<_>>(),
rng,
),
_ => panic!("BUG: invalid DKG version"),
}
}

/// Processes all received messages and advances the randomness DKG state machine when possible,
/// sending out a dkg::Confirmation and generating final output.
pub async fn advance_dkg(&mut self, batch: &mut DBBatch, round: Round) -> SuiResult {
let epoch_store = self.epoch_store()?;
let dkg_version = epoch_store.protocol_config().dkg_version();

// Once we have enough Messages, send a Confirmation.
if !self.dkg_output.initialized() && !self.used_messages.initialized() {
Expand All @@ -515,7 +519,14 @@ impl RandomnessManager {
}

// Attempt to generate the Confirmation.
match self.merge_dkg_messages(self.epoch_store()?.protocol_config().dkg_version()) {
match VersionedProcessedMessage::merge(
self.epoch_store()?.protocol_config().dkg_version(),
self.party.clone(),
self.processed_messages
.values()
.cloned()
.collect::<Vec<_>>(),
) {
Ok((conf, used_msgs)) => {
info!(
"random beacon: sending DKG Confirmation with {} complaints",
Expand Down Expand Up @@ -560,7 +571,7 @@ impl RandomnessManager {

// Once we have enough Confirmations, process them and update shares.
if !self.dkg_output.initialized() && self.used_messages.initialized() {
match self.complete_dkg(self.epoch_store()?.protocol_config().dkg_version()) {
match self.complete_dkg(dkg_version) {
Ok(output) => {
let num_shares = output.shares.as_ref().map_or(0, |shares| shares.len());
let epoch_elapsed = epoch_store.epoch_open_time.elapsed().as_millis();
Expand Down Expand Up @@ -645,26 +656,16 @@ impl RandomnessManager {
}

let party = self.party.clone();
let dkg_version = self.epoch_store()?.protocol_config().dkg_version();
// TODO: Could save some CPU by not processing messages if we already have enough to merge.
self.enqueued_messages.insert(
msg.sender(),
tokio::task::spawn_blocking(move || match msg {
VersionedDkgMessage::V0(msg) => {
match party.process_message(msg, &mut rand::thread_rng()) {
Ok(processed) => Some(VersionedProcessedMessage::V0(processed)),
Err(err) => {
debug!("random beacon: error while processing DKG Message: {err:?}");
None
}
}
}
VersionedDkgMessage::V1(msg) => {
match party.process_message_v1(msg, &mut rand::thread_rng()) {
Ok(processed) => Some(VersionedProcessedMessage::V1(processed)),
Err(err) => {
debug!("random beacon: error while processing DKG Message: {err:?}");
None
}
tokio::task::spawn_blocking(move || {
match VersionedProcessedMessage::process(dkg_version, party, msg) {
Ok(processed) => Some(processed),
Err(err) => {
debug!("random beacon: error while processing DKG Message: {err:?}");
None
}
}
}),
Expand Down
19 changes: 19 additions & 0 deletions crates/sui-types/src/messages_consensus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::messages_checkpoint::{
};
use crate::transaction::CertifiedTransaction;
use byteorder::{BigEndian, ReadBytesExt};
use fastcrypto::error::FastCryptoResult;
use fastcrypto::groups::bls12381;
use fastcrypto_tbls::{dkg, dkg_v0, dkg_v1};
use fastcrypto_zkp::bn254::zk_login::{JwkId, JWK};
Expand All @@ -17,6 +18,7 @@ use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::fmt::{Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use sui_protocol_config::SupportedProtocolVersions;

Expand Down Expand Up @@ -256,6 +258,23 @@ impl VersionedDkgMessage {
VersionedDkgMessage::V1(msg) => msg.sender,
}
}

pub fn create(
dkg_version: u64,
party: Arc<dkg::Party<bls12381::G2Element, bls12381::G2Element>>,
) -> FastCryptoResult<VersionedDkgMessage> {
match dkg_version {
0 => {
let msg = party.create_message(&mut rand::thread_rng())?;
Ok(VersionedDkgMessage::V0(msg))
}
1 => {
let msg = party.create_message_v1(&mut rand::thread_rng())?;
Ok(VersionedDkgMessage::V1(msg))
}
_ => panic!("BUG: invalid DKG version"),
}
}
}

impl VersionedDkgConfimation {
Expand Down

0 comments on commit 74d0cde

Please sign in to comment.