Skip to content

Commit

Permalink
refactor(p2p_network): use tokio::mpsc instead of flume
Browse files Browse the repository at this point in the history
  • Loading branch information
CHr15F0x committed Feb 19, 2024
1 parent c55e120 commit c2932e3
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions crates/pathfinder/src/p2p_network/sync_handlers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use anyhow::Context;
use futures::channel::mpsc;
use futures::{SinkExt, StreamExt};
use futures::SinkExt;
use p2p_proto::class::{Class, ClassesRequest, ClassesResponse};
use p2p_proto::common::{
Address, BlockNumberOrHash, ConsensusSignature, Direction, Hash, Iteration, Merkle, Patricia,
Expand All @@ -16,6 +15,7 @@ use pathfinder_crypto::Felt;
use pathfinder_storage::Storage;
use pathfinder_storage::Transaction;
use starknet_gateway_types::class_definition;
use tokio::sync::mpsc;

pub mod conv;
#[cfg(test)]
Expand All @@ -34,47 +34,47 @@ const MAX_BLOCKS_COUNT: u64 = MAX_COUNT_IN_TESTS;
pub async fn get_headers(
storage: Storage,
request: BlockHeadersRequest,
tx: mpsc::Sender<BlockHeadersResponse>,
tx: futures::channel::mpsc::Sender<BlockHeadersResponse>,
) -> anyhow::Result<()> {
spawn_blocking_get(request, storage, blocking::get_headers, tx).await
}

pub async fn get_classes(
storage: Storage,
request: ClassesRequest,
tx: mpsc::Sender<ClassesResponse>,
tx: futures::channel::mpsc::Sender<ClassesResponse>,
) -> anyhow::Result<()> {
spawn_blocking_get(request, storage, blocking::get_classes, tx).await
}

pub async fn get_state_diffs(
storage: Storage,
request: StateDiffsRequest,
tx: mpsc::Sender<StateDiffsResponse>,
tx: futures::channel::mpsc::Sender<StateDiffsResponse>,
) -> anyhow::Result<()> {
spawn_blocking_get(request, storage, blocking::get_state_diffs, tx).await
}

pub async fn get_transactions(
storage: Storage,
request: TransactionsRequest,
tx: mpsc::Sender<TransactionsResponse>,
tx: futures::channel::mpsc::Sender<TransactionsResponse>,
) -> anyhow::Result<()> {
spawn_blocking_get(request, storage, blocking::get_transactions, tx).await
}

pub async fn get_receipts(
storage: Storage,
request: ReceiptsRequest,
tx: mpsc::Sender<ReceiptsResponse>,
tx: futures::channel::mpsc::Sender<ReceiptsResponse>,
) -> anyhow::Result<()> {
spawn_blocking_get(request, storage, blocking::get_receipts, tx).await
}

pub async fn get_events(
storage: Storage,
request: EventsRequest,
tx: mpsc::Sender<EventsResponse>,
tx: futures::channel::mpsc::Sender<EventsResponse>,
) -> anyhow::Result<()> {
spawn_blocking_get(request, storage, blocking::get_events, tx).await
}
Expand All @@ -85,47 +85,47 @@ pub(crate) mod blocking {
pub(crate) fn get_headers(
db_tx: Transaction<'_>,
request: BlockHeadersRequest,
tx: flume::Sender<BlockHeadersResponse>,
tx: mpsc::Sender<BlockHeadersResponse>,
) -> anyhow::Result<()> {
iterate(db_tx, request.iteration, get_header, tx)
}

pub(crate) fn get_classes(
db_tx: Transaction<'_>,
request: ClassesRequest,
tx: flume::Sender<ClassesResponse>,
tx: mpsc::Sender<ClassesResponse>,
) -> anyhow::Result<()> {
iterate(db_tx, request.iteration, get_classes_for_block, tx)
}

pub(crate) fn get_state_diffs(
db_tx: Transaction<'_>,
request: StateDiffsRequest,
tx: flume::Sender<StateDiffsResponse>,
tx: mpsc::Sender<StateDiffsResponse>,
) -> anyhow::Result<()> {
iterate(db_tx, request.iteration, get_state_diff, tx)
}

pub(crate) fn get_transactions(
db_tx: Transaction<'_>,
request: TransactionsRequest,
tx: flume::Sender<TransactionsResponse>,
tx: mpsc::Sender<TransactionsResponse>,
) -> anyhow::Result<()> {
iterate(db_tx, request.iteration, get_transactions_for_block, tx)
}

pub(crate) fn get_receipts(
db_tx: Transaction<'_>,
request: ReceiptsRequest,
tx: flume::Sender<ReceiptsResponse>,
tx: mpsc::Sender<ReceiptsResponse>,
) -> anyhow::Result<()> {
iterate(db_tx, request.iteration, get_receipts_for_block, tx)
}

pub(crate) fn get_events(
db_tx: Transaction<'_>,
request: EventsRequest,
tx: flume::Sender<EventsResponse>,
tx: mpsc::Sender<EventsResponse>,
) -> anyhow::Result<()> {
iterate(db_tx, request.iteration, get_events_for_block, tx)
}
Expand All @@ -134,7 +134,7 @@ pub(crate) mod blocking {
fn get_header(
db_tx: &Transaction<'_>,
block_number: BlockNumber,
tx: &flume::Sender<BlockHeadersResponse>,
tx: &mpsc::Sender<BlockHeadersResponse>,
) -> anyhow::Result<bool> {
if let Some(header) = db_tx.block_header(block_number.into())? {
if let Some(signature) = db_tx.signature(block_number.into())? {
Expand All @@ -143,7 +143,7 @@ fn get_header(
.try_into()
.context("invalid transaction count")?;

tx.send(BlockHeadersResponse::Header(Box::new(SignedBlockHeader {
tx.blocking_send(BlockHeadersResponse::Header(Box::new(SignedBlockHeader {
block_hash: Hash(header.hash.0),
parent_hash: Hash(header.parent_hash.0),
number: header.number.get(),
Expand Down Expand Up @@ -198,7 +198,7 @@ enum ClassDefinition {
fn get_classes_for_block(
db_tx: &Transaction<'_>,
block_number: BlockNumber,
tx: &flume::Sender<ClassesResponse>,
tx: &mpsc::Sender<ClassesResponse>,
) -> anyhow::Result<bool> {
let get_definition =
|block_number: BlockNumber, class_hash| -> anyhow::Result<ClassDefinition> {
Expand Down Expand Up @@ -249,7 +249,7 @@ fn get_classes_for_block(
}
};

tx.send(ClassesResponse::Class(class))
tx.blocking_send(ClassesResponse::Class(class))
.map_err(|_| anyhow::anyhow!("Sending class"))?;
}

Expand All @@ -259,14 +259,14 @@ fn get_classes_for_block(
fn get_state_diff(
db_tx: &Transaction<'_>,
block_number: BlockNumber,
tx: &flume::Sender<StateDiffsResponse>,
tx: &mpsc::Sender<StateDiffsResponse>,
) -> anyhow::Result<bool> {
let Some(state_diff) = db_tx.state_update(block_number.into())? else {
return Ok(false);
};

for (address, update) in state_diff.contract_updates {
tx.send(StateDiffsResponse::ContractDiff(ContractDiff {
tx.blocking_send(StateDiffsResponse::ContractDiff(ContractDiff {
address: Address(address.0),
nonce: update.nonce.map(|n| n.0),
class_hash: update.class.as_ref().map(|c| c.class_hash().0),
Expand All @@ -285,7 +285,7 @@ fn get_state_diff(
}

for (address, update) in state_diff.system_contract_updates {
tx.send(StateDiffsResponse::ContractDiff(ContractDiff {
tx.blocking_send(StateDiffsResponse::ContractDiff(ContractDiff {
address: Address(address.0),
nonce: None,
class_hash: None,
Expand All @@ -309,14 +309,14 @@ fn get_state_diff(
fn get_transactions_for_block(
db_tx: &Transaction<'_>,
block_number: BlockNumber,
tx: &flume::Sender<TransactionsResponse>,
tx: &mpsc::Sender<TransactionsResponse>,
) -> anyhow::Result<bool> {
let Some(txn_data) = db_tx.transaction_data_for_block(block_number.into())? else {
return Ok(false);
};

for (txn, _) in txn_data {
tx.send(TransactionsResponse::Transaction(txn.to_dto()))
tx.blocking_send(TransactionsResponse::Transaction(txn.to_dto()))
.map_err(|_| anyhow::anyhow!("Sending transaction"))?;
}

Expand All @@ -326,14 +326,14 @@ fn get_transactions_for_block(
fn get_receipts_for_block(
db_tx: &Transaction<'_>,
block_number: BlockNumber,
tx: &flume::Sender<ReceiptsResponse>,
tx: &mpsc::Sender<ReceiptsResponse>,
) -> anyhow::Result<bool> {
let Some(txn_data) = db_tx.transaction_data_for_block(block_number.into())? else {
return Ok(false);
};

for tr in txn_data {
tx.send(ReceiptsResponse::Receipt(tr.to_dto()))
tx.blocking_send(ReceiptsResponse::Receipt(tr.to_dto()))
.map_err(|_| anyhow::anyhow!("Sending receipt"))?;
}

Expand All @@ -343,15 +343,15 @@ fn get_receipts_for_block(
fn get_events_for_block(
db_tx: &Transaction<'_>,
block_number: BlockNumber,
tx: &flume::Sender<EventsResponse>,
tx: &mpsc::Sender<EventsResponse>,
) -> anyhow::Result<bool> {
let Some(txn_data) = db_tx.transaction_data_for_block(block_number.into())? else {
return Ok(false);
};

for (_, r) in txn_data {
for event in r.events {
tx.send(EventsResponse::Event((r.transaction_hash, event).to_dto()))
tx.blocking_send(EventsResponse::Event((r.transaction_hash, event).to_dto()))
.map_err(|_| anyhow::anyhow!("Sending event"))?;
}
}
Expand All @@ -365,8 +365,8 @@ fn get_events_for_block(
fn iterate<T: Default + std::fmt::Debug>(
db_tx: Transaction<'_>,
iteration: Iteration,
block_handler: impl Fn(&Transaction<'_>, BlockNumber, &flume::Sender<T>) -> anyhow::Result<bool>,
tx: flume::Sender<T>,
block_handler: impl Fn(&Transaction<'_>, BlockNumber, &mpsc::Sender<T>) -> anyhow::Result<bool>,
tx: mpsc::Sender<T>,
) -> anyhow::Result<()> {
let Iteration {
start,
Expand All @@ -376,15 +376,15 @@ fn iterate<T: Default + std::fmt::Debug>(
} = iteration;

if limit == 0 {
tx.send(T::default())
tx.blocking_send(T::default())
.map_err(|_| anyhow::anyhow!("Sending Fin"))?;
return Ok(());
}

let mut block_number = match get_start_block_number(start, &db_tx)? {
Some(x) => x,
None => {
tx.send(T::default())
tx.blocking_send(T::default())
.map_err(|_| anyhow::anyhow!("Sending Fin"))?;
return Ok(());
}
Expand All @@ -409,7 +409,7 @@ fn iterate<T: Default + std::fmt::Debug>(
}
}

tx.send(T::default())
tx.blocking_send(T::default())
.map_err(|_| anyhow::anyhow!("Sending Fin"))?;

Ok(())
Expand All @@ -427,24 +427,24 @@ fn get_start_block_number(

/// Spawns a blocking task and forwards the result to the given channel.
/// Bails out early if the database operation fails or sending fails.
/// The `getter` function is expected to send partial results through the flume channel as soon as possible,
/// The `getter` function is expected to send partial results through the tokio channel as soon as possible,
/// ideally after each database read operation.
async fn spawn_blocking_get<Request, Response, Getter>(
request: Request,
storage: Storage,
getter: Getter,
mut tx: mpsc::Sender<Response>,
mut tx: futures::channel::mpsc::Sender<Response>,
) -> anyhow::Result<()>
where
Request: Send + 'static,
Response: Send + 'static,
Getter: FnOnce(Transaction<'_>, Request, flume::Sender<Response>) -> anyhow::Result<()>
Getter: FnOnce(Transaction<'_>, Request, mpsc::Sender<Response>) -> anyhow::Result<()>
+ Send
+ 'static,
{
let span = tracing::Span::current();

let (sync_tx, rx) = flume::bounded(0); // For backpressure
let (sync_tx, mut rx) = mpsc::channel(1); // For backpressure

let db_fut = async {
tokio::task::spawn_blocking(move || {
Expand All @@ -463,8 +463,7 @@ where
};

let fwd_fut = async move {
let mut rx = rx.into_stream();
while let Some(x) = rx.next().await {
while let Some(x) = rx.recv().await {
tx.send(x).await.context("Sending item")?;
}
Ok::<_, anyhow::Error>(())
Expand Down

0 comments on commit c2932e3

Please sign in to comment.