diff --git a/assets/arena.html b/assets/arena.html index b43c4640..98080847 100644 --- a/assets/arena.html +++ b/assets/arena.html @@ -565,7 +565,7 @@ async init() { try { const models = await fetchJSON(MODELS_API); - this.models = models; + this.models = models.filter(v => !v.mode || v.mode === "chat"); } catch (err) { toast("No available model"); console.error("Failed to load models", err); diff --git a/assets/playground.html b/assets/playground.html index 4487b444..6ef1f1a1 100644 --- a/assets/playground.html +++ b/assets/playground.html @@ -741,7 +741,7 @@ async init() { await Promise.all([ fetchJSON(MODELS_API).then(models => { - this.models = models; + this.models = models.filter(v => !v.mode || v.mode === "chat"); }).catch(err => { toast("No model available"); console.error("Failed to load models", err); diff --git a/src/client/model.rs b/src/client/model.rs index ebf12646..06f5577c 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -8,7 +8,7 @@ use crate::config::Config; use crate::utils::{estimate_token_length, format_option_value}; use anyhow::{bail, Result}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; const PER_MESSAGES_TOKENS: usize = 5; const BASIS_TOKENS: usize = 2; @@ -242,7 +242,7 @@ impl Model { } } -#[derive(Debug, Clone, Default, Deserialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ModelData { pub name: String, #[serde(default = "default_model_mode")] diff --git a/src/config/mod.rs b/src/config/mod.rs index 903face9..97a6017e 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1445,6 +1445,9 @@ impl WorkingMode { pub fn is_repl(&self) -> bool { *self == WorkingMode::Repl } + pub fn is_serve(&self) -> bool { + *self == WorkingMode::Serve + } } bitflags::bitflags! { diff --git a/src/logger.rs b/src/logger.rs index f7ef2f5f..25d6382f 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -6,8 +6,9 @@ use simplelog::{format_description, Config as LogConfig, ConfigBuilder}; #[cfg(debug_assertions)] pub fn setup_logger(working_mode: WorkingMode) -> Result<()> { - let config = build_config(); - if working_mode == WorkingMode::Serve { + let is_serve = working_mode.is_serve(); + let config = build_config(is_serve); + if is_serve { simplelog::SimpleLogger::init(LevelFilter::Debug, config)?; } else { let file = std::fs::File::create(crate::config::Config::local_path("debug.log")?)?; @@ -18,17 +19,21 @@ pub fn setup_logger(working_mode: WorkingMode) -> Result<()> { #[cfg(not(debug_assertions))] pub fn setup_logger(working_mode: WorkingMode) -> Result<()> { - let config = build_config(); - if working_mode == WorkingMode::Serve { + if working_mode.is_serve() { + let config = build_config(true); simplelog::SimpleLogger::init(log::LevelFilter::Info, config)?; } Ok(()) } -fn build_config() -> LogConfig { - let log_filter = match std::env::var("AICHAT_LOG_FILTER") { - Ok(v) => v, - Err(_) => "aichat".into(), +fn build_config(is_serve: bool) -> LogConfig { + let log_filter = if is_serve { + "aichat::serve".into() + } else { + match std::env::var("AICHAT_LOG_FILTER") { + Ok(v) => v, + Err(_) => "aichat".into(), + } }; ConfigBuilder::new() .add_filter_allow(log_filter) diff --git a/src/serve.rs b/src/serve.rs index 3f33fc2a..fc80f2a9 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -49,6 +49,7 @@ pub async fn run(config: GlobalConfig, addr: Option) -> Result<()> { let listener = TcpListener::bind(&addr).await?; let stop_server = server.run(listener).await?; println!("Chat Completions API: http://{addr}/v1/chat/completions"); + println!("Embeddings API: http://{addr}/v1/embeddings"); println!("LLM Playground: http://{addr}/playground"); println!("LLM Arena: http://{addr}/arena"); shutdown_signal().await; @@ -69,7 +70,7 @@ impl Server { let clients = config.clients.clone(); let model = config.model.clone(); let roles = config.roles.clone(); - let mut models = list_chat_models(&config); + let mut models = list_models(&config); let mut default_model = model.clone(); default_model.data_mut().name = DEFAULT_MODEL_NAME.into(); models.insert(0, &default_model); @@ -82,26 +83,14 @@ impl Server { } else { model.id() }; - let ModelData { - max_input_tokens, - max_output_tokens, - require_max_tokens, - input_price, - output_price, - supports_vision, - supports_function_calling, - .. - } = model.data(); - json!({ - "id": id, - "max_input_tokens": max_input_tokens, - "max_output_tokens": max_output_tokens, - "require_max_tokens": require_max_tokens, - "input_price": input_price, - "output_price": output_price, - "supports_vision": supports_vision, - "supports_function_calling": supports_function_calling, - }) + let mut value = json!(model.data()); + if let Some(value_obj) = value.as_object_mut() { + value_obj.insert("id".into(), id.into()); + value_obj.insert("object".into(), "model".into()); + value_obj.insert("owned_by".into(), model.client_name().into()); + value_obj.remove("name"); + } + value }) .collect(); Self { @@ -161,7 +150,9 @@ impl Server { let mut status = StatusCode::OK; let res = if path == "/v1/chat/completions" { - self.chat_completion(req).await + self.chat_completions(req).await + } else if path == "/v1/embeddings" { + self.embeddings(req).await } else if path == "/v1/models" { self.list_models() } else if path == "/v1/roles" { @@ -220,12 +211,16 @@ impl Server { Ok(res) } - async fn chat_completion(&self, req: hyper::Request) -> Result { + async fn chat_completions(&self, req: hyper::Request) -> Result { let req_body = req.collect().await?.to_bytes(); - let req_body: ChatCompletionReqBody = serde_json::from_slice(&req_body) - .map_err(|err| anyhow!("Invalid request body, {err}"))?; + let req_body: Value = serde_json::from_slice(&req_body) + .map_err(|err| anyhow!("Invalid request json, {err}"))?; - let ChatCompletionReqBody { + debug!("chat completions request: {req_body}"); + let req_body = serde_json::from_value(req_body) + .map_err(|err| anyhow!("Invalid requst body, {err}"))?; + + let ChatCompletionsReqBody { model, messages, temperature, @@ -358,10 +353,68 @@ impl Server { Ok(res) } } + + async fn embeddings(&self, req: hyper::Request) -> Result { + let req_body = req.collect().await?.to_bytes(); + let req_body: Value = serde_json::from_slice(&req_body) + .map_err(|err| anyhow!("Invalid request json, {err}"))?; + + debug!("embeddings request: {req_body}"); + let req_body = serde_json::from_value(req_body) + .map_err(|err| anyhow!("Invalid requst body, {err}"))?; + + let EmbeddingsReqBody { + input, + model: embedding_model_id, + } = req_body; + + let config = Config { + clients: self.clients.to_vec(), + ..Default::default() + }; + let config = Arc::new(RwLock::new(config)); + let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?; + + let texts = match input { + EmbeddingsReqBodyInput::Single(v) => vec![v], + EmbeddingsReqBodyInput::Multiple(v) => v, + }; + let client = init_client(&config, Some(embedding_model))?; + let data = client + .embeddings(EmbeddingsData { + query: false, + texts, + }) + .await?; + let data: Vec<_> = data + .into_iter() + .enumerate() + .map(|(i, v)| { + json!({ + "object": "embedding", + "embedding": v, + "index": i, + }) + }) + .collect(); + let output = json!({ + "object": "list", + "data": data, + "model": embedding_model_id, + "usage": { + "prompt_tokens": 0, + "total_tokens": 0, + } + }); + let res = Response::builder() + .header("Content-Type", "application/json") + .body(Full::new(Bytes::from(output.to_string())).boxed())?; + Ok(res) + } } #[derive(Debug, Deserialize)] -struct ChatCompletionReqBody { +struct ChatCompletionsReqBody { model: String, messages: Vec, temperature: Option, @@ -371,6 +424,19 @@ struct ChatCompletionReqBody { stream: bool, } +#[derive(Debug, Deserialize)] +struct EmbeddingsReqBody { + pub input: EmbeddingsReqBodyInput, + pub model: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum EmbeddingsReqBodyInput { + Single(String), + Multiple(Vec), +} + #[derive(Debug)] enum ResEvent { First(Option),