Skip to content

Commit

Permalink
feat: proxy rerank api (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 9, 2024
1 parent 6996546 commit e5cc194
Showing 1 changed file with 64 additions and 2 deletions.
66 changes: 64 additions & 2 deletions src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
println!("Embeddings API: http://{addr}/v1/embeddings");
println!("Rerank API: http://{addr}/v1/rerank");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena?num=2");
shutdown_signal().await;
Expand Down Expand Up @@ -158,6 +159,8 @@ impl Server {
self.chat_completions(req).await
} else if path == "/v1/embeddings" {
self.embeddings(req).await
} else if path == "/v1/rerank" {
self.rerank(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
Expand Down Expand Up @@ -498,6 +501,57 @@ impl Server {
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}

async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;

debug!("rerank request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;

let RerankReqBody {
model: reranker_model_id,
documents,
query,
top_n,
} = req_body;

let top_n = top_n.unwrap_or(documents.len());

let config = Arc::new(RwLock::new(self.config.clone()));

let reranker_model = Model::retrieve_embedding(&config.read(), &reranker_model_id)?;

let client = init_client(&config, Some(reranker_model))?;
let data = client
.rerank(RerankData {
query,
documents: documents.clone(),
top_n,
})
.await?;

let results: Vec<_> = data
.into_iter()
.map(|v| {
json!({
"index": v.index,
"relevance_score": v.relevance_score,
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
})
})
.collect();
let output = json!({
"id": uuid::Uuid::new_v4().to_string(),
"results": results,
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
}

#[derive(Debug, Deserialize)]
Expand All @@ -520,8 +574,8 @@ struct ChatCompletionsReqBody {

#[derive(Debug, Deserialize)]
struct EmbeddingsReqBody {
pub input: EmbeddingsReqBodyInput,
pub model: String,
input: EmbeddingsReqBodyInput,
model: String,
}

#[derive(Debug, Deserialize)]
Expand All @@ -531,6 +585,14 @@ enum EmbeddingsReqBodyInput {
Multiple(Vec<String>),
}

#[derive(Debug, Deserialize)]
struct RerankReqBody {
documents: Vec<String>,
query: String,
model: String,
top_n: Option<usize>,
}

#[derive(Debug)]
enum ResEvent {
First(Option<String>),
Expand Down

0 comments on commit e5cc194

Please sign in to comment.