diff --git a/core/store/src/trie/prefetching_trie_storage.rs b/core/store/src/trie/prefetching_trie_storage.rs index b00aff22bee..6f03b2f7642 100644 --- a/core/store/src/trie/prefetching_trie_storage.rs +++ b/core/store/src/trie/prefetching_trie_storage.rs @@ -11,6 +11,7 @@ use near_primitives::trie_key::TrieKey; use near_primitives::types::{AccountId, ShardId, StateRoot, TrieNodesCount}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use tracing::error; const MAX_QUEUED_WORK_ITEMS: usize = 16 * 1024; const MAX_PREFETCH_STAGING_MEMORY: usize = 200 * 1024 * 1024; @@ -62,8 +63,7 @@ pub struct PrefetchApi { /// changing the queue to an enum. /// The state root is also included because multiple chunks could be applied /// at the same time. - work_queue_tx: crossbeam::channel::Sender<(StateRoot, TrieKey)>, - work_queue_rx: crossbeam::channel::Receiver<(StateRoot, TrieKey)>, + work_queue: WorkQueue, /// Prefetching IO threads will insert fetched data here. This is also used /// to mark what is already being fetched, to avoid fetching the same data /// multiple times. @@ -383,24 +383,33 @@ impl PrefetchApi { shard_uid: ShardUId, trie_config: &TrieConfig, ) -> Self { - let (work_queue_tx, work_queue_rx) = crossbeam::channel::bounded(MAX_QUEUED_WORK_ITEMS); + let (tx, rx) = crossbeam::channel::bounded(MAX_QUEUED_WORK_ITEMS); let sweat_prefetch_receivers = trie_config.sweat_prefetch_receivers.clone(); let sweat_prefetch_senders = trie_config.sweat_prefetch_senders.clone(); let enable_receipt_prefetching = trie_config.enable_receipt_prefetching; - - let this = Self { - work_queue_tx, - work_queue_rx, - prefetching: PrefetchStagingArea::new(shard_uid.shard_id()), + let prefetching = PrefetchStagingArea::new(shard_uid.shard_id()); + + let handles = (0..NUM_IO_THREADS) + .map(|_| { + Self::start_io_thread( + rx.clone(), + prefetching.clone(), + store.clone(), + shard_cache.clone(), + shard_uid.clone(), + ) + }) + .collect(); + // Do not clone tx before this point, or `WorkQueue` invariant is broken. + let work_queue = WorkQueue { rx, tx, _handles: Arc::new(JoinGuard(handles)) }; + Self { + work_queue, + prefetching, enable_receipt_prefetching, sweat_prefetch_receivers, sweat_prefetch_senders, shard_uid, - }; - for _ in 0..NUM_IO_THREADS { - this.start_io_thread(store.clone(), shard_cache.clone(), shard_uid.clone()); } - this } /// Returns the argument back if queue is full. @@ -409,18 +418,18 @@ impl PrefetchApi { root: StateRoot, trie_key: TrieKey, ) -> Result<(), (StateRoot, TrieKey)> { - self.work_queue_tx.send((root, trie_key)).map_err(|e| e.0) + self.work_queue.tx.send((root, trie_key)).map_err(|e| e.0) } - pub fn start_io_thread( - &self, + fn start_io_thread( + work_queue: crossbeam::channel::Receiver<(StateRoot, TrieKey)>, + prefetching: PrefetchStagingArea, store: Store, shard_cache: TrieCache, shard_uid: ShardUId, ) -> std::thread::JoinHandle<()> { let prefetcher_storage = - TriePrefetchingStorage::new(store, shard_uid, shard_cache, self.prefetching.clone()); - let work_queue = self.work_queue_rx.clone(); + TriePrefetchingStorage::new(store, shard_uid, shard_cache, prefetching); let metric_prefetch_sent = metrics::PREFETCH_SENT.with_label_values(&[&shard_uid.shard_id.to_string()]); let metric_prefetch_fail = @@ -451,7 +460,7 @@ impl PrefetchApi { /// Queued up work will not be finished. But trie keys that are already /// being fetched will finish. pub fn clear_queue(&self) { - while let Ok(_dropped) = self.work_queue_rx.try_recv() {} + while let Ok(_dropped) = self.work_queue.rx.try_recv() {} } /// Clear prefetched staging area from data that has not been picked up by the main thread. @@ -460,6 +469,47 @@ impl PrefetchApi { } } +/// Bounded, shared queue for all IO threads to take work from. +/// +/// Work items are defined as `TrieKey` because currently the only +/// work is to prefetch a trie key. If other IO work is added, consider +/// changing the queue to an enum. +/// The state root is also included because multiple chunks could be applied +/// at the same time. +#[derive(Clone)] +struct WorkQueue { + /// The channel to the IO prefetch work queue. + rx: crossbeam::channel::Receiver<(StateRoot, TrieKey)>, + tx: crossbeam::channel::Sender<(StateRoot, TrieKey)>, + /// Thread handles for threads sitting behind channel. + /// + /// Invariant: The number of existing clones of `tx` is equal to + /// the reference count of join handles. + /// + /// The invariant holds because when `WorkQueue` is created there is no + /// clone of `tx`, yet. And afterwards the only `tx` clones are through + /// `WorkQueue.clone()` which also increases the handles reference count. + /// + /// When the last reference to `handles` is dropped, the handles + /// are joined, which will terminate because the last `tx` has + /// already been dropped (field order matters!) and therefore the crossbeam + /// channel has been closed. + _handles: Arc, +} + +/// Only exists to implement `Drop`. +struct JoinGuard(Vec>); + +impl Drop for JoinGuard { + fn drop(&mut self) { + for handle in self.0.drain(..) { + if let Err(e) = handle.join() { + error!("Failed to join background thread: {e:?}") + } + } + } +} + fn prefetch_state_matches(expected: PrefetchSlot, actual: &PrefetchSlot) -> bool { match (expected, actual) { (PrefetchSlot::PendingPrefetch, PrefetchSlot::PendingPrefetch) @@ -498,7 +548,7 @@ mod tests { } pub fn work_queued(&self) -> bool { - !self.work_queue_rx.is_empty() + !self.work_queue.rx.is_empty() } }