Skip to content

Commit

Permalink
refactor try_join to use tokio instead of Runtime (#31110)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: d6393400a503d15cb8d10c28be69ba742413b269
  • Loading branch information
ldanilek authored and Convex, Inc. committed Oct 28, 2024
1 parent 9a18c5e commit 151b69a
Show file tree
Hide file tree
Showing 29 changed files with 197 additions and 291 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions crates/application/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,11 +826,11 @@ impl<RT: Runtime> Application<RT> {
.await
}

pub fn snapshot(&self, ts: RepeatableTimestamp) -> anyhow::Result<Snapshot<RT>> {
pub fn snapshot(&self, ts: RepeatableTimestamp) -> anyhow::Result<Snapshot> {
self.database.snapshot(ts)
}

pub fn latest_snapshot(&self) -> anyhow::Result<Snapshot<RT>> {
pub fn latest_snapshot(&self) -> anyhow::Result<Snapshot> {
self.database.latest_snapshot()
}

Expand Down
1 change: 1 addition & 0 deletions crates/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ strum = { workspace = true }
sync_types = { package = "convex_sync_types", path = "../convex/sync_types" }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-metrics-collector = { workspace = true }
tokio-stream = { workspace = true }
tonic = { workspace = true }
tonic-health = { workspace = true }
Expand Down
61 changes: 52 additions & 9 deletions crates/common/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! implementations for test, dev, prod, etc.
use std::{
collections::HashMap,
future::Future,
hash::Hash,
num::TryFromIntError,
Expand All @@ -10,6 +11,7 @@ use std::{
Sub,
},
pin::Pin,
sync::LazyLock,
time::{
Duration,
SystemTime,
Expand Down Expand Up @@ -40,18 +42,21 @@ use governor::{
},
Quota,
};
use metrics::CONVEX_METRICS_REGISTRY;
use minitrace::{
collector::SpanContext,
full_name,
future::FutureExt as MinitraceFutureExt,
Span,
};
use parking_lot::Mutex;
#[cfg(any(test, feature = "testing"))]
use proptest::prelude::*;
use rand::RngCore;
use serde::Serialize;
use thiserror::Error;
use tokio::sync::oneshot;
use tokio_metrics_collector::TaskMonitor;
use uuid::Uuid;
use value::heap_size::HeapSize;

Expand Down Expand Up @@ -110,7 +115,6 @@ pub async fn try_join_buffered<
T: Send + 'static,
C: Default + Send + 'static + Extend<T>,
>(
rt: &RT,
name: &'static str,
tasks: impl Iterator<Item = impl Future<Output = anyhow::Result<T>> + Send + 'static>
+ Send
Expand All @@ -121,7 +125,7 @@ pub async fn try_join_buffered<
let span = SpanContext::current_local_parent()
.map(|ctx| Span::root(format!("{}::{name}", full_name!()), ctx))
.unwrap_or(Span::noop());
assert_send(try_join(rt, name, assert_send(task), span))
assert_send(try_join(name, assert_send(task), span))
}))
.buffered(JOIN_BUFFER_SIZE)
.try_collect(),
Expand All @@ -139,11 +143,9 @@ fn assert_send<'a, T>(
}

pub async fn try_join_buffer_unordered<
RT: Runtime,
T: Send + 'static,
C: Default + Send + 'static + Extend<T>,
>(
rt: &RT,
name: &'static str,
tasks: impl Iterator<Item = impl Future<Output = anyhow::Result<T>> + Send + 'static>
+ Send
Expand All @@ -154,30 +156,29 @@ pub async fn try_join_buffer_unordered<
let span = SpanContext::current_local_parent()
.map(|ctx| Span::root(format!("{}::{name}", full_name!()), ctx))
.unwrap_or(Span::noop());
try_join(rt, name, task, span)
try_join(name, task, span)
}))
.buffer_unordered(JOIN_BUFFER_SIZE)
.try_collect(),
)
.await
}

