diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 96fb9d5c..4f65925f 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -19,7 +19,11 @@ use parking_lot::RwLock; use path_absolutize::Absolutize; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::{collections::HashMap, fmt::Debug, fs, path::Path}; +use std::{collections::HashMap, fmt::Debug, fs, path::Path, time::Duration}; +use tokio::time::sleep; + +const EMBEDDING_RETRY_LIMIT: usize = 3; +const RERANK_RETRY_LIMIT: usize = 2; pub struct Rag { config: GlobalConfig, @@ -483,7 +487,23 @@ impl Rag { } } let data = RerankData::new(query.to_string(), documents, top_k); - let list = client.rerank(&data).await?; + let mut retry = 0; + let list = loop { + retry += 1; + match client.rerank(&data).await { + Ok(result) => break result, + Err(e) if retry < RERANK_RETRY_LIMIT => { + debug!("retry {} failed: {}", retry, e); + sleep(Duration::from_secs(retry as _)).await; + continue; + } + Err(e) => { + return Err(e).with_context(|| { + format!("Failed to rerank after {RERANK_RETRY_LIMIT} attempts") + })? + } + } + }; let ids: Vec<_> = list .into_iter() .take(top_k) @@ -587,10 +607,25 @@ impl Rag { texts: texts.to_vec(), query, }; - let chunk_output = embedding_client - .embeddings(&chunk_data) - .await - .context("Failed to create embedding")?; + let mut retry = 0; + let chunk_output = loop { + retry += 1; + match embedding_client.embeddings(&chunk_data).await { + Ok(v) => break v, + Err(e) if retry < EMBEDDING_RETRY_LIMIT => { + debug!("retry {} failed: {}", retry, e); + sleep(Duration::from_secs(retry as _)).await; + continue; + } + Err(e) => { + return Err(e).with_context(|| { + format!( + "Failed to create embedding after {EMBEDDING_RETRY_LIMIT} attempts" + ) + })? + } + } + }; output.extend(chunk_output); } Ok(output)