From 3f7cff79c0069099843dd0d96ccbd7a98f8573f6 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 26 Oct 2023 16:42:54 +0800 Subject: [PATCH] feat: support multi bots and custom url (#150) --- Cargo.lock | 12 +++ Cargo.toml | 1 + src/client.rs | 177 ---------------------------------- src/client/localai.rs | 181 ++++++++++++++++++++++++++++++++++ src/client/mod.rs | 198 ++++++++++++++++++++++++++++++++++++++ src/client/openai.rs | 219 ++++++++++++++++++++++++++++++++++++++++++ src/config/mod.rs | 192 ++++++++++++++++-------------------- src/main.rs | 28 ++++-- src/render/mod.rs | 4 +- src/repl/handler.rs | 8 +- src/repl/mod.rs | 4 +- src/utils/mod.rs | 28 ++---- 12 files changed, 728 insertions(+), 324 deletions(-) delete mode 100644 src/client.rs create mode 100644 src/client/localai.rs create mode 100644 src/client/mod.rs create mode 100644 src/client/openai.rs diff --git a/Cargo.lock b/Cargo.lock index 17ec999a..ccfde4a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,7 @@ version = "0.8.0" dependencies = [ "anyhow", "arboard", + "async-trait", "atty", "base64", "bincode", @@ -154,6 +155,17 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "atty" version = "0.2.14" diff --git a/Cargo.toml b/Cargo.toml index 3d71497a..f17c01a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ rustc-hash = "1.1.0" bstr = "1.3.0" nu-ansi-term = "0.47.0" arboard = { version = "3.2.0", default-features = false } +async-trait = "0.1.74" [dependencies.reqwest] version = "0.11.14" diff --git a/src/client.rs b/src/client.rs deleted file mode 100644 index 74871a0a..00000000 --- a/src/client.rs +++ /dev/null @@ -1,177 +0,0 @@ -use crate::config::SharedConfig; -use crate::repl::{ReplyStreamHandler, SharedAbortSignal}; - -use anyhow::{anyhow, bail, Context, Result}; -use eventsource_stream::Eventsource; -use futures_util::StreamExt; -use reqwest::{Client, Proxy, RequestBuilder}; -use serde_json::{json, Value}; -use std::time::Duration; -use tokio::runtime::Runtime; -use tokio::time::sleep; - -const API_URL: &str = "https://api.openai.com/v1/chat/completions"; - -#[allow(clippy::module_name_repetitions)] -#[derive(Debug)] -pub struct ChatGptClient { - config: SharedConfig, - runtime: Runtime, -} - -impl ChatGptClient { - pub fn init(config: SharedConfig) -> Result { - let runtime = init_runtime()?; - let s = Self { config, runtime }; - let _ = s.build_client()?; // check error - Ok(s) - } - - pub fn send_message(&self, input: &str) -> Result { - self.runtime.block_on(async { - self.send_message_inner(input) - .await - .with_context(|| "Failed to fetch") - }) - } - - pub fn send_message_streaming( - &self, - input: &str, - handler: &mut ReplyStreamHandler, - ) -> Result<()> { - async fn watch_abort(abort: SharedAbortSignal) { - loop { - if abort.aborted() { - break; - } - sleep(Duration::from_millis(100)).await; - } - } - let abort = handler.get_abort(); - self.runtime.block_on(async { - tokio::select! { - ret = self.send_message_streaming_inner(input, handler) => { - handler.done()?; - ret.with_context(|| "Failed to fetch stream") - } - _ = watch_abort(abort.clone()) => { - handler.done()?; - Ok(()) - }, - _ = tokio::signal::ctrl_c() => { - abort.set_ctrlc(); - Ok(()) - } - } - }) - } - - async fn send_message_inner(&self, content: &str) -> Result { - if self.config.read().dry_run { - return Ok(self.config.read().echo_messages(content)); - } - let builder = self.request_builder(content, false)?; - let data: Value = builder.send().await?.json().await?; - if let Some(err_msg) = data["error"]["message"].as_str() { - bail!("Request failed, {err_msg}"); - } - - let output = data["choices"][0]["message"]["content"] - .as_str() - .ok_or_else(|| anyhow!("Unexpected response {data}"))?; - - Ok(output.to_string()) - } - - async fn send_message_streaming_inner( - &self, - content: &str, - handler: &mut ReplyStreamHandler, - ) -> Result<()> { - if self.config.read().dry_run { - handler.text(&self.config.read().echo_messages(content))?; - return Ok(()); - } - let builder = self.request_builder(content, true)?; - let res = builder.send().await?; - if !res.status().is_success() { - let data: Value = res.json().await?; - if let Some(err_msg) = data["error"]["message"].as_str() { - bail!("Request failed, {err_msg}"); - } - bail!("Request failed"); - } - let mut stream = res.bytes_stream().eventsource(); - while let Some(part) = stream.next().await { - let chunk = part?.data; - if chunk == "[DONE]" { - break; - } - let data: Value = serde_json::from_str(&chunk)?; - let text = data["choices"][0]["delta"]["content"] - .as_str() - .unwrap_or_default(); - if text.is_empty() { - continue; - } - handler.text(text)?; - } - - Ok(()) - } - - fn build_client(&self) -> Result { - let mut builder = Client::builder(); - if let Some(proxy) = self.config.read().proxy.as_ref() { - builder = builder - .proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); - } - let timeout = self.config.read().get_connect_timeout(); - let client = builder - .connect_timeout(timeout) - .build() - .with_context(|| "Failed to build http client")?; - Ok(client) - } - - fn request_builder(&self, content: &str, stream: bool) -> Result { - let (model, _) = self.config.read().get_model(); - let messages = self.config.read().build_messages(content)?; - let mut body = json!({ - "model": model, - "messages": messages, - }); - - if let Some(v) = self.config.read().get_temperature() { - body.as_object_mut() - .and_then(|m| m.insert("temperature".into(), json!(v))); - } - - if stream { - body.as_object_mut() - .and_then(|m| m.insert("stream".into(), json!(true))); - } - - let (api_key, organization_id) = self.config.read().get_api_key(); - - let mut builder = self - .build_client()? - .post(API_URL) - .bearer_auth(api_key) - .json(&body); - - if let Some(organization_id) = organization_id { - builder = builder.header("OpenAI-Organization", organization_id); - } - - Ok(builder) - } -} - -fn init_runtime() -> Result { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .with_context(|| "Failed to init tokio") -} diff --git a/src/client/localai.rs b/src/client/localai.rs new file mode 100644 index 00000000..4375c1d5 --- /dev/null +++ b/src/client/localai.rs @@ -0,0 +1,181 @@ +use super::openai::{openai_send_message, openai_send_message_streaming}; +use super::{Client, ModelInfo}; + +use crate::config::SharedConfig; +use crate::repl::ReplyStreamHandler; + +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use inquire::{Confirm, Text}; +use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder}; +use serde::Deserialize; +use serde_json::json; +use std::time::Duration; +use tokio::runtime::Runtime; + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub struct LocalAIClient { + global_config: SharedConfig, + local_config: LocalAIConfig, + model_info: ModelInfo, + runtime: Runtime, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct LocalAIConfig { + pub url: String, + pub api_key: Option, + pub models: Vec, + pub proxy: Option, + /// Set a timeout in seconds for connect to server + pub connect_timeout: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct LocalAIModel { + name: String, + max_tokens: usize, +} + +#[async_trait] +impl Client for LocalAIClient { + fn get_config(&self) -> &SharedConfig { + &self.global_config + } + + fn get_runtime(&self) -> &Runtime { + &self.runtime + } + + async fn send_message_inner(&self, content: &str) -> Result { + let builder = self.request_builder(content, false)?; + openai_send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + content: &str, + handler: &mut ReplyStreamHandler, + ) -> Result<()> { + let builder = self.request_builder(content, true)?; + openai_send_message_streaming(builder, handler).await + } +} + +impl LocalAIClient { + pub fn new( + global_config: SharedConfig, + local_config: LocalAIConfig, + model_info: ModelInfo, + runtime: Runtime, + ) -> Self { + Self { + global_config, + local_config, + model_info, + runtime, + } + } + + pub fn name() -> &'static str { + "localai" + } + + pub fn list_models(local_config: &LocalAIConfig) -> Vec<(String, usize)> { + local_config + .models + .iter() + .map(|v| (v.name.to_string(), v.max_tokens)) + .collect() + } + + pub fn create_config() -> Result { + let mut client_config = format!("clients:\n - type: {}\n", Self::name()); + + let url = Text::new("URL:") + .prompt() + .map_err(|_| anyhow!("An error happened when asking for url, try again later."))?; + + client_config.push_str(&format!(" url: {url}\n")); + + let ans = Confirm::new("Use auth?") + .with_default(false) + .prompt() + .map_err(|_| anyhow!("Not finish questionnaire, try again later."))?; + + if ans { + let api_key = Text::new("API key:").prompt().map_err(|_| { + anyhow!("An error happened when asking for api key, try again later.") + })?; + + client_config.push_str(&format!(" api_key: {api_key}\n")); + } + + let model_name = Text::new("Model Name:").prompt().map_err(|_| { + anyhow!("An error happened when asking for model name, try again later.") + })?; + + let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| { + anyhow!("An error happened when asking for max tokens, try again later.") + })?; + + let ans = Confirm::new("Use proxy?") + .with_default(false) + .prompt() + .map_err(|_| anyhow!("Not finish questionnaire, try again later."))?; + + if ans { + let proxy = Text::new("Set proxy:").prompt().map_err(|_| { + anyhow!("An error happened when asking for proxy, try again later.") + })?; + client_config.push_str(&format!(" proxy: {proxy}\n")); + } + + client_config.push_str(&format!( + " models:\n - name: {model_name}\n max_tokens: {max_tokens}\n" + )); + + Ok(client_config) + } + + fn request_builder(&self, content: &str, stream: bool) -> Result { + let messages = self.global_config.read().build_messages(content)?; + + let mut body = json!({ + "model": self.model_info.name, + "messages": messages, + }); + + if let Some(v) = self.global_config.read().get_temperature() { + body.as_object_mut() + .and_then(|m| m.insert("temperature".into(), json!(v))); + } + + if stream { + body.as_object_mut() + .and_then(|m| m.insert("stream".into(), json!(true))); + } + + let client = { + let mut builder = ReqwestClient::builder(); + if let Some(proxy) = &self.local_config.proxy { + builder = builder + .proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); + } + let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10)); + builder + .connect_timeout(timeout) + .build() + .with_context(|| "Failed to build client")? + }; + + let mut builder = client.post(&self.local_config.url); + if let Some(api_key) = &self.local_config.api_key { + builder = builder.bearer_auth(api_key); + }; + builder = builder.json(&body); + + Ok(builder) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 00000000..b541cfff --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,198 @@ +pub mod localai; +pub mod openai; + +use self::{ + localai::LocalAIConfig, + openai::{OpenAIClient, OpenAIConfig}, +}; + +use anyhow::{bail, Context, Result}; +use async_trait::async_trait; +use serde::Deserialize; +use std::time::Duration; +use tokio::runtime::Runtime; +use tokio::time::sleep; + +use crate::{ + client::localai::LocalAIClient, + config::{Config, SharedConfig}, + repl::{ReplyStreamHandler, SharedAbortSignal}, +}; + +#[allow(clippy::struct_excessive_bools)] +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type")] +pub enum ClientConfig { + #[serde(rename = "openai")] + OpenAI(OpenAIConfig), + #[serde(rename = "localai")] + LocalAI(LocalAIConfig), +} + +#[derive(Debug, Clone)] +pub struct ModelInfo { + pub client: String, + pub name: String, + pub max_tokens: usize, + pub index: usize, +} + +impl Default for ModelInfo { + fn default() -> Self { + let client = OpenAIClient::name(); + let (name, max_tokens) = &OpenAIClient::list_models(&OpenAIConfig::default())[0]; + Self::new(client, name, *max_tokens, 0) + } +} + +impl ModelInfo { + pub fn new(client: &str, name: &str, max_tokens: usize, index: usize) -> Self { + Self { + client: client.into(), + name: name.into(), + max_tokens, + index, + } + } + pub fn stringify(&self) -> String { + format!("{}:{}", self.client, self.name) + } +} + +#[async_trait] +pub trait Client { + fn get_config(&self) -> &SharedConfig; + + fn get_runtime(&self) -> &Runtime; + + fn send_message(&self, content: &str) -> Result { + self.get_runtime().block_on(async { + if self.get_config().read().dry_run { + return Ok(self.get_config().read().echo_messages(content)); + } + self.send_message_inner(content) + .await + .with_context(|| "Failed to fetch") + }) + } + + fn send_message_streaming( + &self, + content: &str, + handler: &mut ReplyStreamHandler, + ) -> Result<()> { + async fn watch_abort(abort: SharedAbortSignal) { + loop { + if abort.aborted() { + break; + } + sleep(Duration::from_millis(100)).await; + } + } + let abort = handler.get_abort(); + self.get_runtime().block_on(async { + tokio::select! { + ret = async { + if self.get_config().read().dry_run { + handler.text(&self.get_config().read().echo_messages(content))?; + return Ok(()); + } + self.send_message_streaming_inner(content, handler).await + } => { + handler.done()?; + ret.with_context(|| "Failed to fetch stream") + } + _ = watch_abort(abort.clone()) => { + handler.done()?; + Ok(()) + }, + _ = tokio::signal::ctrl_c() => { + abort.set_ctrlc(); + Ok(()) + } + } + }) + } + + async fn send_message_inner(&self, content: &str) -> Result; + + async fn send_message_streaming_inner( + &self, + content: &str, + handler: &mut ReplyStreamHandler, + ) -> Result<()>; +} + +pub fn init_client(config: SharedConfig, runtime: Runtime) -> Result> { + let model_info = config.read().model_info.clone(); + let model_info_err = |model_info: &ModelInfo| { + bail!( + "Unknown client {} at config.clients[{}]", + &model_info.client, + &model_info.index + ) + }; + if model_info.client == OpenAIClient::name() { + let local_config = { + if let ClientConfig::OpenAI(c) = &config.read().clients[model_info.index] { + c.clone() + } else { + return model_info_err(&model_info); + } + }; + Ok(Box::new(OpenAIClient::new( + config, + local_config, + model_info, + runtime, + ))) + } else if model_info.client == LocalAIClient::name() { + let local_config = { + if let ClientConfig::LocalAI(c) = &config.read().clients[model_info.index] { + c.clone() + } else { + return model_info_err(&model_info); + } + }; + Ok(Box::new(LocalAIClient::new( + config, + local_config, + model_info, + runtime, + ))) + } else { + bail!("Unknown client {}", &model_info.client) + } +} + +pub fn all_clients() -> Vec<&'static str> { + vec![OpenAIClient::name(), LocalAIClient::name()] +} + +pub fn create_client_config(client: &str) -> Result { + if client == OpenAIClient::name() { + OpenAIClient::create_config() + } else if client == LocalAIClient::name() { + LocalAIClient::create_config() + } else { + bail!("Unknown client {}", &client) + } +} + +pub fn list_models(config: &Config) -> Vec { + config + .clients + .iter() + .enumerate() + .flat_map(|(i, v)| match v { + ClientConfig::OpenAI(c) => OpenAIClient::list_models(c) + .iter() + .map(|(x, y)| ModelInfo::new(OpenAIClient::name(), x, *y, i)) + .collect::>(), + ClientConfig::LocalAI(c) => LocalAIClient::list_models(c) + .iter() + .map(|(x, y)| ModelInfo::new(LocalAIClient::name(), x, *y, i)) + .collect::>(), + }) + .collect() +} diff --git a/src/client/openai.rs b/src/client/openai.rs new file mode 100644 index 00000000..932dc796 --- /dev/null +++ b/src/client/openai.rs @@ -0,0 +1,219 @@ +use super::{Client, ModelInfo}; + +use crate::repl::ReplyStreamHandler; +use crate::{config::SharedConfig, utils::get_env_name}; + +use anyhow::{anyhow, bail, Context, Result}; +use async_trait::async_trait; +use eventsource_stream::Eventsource; +use futures_util::StreamExt; +use inquire::{Confirm, Text}; +use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder}; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::env; +use std::time::Duration; +use tokio::runtime::Runtime; + +const API_URL: &str = "https://api.openai.com/v1/chat/completions"; + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub struct OpenAIClient { + global_config: SharedConfig, + local_config: OpenAIConfig, + model_info: ModelInfo, + runtime: Runtime, +} + +#[allow(clippy::struct_excessive_bools)] +#[derive(Debug, Clone, Deserialize, Default)] +pub struct OpenAIConfig { + pub api_key: Option, + pub organization_id: Option, + pub proxy: Option, + /// Set a timeout in seconds for connect to openai server + pub connect_timeout: Option, +} + +#[async_trait] +impl Client for OpenAIClient { + fn get_config(&self) -> &SharedConfig { + &self.global_config + } + + fn get_runtime(&self) -> &Runtime { + &self.runtime + } + + async fn send_message_inner(&self, content: &str) -> Result { + let builder = self.request_builder(content, false)?; + openai_send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + content: &str, + handler: &mut ReplyStreamHandler, + ) -> Result<()> { + let builder = self.request_builder(content, true)?; + openai_send_message_streaming(builder, handler).await + } +} + +impl OpenAIClient { + pub fn new( + global_config: SharedConfig, + local_config: OpenAIConfig, + model_info: ModelInfo, + runtime: Runtime, + ) -> Self { + Self { + global_config, + local_config, + model_info, + runtime, + } + } + + pub fn name() -> &'static str { + "openai" + } + + pub fn list_models(_local_config: &OpenAIConfig) -> Vec<(String, usize)> { + vec![ + ("gpt-3.5-turbo".into(), 4096), + ("gpt-3.5-turbo-16k".into(), 16384), + ("gpt-4".into(), 8192), + ("gpt-4-32k".into(), 32768), + ] + } + + pub fn create_config() -> Result { + let mut client_config = format!("clients:\n - type: {}\n", Self::name()); + + let api_key = Text::new("API key:") + .prompt() + .map_err(|_| anyhow!("An error happened when asking for api key, try again later."))?; + + client_config.push_str(&format!(" api_key: {api_key}\n")); + + let ans = Confirm::new("Has Organization?") + .with_default(false) + .prompt() + .map_err(|_| anyhow!("Not finish questionnaire, try again later."))?; + + if ans { + let organization_id = Text::new("Organization ID:").prompt().map_err(|_| { + anyhow!("An error happened when asking for proxy, try again later.") + })?; + client_config.push_str(&format!(" organization_id: {organization_id}\n")); + } + + let ans = Confirm::new("Use proxy?") + .with_default(false) + .prompt() + .map_err(|_| anyhow!("Not finish questionnaire, try again later."))?; + + if ans { + let proxy = Text::new("Set proxy:").prompt().map_err(|_| { + anyhow!("An error happened when asking for proxy, try again later.") + })?; + client_config.push_str(&format!(" proxy: {proxy}\n")); + } + + Ok(client_config) + } + + fn request_builder(&self, content: &str, stream: bool) -> Result { + let api_key = if let Some(api_key) = &self.local_config.api_key { + api_key.to_string() + } else if let Ok(api_key) = env::var(get_env_name("api_key")) { + api_key.to_string() + } else { + bail!("Miss api_key") + }; + + let messages = self.global_config.read().build_messages(content)?; + + let mut body = json!({ + "model": self.model_info.name, + "messages": messages, + }); + + if let Some(v) = self.global_config.read().get_temperature() { + body.as_object_mut() + .and_then(|m| m.insert("temperature".into(), json!(v))); + } + + if stream { + body.as_object_mut() + .and_then(|m| m.insert("stream".into(), json!(true))); + } + + let client = { + let mut builder = ReqwestClient::builder(); + if let Some(proxy) = &self.local_config.proxy { + builder = builder + .proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); + } + let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10)); + builder + .connect_timeout(timeout) + .build() + .with_context(|| "Failed to build client")? + }; + + let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body); + + if let Some(organization_id) = &self.local_config.organization_id { + builder = builder.header("OpenAI-Organization", organization_id); + } + + Ok(builder) + } +} + +pub(crate) async fn openai_send_message(builder: RequestBuilder) -> Result { + let data: Value = builder.send().await?.json().await?; + if let Some(err_msg) = data["error"]["message"].as_str() { + bail!("Request failed, {err_msg}"); + } + + let output = data["choices"][0]["message"]["content"] + .as_str() + .ok_or_else(|| anyhow!("Unexpected response {data}"))?; + + Ok(output.to_string()) +} + +pub(crate) async fn openai_send_message_streaming( + builder: RequestBuilder, + handler: &mut ReplyStreamHandler, +) -> Result<()> { + let res = builder.send().await?; + if !res.status().is_success() { + let data: Value = res.json().await?; + if let Some(err_msg) = data["error"]["message"].as_str() { + bail!("Request failed, {err_msg}"); + } + bail!("Request failed"); + } + let mut stream = res.bytes_stream().eventsource(); + while let Some(part) = stream.next().await { + let chunk = part?.data; + if chunk == "[DONE]" { + break; + } + let data: Value = serde_json::from_str(&chunk)?; + let text = data["choices"][0]["delta"]["content"] + .as_str() + .unwrap_or_default(); + if text.is_empty() { + continue; + } + handler.text(text)?; + } + + Ok(()) +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 3fc3f5bc..4a0a1264 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,14 +6,15 @@ use self::conversation::Conversation; use self::message::Message; use self::role::Role; +use crate::client::openai::{OpenAIClient, OpenAIConfig}; +use crate::client::{all_clients, create_client_config, list_models, ClientConfig, ModelInfo}; use crate::config::message::num_tokens_from_messages; -use crate::utils::{mask_text, now}; +use crate::utils::{get_env_name, now}; use anyhow::{anyhow, bail, Context, Result}; -use inquire::{Confirm, Text}; +use inquire::{Confirm, Select}; use parking_lot::RwLock; use serde::Deserialize; -use std::time::Duration; use std::{ env, fs::{create_dir_all, read_to_string, File, OpenOptions}, @@ -23,24 +24,16 @@ use std::{ sync::Arc, }; -pub const MODELS: [(&str, usize); 4] = [ - ("gpt-4", 8192), - ("gpt-4-32k", 32768), - ("gpt-3.5-turbo", 4096), - ("gpt-3.5-turbo-16k", 16384), -]; - const CONFIG_FILE_NAME: &str = "config.yaml"; const ROLES_FILE_NAME: &str = "roles.yaml"; const HISTORY_FILE_NAME: &str = "history.txt"; const MESSAGE_FILE_NAME: &str = "messages.md"; -const SET_COMPLETIONS: [&str; 8] = [ +const SET_COMPLETIONS: [&str; 7] = [ ".set temperature", ".set save true", ".set save false", ".set highlight true", ".set highlight false", - ".set proxy", ".set dry_run true", ".set dry_run false", ]; @@ -49,33 +42,26 @@ const SET_COMPLETIONS: [&str; 8] = [ #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { - /// OpenAI api key - pub api_key: Option, - /// OpenAI organization id - pub organization_id: Option, - /// OpenAI model - #[serde(rename(serialize = "model", deserialize = "model"))] - pub model_name: Option, + /// LLM model + pub model: Option, /// What sampling temperature to use, between 0 and 2 pub temperature: Option, /// Whether to persistently save chat messages pub save: bool, /// Whether to disable highlight pub highlight: bool, - /// Set proxy - pub proxy: Option, /// Used only for debugging pub dry_run: bool, /// If set ture, start a conversation immediately upon repl pub conversation_first: bool, /// Is ligth theme pub light_theme: bool, - /// Set a timeout in seconds for connect to gpt - pub connect_timeout: usize, /// Automatically copy the last output to the clipboard pub auto_copy: bool, /// Use vi keybindings, overriding the default Emacs keybindings pub vi_keybindings: bool, + /// LLM clients + pub clients: Vec, /// Predefined roles #[serde(skip)] pub roles: Vec, @@ -86,29 +72,26 @@ pub struct Config { #[serde(skip)] pub conversation: Option, #[serde(skip)] - pub model: (String, usize), + pub model_info: ModelInfo, } impl Default for Config { fn default() -> Self { Self { - api_key: None, - organization_id: None, - model_name: None, + model: None, temperature: None, save: false, highlight: true, - proxy: None, dry_run: false, conversation_first: false, light_theme: false, - connect_timeout: 10, auto_copy: false, vi_keybindings: false, roles: vec![], + clients: vec![ClientConfig::OpenAI(OpenAIConfig::default())], role: None, conversation: None, - model: ("gpt-3.5-turbo".into(), 4096), + model_info: Default::default(), } } } @@ -118,27 +101,29 @@ pub type SharedConfig = Arc>; impl Config { pub fn init(is_interactive: bool) -> Result { - let api_key = env::var(get_env_name("api_key")).ok(); let config_path = Self::config_file()?; - if is_interactive && api_key.is_none() && !config_path.exists() { + + let api_key = env::var(get_env_name("api_key")).ok(); + + let exist_config_path = config_path.exists(); + if is_interactive && api_key.is_none() && !exist_config_path { create_config_file(&config_path)?; } - let mut config = if api_key.is_some() && !config_path.exists() { + let mut config = if api_key.is_some() && !exist_config_path { Self::default() } else { Self::load_config(&config_path)? }; - if api_key.is_some() { - config.api_key = api_key; - } - if config.api_key.is_none() { - bail!("api_key not set"); + + // Compatible with old configuration files + if exist_config_path { + config.compat_old_config(&config_path)?; } - if let Some(name) = config.model_name.clone() { + + if let Some(name) = config.model.clone() { config.set_model(&name)?; } config.merge_env_vars(); - config.maybe_proxy(); config.load_roles()?; Ok(config) @@ -211,12 +196,6 @@ impl Config { Self::local_file(CONFIG_FILE_NAME) } - pub fn get_api_key(&self) -> (String, Option) { - let api_key = self.api_key.as_ref().expect("api_key not set"); - let organization_id = self.organization_id.as_ref(); - (api_key.into(), organization_id.cloned()) - } - pub fn roles_file() -> Result { let env_name = get_env_name("roles_file"); env::var(env_name).map_or_else( @@ -283,14 +262,6 @@ impl Config { } } - pub const fn get_connect_timeout(&self) -> Duration { - Duration::from_secs(self.connect_timeout as u64) - } - - pub fn get_model(&self) -> (String, usize) { - self.model.clone() - } - pub fn build_messages(&self, content: &str) -> Result> { #[allow(clippy::option_if_let_else)] let messages = if let Some(conversation) = self.conversation.as_ref() { @@ -302,24 +273,29 @@ impl Config { vec![message] }; let tokens = num_tokens_from_messages(&messages); - if tokens >= self.model.1 { + if tokens >= self.model_info.max_tokens { bail!("Exceed max tokens limit") } Ok(messages) } - pub fn set_model(&mut self, name: &str) -> Result<()> { - if let Some(token) = MODELS.iter().find(|(v, _)| *v == name).map(|(_, v)| *v) { - self.model = (name.to_string(), token); - } else { - bail!("Invalid model") + pub fn set_model(&mut self, value: &str) -> Result<()> { + let models = list_models(self); + if value.contains(':') { + if let Some(model) = models.iter().find(|v| v.stringify() == value) { + self.model_info = model.clone(); + return Ok(()); + } + } else if let Some(model) = models.iter().find(|v| v.client == value) { + self.model_info = model.clone(); + return Ok(()); } - Ok(()) + bail!("Invalid model") } pub const fn get_reamind_tokens(&self) -> usize { - let mut tokens = self.model.1; + let mut tokens = self.model_info.max_tokens; if let Some(conversation) = self.conversation.as_ref() { tokens = tokens.saturating_sub(conversation.tokens); } @@ -331,30 +307,19 @@ impl Config { let state = if path.exists() { "" } else { " ⚠️" }; format!("{}{state}", path.display()) }; - let proxy = self - .proxy - .as_ref() - .map_or_else(|| String::from("-"), std::string::ToString::to_string); let temperature = self .temperature .map_or_else(|| String::from("-"), |v| v.to_string()); - let (api_key, organization_id) = self.get_api_key(); - let api_key = mask_text(&api_key, 3, 4); - let organization_id = organization_id.map_or_else(|| "-".into(), |v| mask_text(&v, 3, 4)); let items = vec![ ("config_file", file_info(&Self::config_file()?)), ("roles_file", file_info(&Self::roles_file()?)), ("messages_file", file_info(&Self::messages_file()?)), - ("api_key", api_key), - ("organization_id", organization_id), - ("model", self.model.0.to_string()), + ("model", self.model_info.stringify()), ("temperature", temperature), ("save", self.save.to_string()), ("highlight", self.highlight.to_string()), - ("proxy", proxy), ("conversation_first", self.conversation_first.to_string()), ("light_theme", self.light_theme.to_string()), - ("connect_timeout", self.connect_timeout.to_string()), ("dry_run", self.dry_run.to_string()), ("vi_keybindings", self.vi_keybindings.to_string()), ]; @@ -373,7 +338,11 @@ impl Config { .collect(); completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string)); - completion.extend(MODELS.map(|(v, _)| format!(".model {v}"))); + completion.extend( + list_models(self) + .iter() + .map(|v| format!(".model {}", v.stringify())), + ); completion } @@ -402,13 +371,6 @@ impl Config { let value = value.parse().with_context(|| "Invalid value")?; self.highlight = value; } - "proxy" => { - if unset { - self.proxy = None; - } else { - self.proxy = Some(value.to_string()); - } - } "dry_run" => { let value = value.parse().with_context(|| "Invalid value")?; self.dry_run = value; @@ -501,44 +463,62 @@ impl Config { } } - fn maybe_proxy(&mut self) { - if self.proxy.is_some() { - return; + fn compat_old_config(&mut self, config_path: &PathBuf) -> Result<()> { + let content = read_to_string(config_path)?; + let value: serde_json::Value = serde_yaml::from_str(&content)?; + if value.get("client").is_some() { + return Ok(()); } - if let Ok(value) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) { - self.proxy = Some(value); + + if let Some(model_name) = value.get("model").and_then(|v| v.as_str()) { + if model_name.starts_with("gpt") { + self.model = Some(format!("{}:{}", OpenAIClient::name(), model_name)); + } } + + if let Some(ClientConfig::OpenAI(client_config)) = self.clients.get_mut(0) { + if let Some(api_key) = value.get("api_key").and_then(|v| v.as_str()) { + client_config.api_key = Some(api_key.to_string()) + } + + if let Some(organization_id) = value.get("organization_id").and_then(|v| v.as_str()) { + client_config.organization_id = Some(organization_id.to_string()) + } + + if let Some(proxy) = value.get("proxy").and_then(|v| v.as_str()) { + client_config.proxy = Some(proxy.to_string()) + } + + if let Some(connect_timeout) = value.get("connect_timeout").and_then(|v| v.as_i64()) { + client_config.connect_timeout = Some(connect_timeout as _) + } + } + Ok(()) } } fn create_config_file(config_path: &Path) -> Result<()> { - let confirm_map_err = |_| anyhow!("Not finish questionnaire, try again later."); - let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later."); let ans = Confirm::new("No config file, create a new one?") .with_default(true) .prompt() - .map_err(confirm_map_err)?; + .map_err(|_| anyhow!("Not finish questionnaire, try again later."))?; if !ans { exit(0); } - let api_key = Text::new("OpenAI API Key:") - .prompt() - .map_err(text_map_err)?; - let mut raw_config = format!("api_key: {api_key}\n"); - let ans = Confirm::new("Use proxy?") - .with_default(false) + let client = Select::new("Choose bots?", all_clients()) .prompt() - .map_err(confirm_map_err)?; - if ans { - let proxy = Text::new("Set proxy:").prompt().map_err(text_map_err)?; - raw_config.push_str(&format!("proxy: {proxy}\n")); - } + .map_err(|_| anyhow!("An error happened when selecting bots, try again later."))?; + + let mut raw_config = create_client_config(client)?; + + raw_config.push_str(&format!("model: {client}\n")); let ans = Confirm::new("Save chat messages") .with_default(true) .prompt() - .map_err(confirm_map_err)?; + .map_err(|_| anyhow!("Not finish questionnaire, try again later."))?; + if ans { raw_config.push_str("save: true\n"); } @@ -571,14 +551,6 @@ fn ensure_parent_exists(path: &Path) -> Result<()> { Ok(()) } -fn get_env_name(key: &str) -> String { - format!( - "{}_{}", - env!("CARGO_CRATE_NAME").to_ascii_uppercase(), - key.to_ascii_uppercase(), - ) -} - fn set_bool(target: &mut bool, value: &str) { match value { "1" | "true" => *target = true, diff --git a/src/main.rs b/src/main.rs index b201fd6c..ec8abeef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,11 +8,12 @@ mod term; mod utils; use crate::cli::Cli; -use crate::client::ChatGptClient; +use crate::client::Client; use crate::config::{Config, SharedConfig}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use clap::Parser; +use client::{init_client, list_models}; use crossbeam::sync::WaitGroup; use is_terminal::IsTerminal; use parking_lot::RwLock; @@ -21,6 +22,7 @@ use repl::{AbortSignal, Repl}; use std::io::{stdin, Read}; use std::sync::Arc; use std::{io::stdout, process::exit}; +use tokio::runtime::Runtime; use utils::cl100k_base_singleton; fn main() -> Result<()> { @@ -36,8 +38,8 @@ fn main() -> Result<()> { exit(0); } if cli.list_models { - for (name, _) in &config::MODELS { - println!("{name}"); + for model in list_models(&config.read()) { + println!("{}", model.stringify()); } exit(0); } @@ -69,24 +71,25 @@ fn main() -> Result<()> { exit(0); } let no_stream = cli.no_stream; - let client = ChatGptClient::init(config.clone())?; + let runtime = init_runtime()?; + let client = init_client(config.clone(), runtime)?; if atty::isnt(atty::Stream::Stdin) { let mut input = String::new(); stdin().read_to_string(&mut input)?; if let Some(text) = text { input = format!("{text}\n{input}"); } - start_directive(&client, &config, &input, no_stream) + start_directive(client.as_ref(), &config, &input, no_stream) } else { match text { - Some(text) => start_directive(&client, &config, &text, no_stream), + Some(text) => start_directive(client.as_ref(), &config, &text, no_stream), None => start_interactive(client, config), } } } fn start_directive( - client: &ChatGptClient, + client: &dyn Client, config: &SharedConfig, input: &str, no_stream: bool, @@ -120,9 +123,16 @@ fn start_directive( config.read().save_message(input, &output) } -fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> { +fn start_interactive(client: Box, config: SharedConfig) -> Result<()> { cl100k_base_singleton(); config.write().on_repl()?; let mut repl = Repl::init(config.clone())?; repl.run(client, config) } + +fn init_runtime() -> Result { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .with_context(|| "Failed to init tokio") +} diff --git a/src/render/mod.rs b/src/render/mod.rs index 20d4d0e0..e7c6228b 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -7,7 +7,7 @@ use self::cmd::cmd_render_stream; pub use self::markdown::MarkdownRender; use self::repl::repl_render_stream; -use crate::client::ChatGptClient; +use crate::client::Client; use crate::config::SharedConfig; use crate::print_now; use crate::repl::{ReplyStreamHandler, SharedAbortSignal}; @@ -20,7 +20,7 @@ use std::thread::spawn; #[allow(clippy::module_name_repetitions)] pub fn render_stream( input: &str, - client: &ChatGptClient, + client: &dyn Client, config: &SharedConfig, repl: bool, abort: SharedAbortSignal, diff --git a/src/repl/handler.rs b/src/repl/handler.rs index 2057b8be..53cdd984 100644 --- a/src/repl/handler.rs +++ b/src/repl/handler.rs @@ -1,4 +1,4 @@ -use crate::client::ChatGptClient; +use crate::client::Client; use crate::config::SharedConfig; use crate::print_now; use crate::render::render_stream; @@ -26,7 +26,7 @@ pub enum ReplCmd { #[allow(clippy::module_name_repetitions)] pub struct ReplCmdHandler { - client: ChatGptClient, + client: Box, config: SharedConfig, reply: RefCell, abort: SharedAbortSignal, @@ -35,7 +35,7 @@ pub struct ReplCmdHandler { impl ReplCmdHandler { #[allow(clippy::unnecessary_wraps)] pub fn init( - client: ChatGptClient, + client: Box, config: SharedConfig, abort: SharedAbortSignal, ) -> Result { @@ -59,7 +59,7 @@ impl ReplCmdHandler { let wg = WaitGroup::new(); let ret = render_stream( &input, - &self.client, + self.client.as_ref(), &self.config, true, self.abort.clone(), diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 2e685288..ce640f81 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -9,7 +9,7 @@ pub use self::abort::*; pub use self::handler::*; pub use self::init::Repl; -use crate::client::ChatGptClient; +use crate::client::Client; use crate::config::SharedConfig; use crate::print_now; use crate::term; @@ -35,7 +35,7 @@ pub const REPL_COMMANDS: [(&str, &str); 13] = [ ]; impl Repl { - pub fn run(&mut self, client: ChatGptClient, config: SharedConfig) -> Result<()> { + pub fn run(&mut self, client: Box, config: SharedConfig) -> Result<()> { let abort = AbortSignal::new(); let handler = ReplCmdHandler::init(client, config, abort.clone())?; print_now!("Welcome to aichat {}\n", env!("CARGO_PKG_VERSION")); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8b5dae63..ae759696 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -24,32 +24,20 @@ pub fn now() -> String { now.to_rfc3339_opts(SecondsFormat::Secs, false) } +pub fn get_env_name(key: &str) -> String { + format!( + "{}_{}", + env!("CARGO_CRATE_NAME").to_ascii_uppercase(), + key.to_ascii_uppercase(), + ) +} + #[allow(unused)] pub fn emphasis(text: &str) -> String { text.stylize().with(Color::White).to_string() } -pub fn mask_text(text: &str, head: usize, tail: usize) -> String { - if text.len() <= head + tail { - return text.to_string(); - } - format!("{}...{}", &text[0..head], &text[text.len() - tail..]) -} - pub fn copy(src: &str) -> Result<(), arboard::Error> { let mut clipboard = Clipboard::new()?; clipboard.set_text(src) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_mask_text() { - assert_eq!(mask_text("123456", 3, 4), "123456"); - assert_eq!(mask_text("1234567", 3, 4), "1234567"); - assert_eq!(mask_text("12345678", 3, 4), "123...5678"); - assert_eq!(mask_text("12345678", 4, 3), "1234...678"); - } -}