diff --git a/README.md b/README.md index f896c13a..cb64d4a8 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,6 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases), - [x] LocalAI: user deployed opensource LLMs - [x] Azure-OpenAI: user created gpt3.5/gpt4 - [x] Gemini: gemini-pro/gemini-pro-vision/gemini-ultra -- [x] PaLM: chat-bison-001 - [x] Ernie: ernie-bot-turbo/ernie-bot/ernie-bot-8k/ernie-bot-4 - [x] Qianwen: qwen-turbo/qwen-plus/qwen-max diff --git a/config.example.yaml b/config.example.yaml index 847becac..502969b9 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -43,10 +43,6 @@ clients: - type: gemini api_key: AIxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - # See https://developers.generativeai.google/guide - - type: palm - api_key: AIxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - # See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html - type: ernie api_key: xxxxxxxxxxxxxxxxxxxxxxxx diff --git a/src/client/common.rs b/src/client/common.rs index 98cffa20..9ff02ba5 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -292,6 +292,7 @@ pub fn create_config(list: &[PromptType], client: &str) -> Result { Ok(clients) } +#[allow(unused)] pub async fn send_message_as_streaming( builder: RequestBuilder, handler: &mut ReplyHandler, diff --git a/src/client/mod.rs b/src/client/mod.rs index dc6b0938..d2ada444 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -17,7 +17,6 @@ register_client!( AzureOpenAIClient ), (gemini, "gemini", GeminiConfig, GeminiClient), - (palm, "palm", PaLMConfig, PaLMClient), (ernie, "ernie", ErnieConfig, ErnieClient), (qianwen, "qianwen", QianwenConfig, QianwenClient), ); diff --git a/src/client/palm.rs b/src/client/palm.rs deleted file mode 100644 index a5197680..00000000 --- a/src/client/palm.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::{PaLMClient, Client, ExtraConfig, Model, PromptType, SendData, TokensCountFactors, send_message_as_streaming, patch_system_message}; - -use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind}; - -use anyhow::{anyhow, bail, Result}; -use async_trait::async_trait; -use reqwest::{Client as ReqwestClient, RequestBuilder}; -use serde::Deserialize; -use serde_json::{json, Value}; - -const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta2/models/"; - -const MODELS: [(&str, usize); 1] = [("chat-bison-001", 4096)]; - -const TOKENS_COUNT_FACTORS: TokensCountFactors = (3, 8); - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct PaLMConfig { - pub name: Option, - pub api_key: Option, - pub extra: Option, -} - -#[async_trait] -impl Client for PaLMClient { - fn config(&self) -> (&GlobalConfig, &Option) { - (&self.global_config, &self.config.extra) - } - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - send_message_as_streaming(builder, handler, send_message).await - } -} - -impl PaLMClient { - config_get_fn!(api_key, get_api_key); - - pub const PROMPTS: [PromptType<'static>; 1] = - [("api_key", "API Key:", true, PromptKind::String)]; - - pub fn list_models(local_config: &PaLMConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_tokens)| { - Model::new(client_name, name) - .set_max_tokens(Some(max_tokens)) - .set_tokens_count_factors(TOKENS_COUNT_FACTORS) - }) - .collect() - } - - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { - let api_key = self.get_api_key()?; - - let body = build_body(data, self.model.name.clone()); - - let model = self.model.name.clone(); - - let url = format!("{API_BASE}{}:generateMessage?key={}", model, api_key); - - debug!("PaLM Request: {url} {body}"); - - let builder = client.post(url).json(&body); - - Ok(builder) - } -} - -async fn send_message(builder: RequestBuilder) -> Result { - let data: Value = builder.send().await?.json().await?; - check_error(&data)?; - - let output = data["candidates"][0]["content"] - .as_str() - .ok_or_else(|| { - if let Some(reason) = data["filters"][0]["reason"].as_str() { - anyhow!("Content Filtering: {reason}") - } else { - anyhow!("Unexpected response") - } - })?; - - Ok(output.to_string()) -} - -fn check_error(data: &Value) -> Result<()> { - if let Some(error) = data["error"].as_object() { - if let Some(message) = error["message"].as_str() { - bail!("{message}"); - } else { - bail!("Error {}", Value::Object(error.clone())); - } - } - Ok(()) -} - -fn build_body(data: SendData, _model: String) -> Value { - let SendData { - mut messages, - temperature, - .. - } = data; - - patch_system_message(&mut messages); - - let messages: Vec = messages.into_iter().map(|v| json!({ "content": v.content })).collect(); - - let prompt = json!({ "messages": messages }); - - let mut body = json!({ - "prompt": prompt, - }); - - if let Some(temperature) = temperature { - body["temperature"] = (temperature / 2.0).into(); - } - - body -}