From 0e740d81e94505bd57036755abaaecb12c3b26e3 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sun, 28 Jul 2024 06:04:36 +0800 Subject: [PATCH] feat: abandon rag_dedicated client and improve (#757) --- config.example.yaml | 4 +- src/client/azure_openai.rs | 73 ++++++----- src/client/bedrock.rs | 3 +- src/client/claude.rs | 44 ++++--- src/client/cloudflare.rs | 81 ++++++------ src/client/cohere.rs | 90 ++++++++------ src/client/common.rs | 77 +++++++----- src/client/ernie.rs | 194 ++++++++++++++--------------- src/client/gemini.rs | 83 +++++++------ src/client/macros.rs | 94 +++----------- src/client/mod.rs | 15 +-- src/client/ollama.rs | 75 +++++++----- src/client/openai.rs | 80 +++++++----- src/client/openai_compatible.rs | 176 ++++++++++++++++---------- src/client/qianwen.rs | 118 +++++++++--------- src/client/rag_dedicated.rs | 150 ----------------------- src/client/replicate.rs | 53 ++++---- src/client/vertexai.rs | 211 +++++++++++++++++--------------- 18 files changed, 776 insertions(+), 845 deletions(-) delete mode 100644 src/client/rag_dedicated.rs diff --git a/config.example.yaml b/config.example.yaml index 803ace79..29966a01 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -292,13 +292,13 @@ clients: api_key: xxx # ENV: {client}_API_KEY # See https://jina.ai - - type: rag-dedicated + - type: openai-compatible name: jina api_base: https://api.jina.ai/v1 api_key: xxx # ENV: {client}_API_KEY # See https://docs.voyageai.com/docs/introduction - - type: rag-dedicated + - type: openai-compatible name: voyageai api_base: https://api.voyageai.ai/v1 api_key: xxx # ENV: {client}_API_KEY \ No newline at end of file diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 8b583e03..de6bb40b 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -30,49 +30,56 @@ impl AzureOpenAIClient { PromptKind::Integer, ), ]; +} + +impl_client_trait!( + AzureOpenAIClient, + ( + prepare_chat_completions, + openai_chat_completions, + openai_chat_completions_streaming + ), + (prepare_embeddings, openai_embeddings), + (noop_prepare_rerank, noop_rerank), +); - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_base = self.get_api_base()?; - let api_key = self.get_api_key()?; +fn prepare_chat_completions( + self_: &AzureOpenAIClient, + data: ChatCompletionsData, +) -> Result { + let api_base = self_.get_api_base()?; + let api_key = self_.get_api_key()?; - let url = format!( - "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01", - &api_base, - self.model.name() - ); + let url = format!( + "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01", + &api_base, + self_.model.name() + ); - let body = openai_build_chat_completions_body(data, &self.model); + let body = openai_build_chat_completions_body(data, &self_.model); - let mut request_data = RequestData::new(url, body); + let mut request_data = RequestData::new(url, body); - request_data.header("api-key", api_key); + request_data.header("api-key", api_key); - Ok(request_data) - } + Ok(request_data) +} - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_base = self.get_api_base()?; - let api_key = self.get_api_key()?; +fn prepare_embeddings(self_: &AzureOpenAIClient, data: EmbeddingsData) -> Result { + let api_base = self_.get_api_base()?; + let api_key = self_.get_api_key()?; - let url = format!( - "{}/openai/deployments/{}/embeddings?api-version=2024-02-01", - &api_base, - self.model.name() - ); + let url = format!( + "{}/openai/deployments/{}/embeddings?api-version=2024-02-01", + &api_base, + self_.model.name() + ); - let body = openai_build_embeddings_body(data, &self.model); + let body = openai_build_embeddings_body(data, &self_.model); - let mut request_data = RequestData::new(url, body); + let mut request_data = RequestData::new(url, body); - request_data.header("api-key", api_key); + request_data.header("api-key", api_key); - Ok(request_data) - } + Ok(request_data) } - -impl_client_trait!( - AzureOpenAIClient, - openai_chat_completions, - openai_chat_completions_streaming, - openai_embeddings -); diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 2b1ca6ed..25c666c1 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -3,7 +3,6 @@ use super::*; use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256}; use anyhow::{bail, Context, Result}; -use async_trait::async_trait; use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; use aws_smithy_eventstream::smithy::parse_response_headers; use bytes::BytesMut; @@ -148,7 +147,7 @@ impl BedrockClient { } } -#[async_trait] +#[async_trait::async_trait] impl Client for BedrockClient { client_common_fns!(); diff --git a/src/client/claude.rs b/src/client/claude.rs index 8a7b14c4..d8046a9a 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -22,30 +22,41 @@ impl ClaudeClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; +} - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_key = self.get_api_key().ok(); +impl_client_trait!( + ClaudeClient, + ( + prepare_chat_completions, + claude_chat_completions, + claude_chat_completions_streaming + ), + (noop_prepare_embeddings, noop_embeddings), + (noop_prepare_rerank, noop_rerank), +); - let body = claude_build_chat_completions_body(data, &self.model)?; +fn prepare_chat_completions( + self_: &ClaudeClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key().ok(); - let mut request_data = RequestData::new(API_BASE, body); + let body = claude_build_chat_completions_body(data, &self_.model)?; - request_data.header("anthropic-version", "2023-06-01"); - if let Some(api_key) = api_key { - request_data.header("x-api-key", api_key) - } + let mut request_data = RequestData::new(API_BASE, body); - Ok(request_data) + request_data.header("anthropic-version", "2023-06-01"); + if let Some(api_key) = api_key { + request_data.header("x-api-key", api_key) } -} -impl_client_trait!( - ClaudeClient, - claude_chat_completions, - claude_chat_completions_streaming -); + Ok(request_data) +} -pub async fn claude_chat_completions(builder: RequestBuilder) -> Result { +pub async fn claude_chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -59,6 +70,7 @@ pub async fn claude_chat_completions(builder: RequestBuilder) -> Result Result<()> { let mut function_name = String::new(); let mut function_arguments = String::new(); diff --git a/src/client/cloudflare.rs b/src/client/cloudflare.rs index 3ee0a91c..a24a1c67 100644 --- a/src/client/cloudflare.rs +++ b/src/client/cloudflare.rs @@ -26,54 +26,64 @@ impl CloudflareClient { ("account_id", "Account ID:", true, PromptKind::String), ("api_key", "API Key:", true, PromptKind::String), ]; +} - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let account_id = self.get_account_id()?; - let api_key = self.get_api_key()?; +impl_client_trait!( + CloudflareClient, + ( + prepare_chat_completions, + chat_completions, + chat_completions_streaming + ), + (prepare_embeddings, embeddings), + (noop_prepare_rerank, noop_rerank), +); - let url = format!( - "{API_BASE}/accounts/{account_id}/ai/run/{}", - self.model.name() - ); +fn prepare_chat_completions( + self_: &CloudflareClient, + data: ChatCompletionsData, +) -> Result { + let account_id = self_.get_account_id()?; + let api_key = self_.get_api_key()?; - let body = build_chat_completions_body(data, &self.model)?; + let url = format!( + "{API_BASE}/accounts/{account_id}/ai/run/{}", + self_.model.name() + ); - let mut request_data = RequestData::new(url, body); + let body = build_chat_completions_body(data, &self_.model)?; - request_data.bearer_auth(api_key); + let mut request_data = RequestData::new(url, body); - Ok(request_data) - } + request_data.bearer_auth(api_key); + + Ok(request_data) +} - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let account_id = self.get_account_id()?; - let api_key = self.get_api_key()?; +fn prepare_embeddings(self_: &CloudflareClient, data: EmbeddingsData) -> Result { + let account_id = self_.get_account_id()?; + let api_key = self_.get_api_key()?; - let url = format!( - "{API_BASE}/accounts/{account_id}/ai/run/{}", - self.model.name() - ); + let url = format!( + "{API_BASE}/accounts/{account_id}/ai/run/{}", + self_.model.name() + ); - let body = json!({ - "text": data.texts, - }); + let body = json!({ + "text": data.texts, + }); - let mut request_data = RequestData::new(url, body); + let mut request_data = RequestData::new(url, body); - request_data.bearer_auth(api_key); + request_data.bearer_auth(api_key); - Ok(request_data) - } + Ok(request_data) } -impl_client_trait!( - CloudflareClient, - chat_completions, - chat_completions_streaming, - embeddings -); - -async fn chat_completions(builder: RequestBuilder) -> Result { +async fn chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -88,6 +98,7 @@ async fn chat_completions(builder: RequestBuilder) -> Result Result<()> { let handle = |message: SseMmessage| -> Result { if message.data == "[DONE]" { @@ -103,7 +114,7 @@ async fn chat_completions_streaming( sse_stream(builder, handle).await } -async fn embeddings(builder: RequestBuilder) -> Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; diff --git a/src/client/cohere.rs b/src/client/cohere.rs index b6e97553..aff919e6 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,5 +1,5 @@ -use super::rag_dedicated::*; use super::*; +use super::openai_compatible::*; use anyhow::{bail, Context, Result}; use reqwest::RequestBuilder; @@ -25,62 +25,71 @@ impl CohereClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; +} - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_key = self.get_api_key()?; +impl_client_trait!( + CohereClient, + ( + prepare_chat_completions, + chat_completions, + chat_completions_streaming + ), + (prepare_embeddings, embeddings), + (prepare_rerank, generic_rerank), +); - let body = build_chat_completions_body(data, &self.model)?; +fn prepare_chat_completions( + self_: &CohereClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; - let mut request_data = RequestData::new(CHAT_COMPLETIONS_API_URL, body); + let body = build_chat_completions_body(data, &self_.model)?; - request_data.bearer_auth(api_key); + let mut request_data = RequestData::new(CHAT_COMPLETIONS_API_URL, body); - Ok(request_data) - } + request_data.bearer_auth(api_key); - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_key = self.get_api_key()?; + Ok(request_data) +} - let input_type = match data.query { - true => "search_query", - false => "search_document", - }; +fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; + + let input_type = match data.query { + true => "search_query", + false => "search_document", + }; - let body = json!({ - "model": self.model.name(), - "texts": data.texts, - "input_type": input_type, - }); + let body = json!({ + "model": self_.model.name(), + "texts": data.texts, + "input_type": input_type, + }); - let mut request_data = RequestData::new(EMBEDDINGS_API_URL, body); + let mut request_data = RequestData::new(EMBEDDINGS_API_URL, body); - request_data.bearer_auth(api_key); + request_data.bearer_auth(api_key); - Ok(request_data) - } + Ok(request_data) +} - fn prepare_rerank(&self, data: RerankData) -> Result { - let api_key = self.get_api_key()?; +fn prepare_rerank(self_: &CohereClient, data: RerankData) -> Result { + let api_key = self_.get_api_key()?; - let body = rag_dedicated_build_rerank_body(data, &self.model); + let body = generic_build_rerank_body(data, &self_.model); - let mut request_data = RequestData::new(RERANK_API_URL, body); + let mut request_data = RequestData::new(RERANK_API_URL, body); - request_data.bearer_auth(api_key); + request_data.bearer_auth(api_key); - Ok(request_data) - } + Ok(request_data) } -impl_client_trait!( - CohereClient, - chat_completions, - chat_completions_streaming, - embeddings, - rag_dedicated_rerank -); - -async fn chat_completions(builder: RequestBuilder) -> Result { +async fn chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -95,6 +104,7 @@ async fn chat_completions(builder: RequestBuilder) -> Result Result<()> { let res = builder.send().await?; let status = res.status(); @@ -131,7 +141,7 @@ async fn chat_completions_streaming( Ok(()) } -async fn embeddings(builder: RequestBuilder) -> Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; diff --git a/src/client/common.rs b/src/client/common.rs index 33215112..21089419 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -8,7 +8,6 @@ use crate::{ }; use anyhow::{bail, Context, Result}; -use async_trait::async_trait; use fancy_regex::Regex; use indexmap::IndexMap; use lazy_static::lazy_static; @@ -25,7 +24,7 @@ lazy_static! { static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(? &GlobalConfig; @@ -110,6 +109,35 @@ pub trait Client: Sync + Send { .context("Failed to call rerank api") } + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result; + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()>; + + async fn embeddings_inner( + &self, + _client: &ReqwestClient, + _data: EmbeddingsData, + ) -> Result { + bail!("The client doesn't support embeddings api") + } + + async fn rerank_inner( + &self, + _client: &ReqwestClient, + _data: RerankData, + ) -> Result { + bail!("The client doesn't support rerank api") + } + fn request_builder( &self, client: &reqwest::Client, @@ -147,35 +175,6 @@ pub trait Client: Sync + Send { } } } - - async fn chat_completions_inner( - &self, - client: &ReqwestClient, - data: ChatCompletionsData, - ) -> Result; - - async fn chat_completions_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut SseHandler, - data: ChatCompletionsData, - ) -> Result<()>; - - async fn embeddings_inner( - &self, - _client: &ReqwestClient, - _data: EmbeddingsData, - ) -> Result { - bail!("The client doesn't support embeddings api") - } - - async fn rerank_inner( - &self, - _client: &ReqwestClient, - _data: RerankData, - ) -> Result { - bail!("The client doesn't support rerank api") - } } impl Default for ClientConfig { @@ -448,6 +447,22 @@ where Ok(()) } +pub fn noop_prepare_embeddings(_client: &T, _data: EmbeddingsData) -> Result { + bail!("The client doesn't support embeddings api") +} + +pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result { + bail!("The client doesn't support embeddings api") +} + +pub fn noop_prepare_rerank(_client: &T, _data: RerankData) -> Result { + bail!("The client doesn't support rerank api") +} + +pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result { + bail!("The client doesn't support rerank api") +} + pub fn catch_error(data: &Value, status: u16) -> Result<()> { if (200..300).contains(&status) { return Ok(()); diff --git a/src/client/ernie.rs b/src/client/ernie.rs index e5089788..090e34f2 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,13 +1,11 @@ use super::access_token::*; -use super::rag_dedicated::*; +use super::openai_compatible::*; use super::*; use anyhow::{anyhow, bail, Context, Result}; -use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; -use std::env; const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1"; const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token"; @@ -24,93 +22,15 @@ pub struct ErnieConfig { } impl ErnieClient { + config_get_fn!(api_key, get_api_key); + config_get_fn!(secret_key, get_secret_key); pub const PROMPTS: [PromptAction<'static>; 2] = [ ("api_key", "API Key:", true, PromptKind::String), ("secret_key", "Secret Key:", true, PromptKind::String), ]; - - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let access_token = get_access_token(self.name())?; - - let url = format!( - "{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}", - &self.model.name(), - ); - - let body = build_chat_completions_body(data, &self.model); - - let request_data = RequestData::new(url, body); - - Ok(request_data) - } - - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let access_token = get_access_token(self.name())?; - - let url = format!( - "{API_BASE}/wenxinworkshop/embeddings/{}?access_token={access_token}", - &self.model.name(), - ); - - let body = json!({ - "input": data.texts, - }); - - let request_data = RequestData::new(url, body); - - Ok(request_data) - } - - fn prepare_rerank(&self, data: RerankData) -> Result { - let access_token = get_access_token(self.name())?; - - let url = format!( - "{API_BASE}/wenxinworkshop/reranker/{}?access_token={access_token}", - &self.model.name(), - ); - - let RerankData { - query, - documents, - top_n, - } = data; - - let body = json!({ - "query": query, - "documents": documents, - "top_n": top_n - }); - - let request_data = RequestData::new(url, body); - - Ok(request_data) - } - - async fn prepare_access_token(&self) -> Result<()> { - let client_name = self.name(); - if !is_valid_access_token(client_name) { - let env_prefix = Self::name(&self.config).to_ascii_uppercase(); - let api_key = self.config.api_key.clone(); - let api_key = api_key - .or_else(|| env::var(format!("{env_prefix}_API_KEY")).ok()) - .ok_or_else(|| anyhow!("Miss api_key"))?; - - let secret_key = self.config.secret_key.clone(); - let secret_key = secret_key - .or_else(|| env::var(format!("{env_prefix}_SECRET_KEY")).ok()) - .ok_or_else(|| anyhow!("Miss secret_key"))?; - - let client = self.build_client()?; - let token = fetch_access_token(&client, &api_key, &secret_key) - .await - .with_context(|| "Failed to fetch access token")?; - set_access_token(client_name, token, 86400); - } - Ok(()) - } } -#[async_trait] +#[async_trait::async_trait] impl Client for ErnieClient { client_common_fns!(); @@ -119,10 +39,10 @@ impl Client for ErnieClient { client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { - self.prepare_access_token().await?; - let request_data = self.prepare_chat_completions(data)?; + prepare_access_token(self, client).await?; + let request_data = prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions(builder).await + chat_completions(builder, &self.model).await } async fn chat_completions_streaming_inner( @@ -131,10 +51,10 @@ impl Client for ErnieClient { handler: &mut SseHandler, data: ChatCompletionsData, ) -> Result<()> { - self.prepare_access_token().await?; - let request_data = self.prepare_chat_completions(data)?; + prepare_access_token(self, client).await?; + let request_data = prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions_streaming(builder, handler).await + chat_completions_streaming(builder, handler, &self.model).await } async fn embeddings_inner( @@ -142,20 +62,95 @@ impl Client for ErnieClient { client: &ReqwestClient, data: EmbeddingsData, ) -> Result { - self.prepare_access_token().await?; - let request_data = self.prepare_embeddings(data)?; + prepare_access_token(self, client).await?; + let request_data = prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Embeddings); - embeddings(builder).await + embeddings(builder, &self.model).await } async fn rerank_inner(&self, client: &ReqwestClient, data: RerankData) -> Result { - let request_data = self.prepare_rerank(data)?; + prepare_access_token(self, client).await?; + let request_data = prepare_rerank(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Rerank); - rerank(builder).await + rerank(builder, &self.model).await } } -async fn chat_completions(builder: RequestBuilder) -> Result { +fn prepare_chat_completions(self_: &ErnieClient, data: ChatCompletionsData) -> Result { + let access_token = get_access_token(self_.name())?; + + let url = format!( + "{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}", + self_.model.name(), + ); + + let body = build_chat_completions_body(data, &self_.model); + + let request_data = RequestData::new(url, body); + + Ok(request_data) +} + +fn prepare_embeddings(self_: &ErnieClient, data: EmbeddingsData) -> Result { + let access_token = get_access_token(self_.name())?; + + let url = format!( + "{API_BASE}/wenxinworkshop/embeddings/{}?access_token={access_token}", + self_.model.name(), + ); + + let body = json!({ + "input": data.texts, + }); + + let request_data = RequestData::new(url, body); + + Ok(request_data) +} + +fn prepare_rerank(self_: &ErnieClient, data: RerankData) -> Result { + let access_token = get_access_token(self_.name())?; + + let url = format!( + "{API_BASE}/wenxinworkshop/reranker/{}?access_token={access_token}", + self_.model.name(), + ); + + let RerankData { + query, + documents, + top_n, + } = data; + + let body = json!({ + "query": query, + "documents": documents, + "top_n": top_n + }); + + let request_data = RequestData::new(url, body); + + Ok(request_data) +} + +async fn prepare_access_token(self_: &ErnieClient, client: &ReqwestClient) -> Result<()> { + let client_name = self_.name(); + if !is_valid_access_token(client_name) { + let api_key = self_.get_api_key()?; + let secret_key = self_.get_secret_key()?; + + let token = fetch_access_token(client, &api_key, &secret_key) + .await + .with_context(|| "Failed to fetch access token")?; + set_access_token(client_name, token, 86400); + } + Ok(()) +} + +async fn chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; debug!("non-stream-data: {data}"); @@ -165,6 +160,7 @@ async fn chat_completions(builder: RequestBuilder) -> Result Result<()> { let handle = |message: SseMmessage| -> Result { let data: Value = serde_json::from_str(&message.data)?; @@ -188,7 +184,7 @@ async fn chat_completions_streaming( sse_stream(builder, handle).await } -async fn embeddings(builder: RequestBuilder) -> Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; let res_body: EmbeddingsResBody = @@ -207,10 +203,10 @@ struct EmbeddingsResBodyEmbedding { embedding: Vec, } -async fn rerank(builder: RequestBuilder) -> Result { +async fn rerank(builder: RequestBuilder, _model: &Model) -> Result { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; - let res_body: RagDedicatedRerankResBody = + let res_body: GenericRerankResBody = serde_json::from_value(data).context("Invalid rerank data")?; Ok(res_body.results) } diff --git a/src/client/gemini.rs b/src/client/gemini.rs index aa1a5b1c..26162183 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -23,57 +23,64 @@ impl GeminiClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; +} + +impl_client_trait!( + GeminiClient, + ( + prepare_chat_completions, + gemini_chat_completions, + gemini_chat_completions_streaming + ), + (prepare_embeddings, gemini_embeddings), + (noop_prepare_rerank, noop_rerank), +); - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_key = self.get_api_key()?; +fn prepare_chat_completions( + self_: &GeminiClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; - let func = match data.stream { - true => "streamGenerateContent", - false => "generateContent", - }; + let func = match data.stream { + true => "streamGenerateContent", + false => "generateContent", + }; - let url = format!("{API_BASE}{}:{}?key={}", &self.model.name(), func, api_key); + let url = format!("{API_BASE}{}:{}?key={}", self_.model.name(), func, api_key); - let body = gemini_build_chat_completions_body(data, &self.model)?; + let body = gemini_build_chat_completions_body(data, &self_.model)?; - let request_data = RequestData::new(url, body); + let request_data = RequestData::new(url, body); - Ok(request_data) - } + Ok(request_data) +} - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_key = self.get_api_key()?; +fn prepare_embeddings(self_: &GeminiClient, data: EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; - let url = format!( - "{API_BASE}{}:embedContent?key={}", - &self.model.name(), - api_key - ); + let url = format!( + "{API_BASE}{}:embedContent?key={}", + self_.model.name(), + api_key + ); - let body = json!({ - "content": { - "parts": [ - { - "text": data.texts[0], - } - ] - } - }); + let body = json!({ + "content": { + "parts": [ + { + "text": data.texts[0], + } + ] + } + }); - let request_data = RequestData::new(url, body); + let request_data = RequestData::new(url, body); - Ok(request_data) - } + Ok(request_data) } -impl_client_trait!( - GeminiClient, - gemini_chat_completions, - gemini_chat_completions_streaming, - gemini_embeddings -); - -async fn gemini_embeddings(builder: RequestBuilder) -> Result { +async fn gemini_embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; diff --git a/src/client/macros.rs b/src/client/macros.rs index 19e77343..67a2d836 100644 --- a/src/client/macros.rs +++ b/src/client/macros.rs @@ -54,8 +54,7 @@ macro_rules! register_client { if local_config.models.is_empty() { if let Some(models) = $crate::client::ALL_MODELS.iter().find(|v| { v.platform == $name || - ($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform)) || - ($name == RagDedicatedClient::NAME && local_config.name.as_deref() == Some(&v.platform)) + ($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform)) }) { return Model::from_config(client_name, &models.models); } @@ -161,71 +160,12 @@ macro_rules! client_common_fns { #[macro_export] macro_rules! impl_client_trait { - ($client:ident, $chat_completions:path, $chat_completions_streaming:path) => { - #[async_trait::async_trait] - impl $crate::client::Client for $crate::client::$client { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &reqwest::Client, - data: $crate::client::ChatCompletionsData, - ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { - let request_data = self.prepare_chat_completions(data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - $chat_completions(builder).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &reqwest::Client, - handler: &mut $crate::client::SseHandler, - data: $crate::client::ChatCompletionsData, - ) -> Result<()> { - let request_data = self.prepare_chat_completions(data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - $chat_completions_streaming(builder, handler).await - } - } - }; - ($client:ident, $chat_completions:path, $chat_completions_streaming:path, $embeddings:path) => { - #[async_trait::async_trait] - impl $crate::client::Client for $crate::client::$client { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &reqwest::Client, - data: $crate::client::ChatCompletionsData, - ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { - let request_data = self.prepare_chat_completions(data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - $chat_completions(builder).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &reqwest::Client, - handler: &mut $crate::client::SseHandler, - data: $crate::client::ChatCompletionsData, - ) -> Result<()> { - let request_data = self.prepare_chat_completions(data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - $chat_completions_streaming(builder, handler).await - } - - async fn embeddings_inner( - &self, - client: &reqwest::Client, - data: $crate::client::EmbeddingsData, - ) -> Result<$crate::client::EmbeddingsOutput> { - let request_data = self.prepare_embeddings(data)?; - let builder = self.request_builder(client, request_data, ApiType::Embeddings); - $embeddings(builder).await - } - } - }; - ($client:ident, $chat_completions:path, $chat_completions_streaming:path, $embeddings:path, $rerank:path) => { + ( + $client:ident, + ($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path), + ($prepare_embeddings:path, $embeddings:path), + ($prepare_rerank:path, $rerank:path), + ) => { #[async_trait::async_trait] impl $crate::client::Client for $crate::client::$client { client_common_fns!(); @@ -235,9 +175,9 @@ macro_rules! impl_client_trait { client: &reqwest::Client, data: $crate::client::ChatCompletionsData, ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { - let request_data = self.prepare_chat_completions(data)?; + let request_data = $prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - $chat_completions(builder).await + $chat_completions(builder, self.model()).await } async fn chat_completions_streaming_inner( @@ -246,9 +186,9 @@ macro_rules! impl_client_trait { handler: &mut $crate::client::SseHandler, data: $crate::client::ChatCompletionsData, ) -> Result<()> { - let request_data = self.prepare_chat_completions(data)?; + let request_data = $prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - $chat_completions_streaming(builder, handler).await + $chat_completions_streaming(builder, handler, self.model()).await } async fn embeddings_inner( @@ -256,9 +196,9 @@ macro_rules! impl_client_trait { client: &reqwest::Client, data: $crate::client::EmbeddingsData, ) -> Result<$crate::client::EmbeddingsOutput> { - let request_data = self.prepare_embeddings(data)?; + let request_data = $prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Embeddings); - $embeddings(builder).await + $embeddings(builder, self.model()).await } async fn rerank_inner( @@ -266,9 +206,9 @@ macro_rules! impl_client_trait { client: &reqwest::Client, data: $crate::client::RerankData, ) -> Result<$crate::client::RerankOutput> { - let request_data = self.prepare_rerank(data)?; + let request_data = $prepare_rerank(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Rerank); - $rerank(builder).await + $rerank(builder, self.model()).await } } }; @@ -286,9 +226,7 @@ macro_rules! config_get_fn { format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); std::env::var(&env_name).ok() }) - .ok_or_else(|| { - anyhow::anyhow!("Miss '{}' in client configuration", stringify!($field_name)) - }) + .ok_or_else(|| anyhow::anyhow!("Miss '{}'", stringify!($field_name))) } }; } diff --git a/src/client/mod.rs b/src/client/mod.rs index b2bd3def..80f526a7 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -22,12 +22,6 @@ register_client!( OpenAICompatibleConfig, OpenAICompatibleClient ), - ( - rag_dedicated, - "rag-dedicated", - RagDedicatedConfig, - RagDedicatedClient - ), (gemini, "gemini", GeminiConfig, GeminiClient), (claude, "claude", ClaudeConfig, ClaudeClient), (cohere, "cohere", CohereConfig, CohereClient), @@ -46,11 +40,13 @@ register_client!( (qianwen, "qianwen", QianwenConfig, QianwenClient), ); -pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 12] = [ +pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 14] = [ ("deepinfra", "https://api.deepinfra.com/v1/openai"), ("deepseek", "https://api.deepseek.com"), ("fireworks", "https://api.fireworks.ai/inference/v1"), ("groq", "https://api.groq.com/openai/v1"), + ("jina", "https://api.jina.ai/v1"), + ("lingyiwanwu", "https://api.lingyiwanwu.com/v1"), ("mistral", "https://api.mistral.ai/v1"), ("moonshot", "https://api.moonshot.cn/v1"), ("openrouter", "https://openrouter.ai/api/v1"), @@ -58,10 +54,5 @@ pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 12] = [ ("perplexity", "https://api.perplexity.ai"), ("together", "https://api.together.xyz/v1"), ("zhipuai", "https://open.bigmodel.cn/api/paas/v4"), - ("lingyiwanwu", "https://api.lingyiwanwu.com/v1"), -]; - -pub const RAG_DEDICATED_PLATFORMS: [(&str, &str); 2] = [ - ("jina", "https://api.jina.ai/v1"), ("voyageai", "https://api.voyageai.com/v1"), ]; diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 5c26b992..4c2f3445 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -31,53 +31,63 @@ impl OllamaClient { PromptKind::Integer, ), ]; +} - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_base = self.get_api_base()?; - let api_auth = self.get_api_auth().ok(); +impl_client_trait!( + OllamaClient, + ( + prepare_chat_completions, + chat_completions, + chat_completions_streaming + ), + (prepare_embeddings, embeddings), + (noop_prepare_rerank, noop_rerank), +); - let url = format!("{api_base}/api/chat"); +fn prepare_chat_completions( + self_: &OllamaClient, + data: ChatCompletionsData, +) -> Result { + let api_base = self_.get_api_base()?; + let api_auth = self_.get_api_auth().ok(); - let body = build_chat_completions_body(data, &self.model)?; + let url = format!("{api_base}/api/chat"); - let mut request_data = RequestData::new(url, body); + let body = build_chat_completions_body(data, &self_.model)?; - if let Some(api_auth) = api_auth { - request_data.header("Authorization", api_auth) - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + if let Some(api_auth) = api_auth { + request_data.header("Authorization", api_auth) } - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_base = self.get_api_base()?; - let api_auth = self.get_api_auth().ok(); + Ok(request_data) +} - let url = format!("{api_base}/api/embed"); +fn prepare_embeddings(self_: &OllamaClient, data: EmbeddingsData) -> Result { + let api_base = self_.get_api_base()?; + let api_auth = self_.get_api_auth().ok(); - let body = json!({ - "model": self.model.name(), - "input": data.texts, - }); + let url = format!("{api_base}/api/embed"); - let mut request_data = RequestData::new(url, body); + let body = json!({ + "model": self_.model.name(), + "input": data.texts, + }); - if let Some(api_auth) = api_auth { - request_data.header("Authorization", api_auth) - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + if let Some(api_auth) = api_auth { + request_data.header("Authorization", api_auth) } -} -impl_client_trait!( - OllamaClient, - chat_completions, - chat_completions_streaming, - embeddings -); + Ok(request_data) +} -async fn chat_completions(builder: RequestBuilder) -> Result { +async fn chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data = res.json().await?; @@ -92,6 +102,7 @@ async fn chat_completions(builder: RequestBuilder) -> Result Result<()> { let res = builder.send().await?; let status = res.status(); @@ -120,7 +131,7 @@ async fn chat_completions_streaming( Ok(()) } -async fn embeddings(builder: RequestBuilder) -> Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data = res.json().await?; diff --git a/src/client/openai.rs b/src/client/openai.rs index 2b83b7d1..ec9cb9de 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -25,45 +25,66 @@ impl OpenAIClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; +} - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_key = self.get_api_key()?; - let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string()); +impl_client_trait!( + OpenAIClient, + ( + prepare_chat_completions, + openai_chat_completions, + openai_chat_completions_streaming + ), + (prepare_embeddings, openai_embeddings), + (noop_prepare_rerank, noop_rerank), +); - let url = format!("{api_base}/chat/completions"); +fn prepare_chat_completions( + self_: &OpenAIClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); - let body = openai_build_chat_completions_body(data, &self.model); + let url = format!("{api_base}/chat/completions"); - let mut request_data = RequestData::new(url, body); + let body = openai_build_chat_completions_body(data, &self_.model); - request_data.bearer_auth(api_key); - if let Some(organization_id) = &self.config.organization_id { - request_data.header("OpenAI-Organization", organization_id); - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + request_data.bearer_auth(api_key); + if let Some(organization_id) = &self_.config.organization_id { + request_data.header("OpenAI-Organization", organization_id); } - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_key = self.get_api_key()?; - let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string()); + Ok(request_data) +} - let url = format!("{api_base}/embeddings"); +fn prepare_embeddings(self_: &OpenAIClient, data: EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); - let body = openai_build_embeddings_body(data, &self.model); + let url = format!("{api_base}/embeddings"); - let mut request_data = RequestData::new(url, body); + let body = openai_build_embeddings_body(data, &self_.model); - request_data.bearer_auth(api_key); - if let Some(organization_id) = &self.config.organization_id { - request_data.header("OpenAI-Organization", organization_id); - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + request_data.bearer_auth(api_key); + if let Some(organization_id) = &self_.config.organization_id { + request_data.header("OpenAI-Organization", organization_id); } + + Ok(request_data) } -pub async fn openai_chat_completions(builder: RequestBuilder) -> Result { +pub async fn openai_chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -78,6 +99,7 @@ pub async fn openai_chat_completions(builder: RequestBuilder) -> Result Result<()> { let mut function_index = 0; let mut function_name = String::new(); @@ -133,7 +155,10 @@ pub async fn openai_chat_completions_streaming( sse_stream(builder, handle).await } -pub async fn openai_embeddings(builder: RequestBuilder) -> Result { +pub async fn openai_embeddings( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -277,10 +302,3 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result Result { - let api_key = self.get_api_key().ok(); - let api_base = self.get_api_base_ext()?; +impl_client_trait!( + OpenAICompatibleClient, + ( + prepare_chat_completions, + openai_chat_completions, + openai_chat_completions_streaming + ), + (prepare_embeddings, openai_embeddings), + (prepare_rerank, generic_rerank), +); - let chat_endpoint = self - .config - .chat_endpoint - .as_deref() - .unwrap_or("/chat/completions"); +fn prepare_chat_completions( + self_: &OpenAICompatibleClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key().ok(); + let api_base = get_api_base_ext(self_)?; - let url = format!("{api_base}{chat_endpoint}"); + let chat_endpoint = self_ + .config + .chat_endpoint + .as_deref() + .unwrap_or("/chat/completions"); - let body = openai_build_chat_completions_body(data, &self.model); + let url = format!("{api_base}{chat_endpoint}"); - let mut request_data = RequestData::new(url, body); + let body = openai_build_chat_completions_body(data, &self_.model); - if let Some(api_key) = api_key { - request_data.bearer_auth(api_key); - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + if let Some(api_key) = api_key { + request_data.bearer_auth(api_key); } - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_key = self.get_api_key().ok(); - let api_base = self.get_api_base_ext()?; + Ok(request_data) +} - let url = format!("{api_base}/embeddings"); +fn prepare_embeddings(self_: &OpenAICompatibleClient, data: EmbeddingsData) -> Result { + let api_key = self_.get_api_key().ok(); + let api_base = get_api_base_ext(self_)?; - let body = openai_build_embeddings_body(data, &self.model); + let url = format!("{api_base}/embeddings"); - let mut request_data = RequestData::new(url, body); + let body = openai_build_embeddings_body(data, &self_.model); - if let Some(api_key) = api_key { - request_data.bearer_auth(api_key); - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + if let Some(api_key) = api_key { + request_data.bearer_auth(api_key); } - fn prepare_rerank(&self, data: RerankData) -> Result { - let api_key = self.get_api_key().ok(); - let api_base = self.get_api_base_ext()?; + Ok(request_data) +} - let url = format!("{api_base}/rerank"); +fn prepare_rerank(self_: &OpenAICompatibleClient, data: RerankData) -> Result { + let api_key = self_.get_api_key().ok(); + let api_base = get_api_base_ext(self_)?; - let body = rag_dedicated_build_rerank_body(data, &self.model); + let url = format!("{api_base}/rerank"); - let mut request_data = RequestData::new(url, body); + let body = generic_build_rerank_body(data, &self_.model); - if let Some(api_key) = api_key { - request_data.bearer_auth(api_key); - } + let mut request_data = RequestData::new(url, body); - Ok(request_data) + if let Some(api_key) = api_key { + request_data.bearer_auth(api_key); } - fn get_api_base_ext(&self) -> Result { - let api_base = match self.get_api_base() { - Ok(v) => v, - Err(err) => { - match OPENAI_COMPATIBLE_PLATFORMS - .into_iter() - .find_map(|(name, api_base)| { - if name == self.model.client_name() { - Some(api_base.to_string()) - } else { - None - } - }) { - Some(v) => v, - None => return Err(err), - } + Ok(request_data) +} + +fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result { + let api_base = match self_.get_api_base() { + Ok(v) => v, + Err(err) => { + match OPENAI_COMPATIBLE_PLATFORMS + .into_iter() + .find_map(|(name, api_base)| { + if name == self_.model.client_name() { + Some(api_base.to_string()) + } else { + None + } + }) { + Some(v) => v, + None => return Err(err), } - }; - Ok(api_base) + } + }; + Ok(api_base) +} + +pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result { + let res = builder.send().await?; + let status = res.status(); + let mut data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + if data.get("results").is_none() && data.get("data").is_some() { + if let Some(data_obj) = data.as_object_mut() { + if let Some(value) = data_obj.remove("data") { + data_obj.insert("results".to_string(), value); + } + } } + let res_body: GenericRerankResBody = + serde_json::from_value(data).context("Invalid rerank data")?; + Ok(res_body.results) } -impl_client_trait!( - OpenAICompatibleClient, - openai_chat_completions, - openai_chat_completions_streaming, - openai_embeddings, - rag_dedicated_rerank -); +#[derive(Deserialize)] +pub struct GenericRerankResBody { + pub results: RerankOutput, +} + +pub fn generic_build_rerank_body(data: RerankData, model: &Model) -> Value { + let RerankData { + query, + documents, + top_n, + } = data; + + let mut body = json!({ + "model": model.name(), + "query": query, + "documents": documents, + }); + if model.client_name() == "voyageai" { + body["top_k"] = top_n.into() + } else { + body["top_n"] = top_n.into() + } + body +} \ No newline at end of file diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 4aa67aee..3c246f9e 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -3,7 +3,6 @@ use super::*; use crate::utils::{base64_decode, sha256}; use anyhow::{anyhow, bail, Context, Result}; -use async_trait::async_trait; use reqwest::{ multipart::{Form, Part}, Client as ReqwestClient, RequestBuilder, @@ -36,60 +35,9 @@ impl QianwenClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - - fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result { - let api_key = self.get_api_key()?; - - let stream = data.stream; - - let url = match self.model.supports_vision() { - true => CHAT_COMPLETIONS_API_URL_VL, - false => CHAT_COMPLETIONS_API_URL, - }; - - let (body, has_upload) = build_chat_completions_body(data, &self.model)?; - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - if stream { - request_data.header("X-DashScope-SSE", "enable"); - } - if has_upload { - request_data.header("X-DashScope-OssResourceResolve", "enable"); - } - - Ok(request_data) - } - - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_key = self.get_api_key()?; - - let text_type = match data.query { - true => "query", - false => "document", - }; - - let body = json!({ - "model": self.model.name(), - "input": { - "texts": data.texts, - }, - "parameters": { - "text_type": text_type, - } - }); - - let mut request_data = RequestData::new(EMBEDDINGS_API_URL, body); - - request_data.bearer_auth(api_key); - - Ok(request_data) - } } -#[async_trait] +#[async_trait::async_trait] impl Client for QianwenClient { client_common_fns!(); @@ -100,7 +48,7 @@ impl Client for QianwenClient { ) -> Result { let api_key = self.get_api_key()?; patch_messages(self.model.name(), &api_key, &mut data.messages).await?; - let request_data = self.prepare_chat_completions(data)?; + let request_data = prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); chat_completions(builder, &self.model).await } @@ -113,7 +61,7 @@ impl Client for QianwenClient { ) -> Result<()> { let api_key = self.get_api_key()?; patch_messages(self.model.name(), &api_key, &mut data.messages).await?; - let request_data = self.prepare_chat_completions(data)?; + let request_data = prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); chat_completions_streaming(builder, handler, &self.model).await } @@ -123,12 +71,66 @@ impl Client for QianwenClient { client: &ReqwestClient, data: EmbeddingsData, ) -> Result>> { - let request_data = self.prepare_embeddings(data)?; + let request_data = prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Embeddings); - embeddings(builder).await + embeddings(builder, &self.model).await } } +fn prepare_chat_completions( + self_: &QianwenClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + + let stream = data.stream; + + let url = match self_.model().supports_vision() { + true => CHAT_COMPLETIONS_API_URL_VL, + false => CHAT_COMPLETIONS_API_URL, + }; + + let (body, has_upload) = build_chat_completions_body(data, &self_.model)?; + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + + if stream { + request_data.header("X-DashScope-SSE", "enable"); + } + if has_upload { + request_data.header("X-DashScope-OssResourceResolve", "enable"); + } + + Ok(request_data) +} + +fn prepare_embeddings(self_: &QianwenClient, data: EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; + + let text_type = match data.query { + true => "query", + false => "document", + }; + + let body = json!({ + "model": self_.model.name(), + "input": { + "texts": data.texts, + }, + "parameters": { + "text_type": text_type, + } + }); + + let mut request_data = RequestData::new(EMBEDDINGS_API_URL, body); + + request_data.bearer_auth(api_key); + + Ok(request_data) +} + async fn chat_completions(builder: RequestBuilder, model: &Model) -> Result { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; @@ -322,7 +324,7 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu Ok((body, has_upload)) } -async fn embeddings(builder: RequestBuilder) -> Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; let res_body: EmbeddingsResBody = diff --git a/src/client/rag_dedicated.rs b/src/client/rag_dedicated.rs deleted file mode 100644 index 7d2b8463..00000000 --- a/src/client/rag_dedicated.rs +++ /dev/null @@ -1,150 +0,0 @@ -use super::openai::*; -use super::*; - -use anyhow::bail; -use anyhow::Context; -use anyhow::Result; -use reqwest::RequestBuilder; -use serde::Deserialize; -use serde_json::json; -use serde_json::Value; - -#[derive(Debug, Clone, Deserialize)] -pub struct RagDedicatedConfig { - pub name: Option, - pub api_base: Option, - pub api_key: Option, - #[serde(default)] - pub models: Vec, - pub patch: Option, - pub extra: Option, -} - -impl RagDedicatedClient { - config_get_fn!(api_base, get_api_base); - config_get_fn!(api_key, get_api_key); - - pub const PROMPTS: [PromptAction<'static>; 0] = []; - - fn prepare_chat_completions(&self, _data: ChatCompletionsData) -> Result { - bail!("The client doesn't support chat-completions api"); - } - - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let api_key = self.get_api_key().ok(); - let api_base = self.get_api_base_ext()?; - - let url = format!("{api_base}/embeddings"); - - let body = openai_build_embeddings_body(data, &self.model); - - let mut request_data = RequestData::new(url, body); - - if let Some(api_key) = api_key { - request_data.bearer_auth(api_key); - } - - Ok(request_data) - } - - fn prepare_rerank(&self, data: RerankData) -> Result { - let api_key = self.get_api_key().ok(); - let api_base = self.get_api_base_ext()?; - - let url = format!("{api_base}/rerank"); - - let body = rag_dedicated_build_rerank_body(data, &self.model); - - let mut request_data = RequestData::new(url, body); - - if let Some(api_key) = api_key { - request_data.bearer_auth(api_key); - } - - Ok(request_data) - } - - fn get_api_base_ext(&self) -> Result { - let api_base = match self.get_api_base() { - Ok(v) => v, - Err(err) => { - match RAG_DEDICATED_PLATFORMS - .into_iter() - .find_map(|(name, api_base)| { - if name == self.model.client_name() { - Some(api_base.to_string()) - } else { - None - } - }) { - Some(v) => v, - None => return Err(err), - } - } - }; - Ok(api_base) - } -} - -impl_client_trait!( - RagDedicatedClient, - no_chat_completions, - no_chat_completions_streaming, - openai_embeddings, - rag_dedicated_rerank -); - -pub async fn no_chat_completions(_builder: RequestBuilder) -> Result { - bail!("The client doesn't support chat-completions api"); -} - -pub async fn no_chat_completions_streaming( - _builder: RequestBuilder, - _handler: &mut SseHandler, -) -> Result<()> { - bail!("The client doesn't support chat-completions api") -} - -pub async fn rag_dedicated_rerank(builder: RequestBuilder) -> Result { - let res = builder.send().await?; - let status = res.status(); - let mut data: Value = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - if data.get("results").is_none() && data.get("data").is_some() { - if let Some(data_obj) = data.as_object_mut() { - if let Some(value) = data_obj.remove("data") { - data_obj.insert("results".to_string(), value); - } - } - } - let res_body: RagDedicatedRerankResBody = - serde_json::from_value(data).context("Invalid rerank data")?; - Ok(res_body.results) -} - -#[derive(Deserialize)] -pub struct RagDedicatedRerankResBody { - pub results: RerankOutput, -} - -pub fn rag_dedicated_build_rerank_body(data: RerankData, model: &Model) -> Value { - let RerankData { - query, - documents, - top_n, - } = data; - - let mut body = json!({ - "model": model.name(), - "query": query, - "documents": documents, - }); - if model.client_name() == "voyageai" { - body["top_k"] = top_n.into() - } else { - body["top_n"] = top_n.into() - } - body -} diff --git a/src/client/replicate.rs b/src/client/replicate.rs index d6ca401a..f0db24ad 100644 --- a/src/client/replicate.rs +++ b/src/client/replicate.rs @@ -2,7 +2,6 @@ use super::prompt_format::*; use super::*; use anyhow::{anyhow, Result}; -use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -25,25 +24,9 @@ impl ReplicateClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - - fn prepare_chat_completions( - &self, - data: ChatCompletionsData, - api_key: &str, - ) -> Result { - let url = format!("{API_BASE}/models/{}/predictions", self.model.name()); - - let body = build_chat_completions_body(data, &self.model)?; - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - Ok(request_data) - } } -#[async_trait] +#[async_trait::async_trait] impl Client for ReplicateClient { client_common_fns!(); @@ -52,10 +35,9 @@ impl Client for ReplicateClient { client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { - let api_key = self.get_api_key()?; - let request_data = self.prepare_chat_completions(data, &api_key)?; + let request_data = prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions(client, builder, &api_key).await + chat_completions(builder, client, &self.get_api_key()?).await } async fn chat_completions_streaming_inner( @@ -64,16 +46,32 @@ impl Client for ReplicateClient { handler: &mut SseHandler, data: ChatCompletionsData, ) -> Result<()> { - let api_key = self.get_api_key()?; - let request_data = self.prepare_chat_completions(data, &api_key)?; + let request_data = prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions_streaming(client, builder, handler).await + chat_completions_streaming(builder, handler, client).await } } +fn prepare_chat_completions( + self_: &ReplicateClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + + let url = format!("{API_BASE}/models/{}/predictions", self_.model.name()); + + let body = build_chat_completions_body(data, &self_.model)?; + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + + Ok(request_data) +} + async fn chat_completions( - client: &ReqwestClient, builder: RequestBuilder, + client: &ReqwestClient, api_key: &str, ) -> Result { let res = builder.send().await?; @@ -106,9 +104,9 @@ async fn chat_completions( } async fn chat_completions_streaming( - client: &ReqwestClient, builder: RequestBuilder, handler: &mut SseHandler, + client: &ReqwestClient, ) -> Result<()> { let res = builder.send().await?; let status = res.status(); @@ -126,6 +124,9 @@ async fn chat_completions_streaming( if message.event == "done" { return Ok(true); } + + debug!("stream-data: {}", message.data); + handler.text(&message.data)?; Ok(false) }; diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 4ad07b83..5349eaf2 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -4,7 +4,6 @@ use super::openai::*; use super::*; use anyhow::{anyhow, bail, Context, Result}; -use async_trait::async_trait; use chrono::{Duration, Utc}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -31,93 +30,9 @@ impl VertexAIClient { ("project_id", "Project ID", true, PromptKind::String), ("location", "Location", true, PromptKind::String), ]; - - fn prepare_chat_completions( - &self, - data: ChatCompletionsData, - model_category: &ModelCategory, - ) -> Result { - let project_id = self.get_project_id()?; - let location = self.get_location()?; - let access_token = get_access_token(self.name())?; - - let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); - - let model_name = self.model.name(); - - let url = match model_category { - ModelCategory::Gemini => { - let func = match data.stream { - true => "streamGenerateContent", - false => "generateContent", - }; - format!("{base_url}/google/models/{model_name}:{func}") - } - ModelCategory::Claude => { - format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") - } - ModelCategory::Mistral => { - let func = match data.stream { - true => "streamRawPredict", - false => "rawPredict", - }; - format!("{base_url}/mistralai/models/{model_name}:{func}") - } - }; - - let body = match model_category { - ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self.model)?, - ModelCategory::Claude => { - let mut body = claude_build_chat_completions_body(data, &self.model)?; - if let Some(body_obj) = body.as_object_mut() { - body_obj.remove("model"); - } - body["anthropic_version"] = "vertex-2023-10-16".into(); - body - } - ModelCategory::Mistral => { - let mut body = openai_build_chat_completions_body(data, &self.model); - if let Some(body_obj) = body.as_object_mut() { - body_obj["model"] = strip_model_version(self.model.name()).into(); - } - body - } - }; - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(access_token); - - Ok(request_data) - } - - fn prepare_embeddings(&self, data: EmbeddingsData) -> Result { - let project_id = self.get_project_id()?; - let location = self.get_location()?; - let access_token = get_access_token(self.name())?; - - let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); - let url = format!("{base_url}/google/models/{}:predict", self.model.name()); - - let instances: Vec<_> = data - .texts - .into_iter() - .map(|v| json!({"content": v})) - .collect(); - - let body = json!({ - "instances": instances, - }); - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(access_token); - - Ok(request_data) - } } -#[async_trait] +#[async_trait::async_trait] impl Client for VertexAIClient { client_common_fns!(); @@ -127,13 +42,14 @@ impl Client for VertexAIClient { data: ChatCompletionsData, ) -> Result { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; - let model_category = ModelCategory::from_str(self.model.name())?; - let request_data = self.prepare_chat_completions(data, &model_category)?; + let model = self.model(); + let model_category = ModelCategory::from_str(model.name())?; + let request_data = prepare_chat_completions(self, data, &model_category)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); match model_category { - ModelCategory::Gemini => gemini_chat_completions(builder).await, - ModelCategory::Claude => claude_chat_completions(builder).await, - ModelCategory::Mistral => openai_chat_completions(builder).await, + ModelCategory::Gemini => gemini_chat_completions(builder, model).await, + ModelCategory::Claude => claude_chat_completions(builder, model).await, + ModelCategory::Mistral => openai_chat_completions(builder, model).await, } } @@ -144,13 +60,20 @@ impl Client for VertexAIClient { data: ChatCompletionsData, ) -> Result<()> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; - let model_category = ModelCategory::from_str(self.model.name())?; - let request_data = self.prepare_chat_completions(data, &model_category)?; + let model = self.model(); + let model_category = ModelCategory::from_str(model.name())?; + let request_data = prepare_chat_completions(self, data, &model_category)?; let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); match model_category { - ModelCategory::Gemini => gemini_chat_completions_streaming(builder, handler).await, - ModelCategory::Claude => claude_chat_completions_streaming(builder, handler).await, - ModelCategory::Mistral => openai_chat_completions_streaming(builder, handler).await, + ModelCategory::Gemini => { + gemini_chat_completions_streaming(builder, handler, model).await + } + ModelCategory::Claude => { + claude_chat_completions_streaming(builder, handler, model).await + } + ModelCategory::Mistral => { + openai_chat_completions_streaming(builder, handler, model).await + } } } @@ -160,13 +83,100 @@ impl Client for VertexAIClient { data: EmbeddingsData, ) -> Result>> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; - let request_data = self.prepare_embeddings(data)?; + let request_data = prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Embeddings); - embeddings(builder).await + embeddings(builder, self.model()).await } } -pub async fn gemini_chat_completions(builder: RequestBuilder) -> Result { +fn prepare_chat_completions( + self_: &VertexAIClient, + data: ChatCompletionsData, + model_category: &ModelCategory, +) -> Result { + let project_id = self_.get_project_id()?; + let location = self_.get_location()?; + let access_token = get_access_token(self_.name())?; + + let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); + + let model_name = self_.model.name(); + + let url = match model_category { + ModelCategory::Gemini => { + let func = match data.stream { + true => "streamGenerateContent", + false => "generateContent", + }; + format!("{base_url}/google/models/{model_name}:{func}") + } + ModelCategory::Claude => { + format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") + } + ModelCategory::Mistral => { + let func = match data.stream { + true => "streamRawPredict", + false => "rawPredict", + }; + format!("{base_url}/mistralai/models/{model_name}:{func}") + } + }; + + let body = match model_category { + ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?, + ModelCategory::Claude => { + let mut body = claude_build_chat_completions_body(data, &self_.model)?; + if let Some(body_obj) = body.as_object_mut() { + body_obj.remove("model"); + } + body["anthropic_version"] = "vertex-2023-10-16".into(); + body + } + ModelCategory::Mistral => { + let mut body = openai_build_chat_completions_body(data, &self_.model); + if let Some(body_obj) = body.as_object_mut() { + body_obj["model"] = strip_model_version(self_.model.name()).into(); + } + body + } + }; + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(access_token); + + Ok(request_data) +} + +fn prepare_embeddings(self_: &VertexAIClient, data: EmbeddingsData) -> Result { + let project_id = self_.get_project_id()?; + let location = self_.get_location()?; + let access_token = get_access_token(self_.name())?; + + let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"); + let url = format!("{base_url}/google/models/{}:predict", self_.model.name()); + + let instances: Vec<_> = data + .texts + .into_iter() + .map(|v| json!({"content": v})) + .collect(); + + let body = json!({ + "instances": instances, + }); + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(access_token); + + Ok(request_data) +} + +pub async fn gemini_chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -180,6 +190,7 @@ pub async fn gemini_chat_completions(builder: RequestBuilder) -> Result Result<()> { let res = builder.send().await?; let status = res.status(); @@ -217,7 +228,7 @@ pub async fn gemini_chat_completions_streaming( Ok(()) } -async fn embeddings(builder: RequestBuilder) -> Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?;