pub async fn try_join<RT: Runtime, T: Send + 'static>(
rt: &RT,
pub async fn try_join<T: Send + 'static>(
name: &'static str,
fut: impl Future<Output = anyhow::Result<T>> + Send + 'static,
span: Span,
) -> anyhow::Result<T> {
let (tx, rx) = oneshot::channel();
let mut handle = rt.spawn(
let handle = tokio_spawn(
name,
async {
let result = fut.await;
let _ = tx.send(result);
}
.in_span(span),
);
handle.join().await?;
handle.await?;
rx.await?
}

Expand Down Expand Up @@ -461,3 +462,45 @@ pub struct TimeoutError {
description: &'static str,
duration: Duration,
}

/// Transitional function while we move away from using our own special
/// `spawn`. Just wraps `tokio::spawn` with our tokio metrics
/// integration.
pub fn tokio_spawn<F>(name: &'static str, f: F) -> tokio::task::JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let monitor = GLOBAL_TASK_MANAGER.lock().get(name);
tokio::spawn(monitor.instrument(f))
}

pub static GLOBAL_TASK_MANAGER: LazyLock<Mutex<TaskManager>> = LazyLock::new(|| {
let task_collector = tokio_metrics_collector::default_task_collector();
CONVEX_METRICS_REGISTRY
.register(Box::new(task_collector))
.unwrap();

let manager = TaskManager {
monitors: HashMap::new(),
};
Mutex::new(manager)
});

pub struct TaskManager {
monitors: HashMap<&'static str, TaskMonitor>,
}

impl TaskManager {
pub fn get(&mut self, name: &'static str) -> TaskMonitor {
if let Some(monitor) = self.monitors.get(name) {
return monitor.clone();
}
let monitor = TaskMonitor::new();
self.monitors.insert(name, monitor.clone());
tokio_metrics_collector::default_task_collector()
.add(name, monitor.clone())
.expect("Duplicate task label?");
monitor
}
}
37 changes: 20 additions & 17 deletions crates/database/src/committer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ pub struct Committer<RT: Runtime> {
// External log of writes for subscriptions.
log: LogWriter,

snapshot_manager: Writer<SnapshotManager<RT>>,
snapshot_manager: Writer<SnapshotManager>,
persistence: Arc<dyn Persistence>,
runtime: RT,

Expand All @@ -167,12 +167,12 @@ pub struct Committer<RT: Runtime> {
impl<RT: Runtime> Committer<RT> {
pub(crate) fn start(
log: LogWriter,
snapshot_manager: Writer<SnapshotManager<RT>>,
snapshot_manager: Writer<SnapshotManager>,
persistence: Arc<dyn Persistence>,
runtime: RT,
retention_validator: Arc<dyn RetentionValidator>,
shutdown: ShutdownSignal,
) -> CommitterClient<RT> {
) -> CommitterClient {
let persistence_reader = persistence.reader();
let conflict_checker = PendingWrites::new(persistence_reader.version());
let (tx, rx) = mpsc::channel(*COMMITTER_QUEUE_SIZE);
Expand All @@ -198,7 +198,7 @@ impl<RT: Runtime> Committer<RT> {
}
}

async fn go(mut self, mut rx: mpsc::Receiver<CommitterMessage<RT>>) {
async fn go(mut self, mut rx: mpsc::Receiver<CommitterMessage>) {
let mut last_bumped_repeatable_ts = self.runtime.monotonic_now();
// Assume there were commits just before the backend restarted, so first do a
// quick bump.
Expand Down Expand Up @@ -316,7 +316,7 @@ impl<RT: Runtime> Committer<RT> {
text_index_manager,
vector_index_manager,
tables_with_indexes,
}: &mut BootstrappedSearchIndexes<RT>,
}: &mut BootstrappedSearchIndexes,
bootstrap_ts: Timestamp,
persistence: RepeatablePersistence,
registry: &IndexRegistry,
Expand Down Expand Up @@ -360,7 +360,7 @@ impl<RT: Runtime> Committer<RT> {

async fn finish_search_and_vector_bootstrap(
&mut self,
mut bootstrapped_indexes: BootstrappedSearchIndexes<RT>,
mut bootstrapped_indexes: BootstrappedSearchIndexes,
bootstrap_ts: RepeatableTimestamp,
result: oneshot::Sender<anyhow::Result<()>>,
) {
Expand Down Expand Up @@ -851,15 +851,15 @@ struct ValidatedDocumentWrite {
doc_in_vector_index: DocInVectorIndex,
}

pub struct CommitterClient<RT: Runtime> {
pub struct CommitterClient {
handle: Arc<Mutex<Box<dyn SpawnHandle>>>,
sender: mpsc::Sender<CommitterMessage<RT>>,
sender: mpsc::Sender<CommitterMessage>,
persistence_reader: Arc<dyn PersistenceReader>,
retention_validator: Arc<dyn RetentionValidator>,
snapshot_reader: Reader<SnapshotManager<RT>>,
snapshot_reader: Reader<SnapshotManager>,
}

impl<RT: Runtime> Clone for CommitterClient<RT> {
impl Clone for CommitterClient {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
Expand All @@ -871,10 +871,10 @@ impl<RT: Runtime> Clone for CommitterClient<RT> {
}
}

impl<RT: Runtime> CommitterClient<RT> {
impl CommitterClient {
pub async fn finish_search_and_vector_bootstrap(
&self,
bootstrapped_indexes: BootstrappedSearchIndexes<RT>,
bootstrapped_indexes: BootstrappedSearchIndexes,
bootstrap_ts: RepeatableTimestamp,
) -> anyhow::Result<()> {
let (tx, rx) = oneshot::channel();
Expand Down Expand Up @@ -906,7 +906,7 @@ impl<RT: Runtime> CommitterClient<RT> {
rx.await.map_err(|_| metrics::shutdown_error())?
}

pub fn commit(
pub fn commit<RT: Runtime>(
&self,
transaction: Transaction<RT>,
write_source: WriteSource,
Expand All @@ -915,7 +915,7 @@ impl<RT: Runtime> CommitterClient<RT> {
}

#[minitrace::trace]
async fn _commit(
async fn _commit<RT: Runtime>(
&self,
transaction: Transaction<RT>,
write_source: WriteSource,
Expand Down Expand Up @@ -968,7 +968,10 @@ impl<RT: Runtime> CommitterClient<RT> {
Ok(rx.await?)
}

async fn check_generated_ids(&self, transaction: &Transaction<RT>) -> anyhow::Result<()> {
async fn check_generated_ids<RT: Runtime>(
&self,
transaction: &Transaction<RT>,
) -> anyhow::Result<()> {
// Check that none of the DocumentIds generated in this transaction
// are already in use.
// We can check at the begin_timestamp+1 because generated_ids are also
Expand Down Expand Up @@ -1020,7 +1023,7 @@ impl<RT: Runtime> CommitterClient<RT> {
}
}

enum CommitterMessage<RT: Runtime> {
enum CommitterMessage {
Commit {
queue_timer: Timer<VMHistogram>,
transaction: FinalTransaction,
Expand All @@ -1035,7 +1038,7 @@ enum CommitterMessage<RT: Runtime> {
result: oneshot::Sender<anyhow::Result<()>>,
},
FinishTextAndVectorBootstrap {
bootstrapped_indexes: BootstrappedSearchIndexes<RT>,
bootstrapped_indexes: BootstrappedSearchIndexes,
bootstrap_ts: RepeatableTimestamp,
result: oneshot::Sender<anyhow::Result<()>>,
},
Expand Down
Loading

0 comments on commit 151b69a

Please sign in to comment.