diff --git a/assets/arena.html b/assets/arena.html
index b43c4640..98080847 100644
--- a/assets/arena.html
+++ b/assets/arena.html
@@ -565,7 +565,7 @@
async init() {
try {
const models = await fetchJSON(MODELS_API);
- this.models = models;
+ this.models = models.filter(v => !v.mode || v.mode === "chat");
} catch (err) {
toast("No available model");
console.error("Failed to load models", err);
diff --git a/assets/playground.html b/assets/playground.html
index 4487b444..6ef1f1a1 100644
--- a/assets/playground.html
+++ b/assets/playground.html
@@ -741,7 +741,7 @@
async init() {
await Promise.all([
fetchJSON(MODELS_API).then(models => {
- this.models = models;
+ this.models = models.filter(v => !v.mode || v.mode === "chat");
}).catch(err => {
toast("No model available");
console.error("Failed to load models", err);
diff --git a/src/client/model.rs b/src/client/model.rs
index ebf12646..06f5577c 100644
--- a/src/client/model.rs
+++ b/src/client/model.rs
@@ -8,7 +8,7 @@ use crate::config::Config;
use crate::utils::{estimate_token_length, format_option_value};
use anyhow::{bail, Result};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;
@@ -242,7 +242,7 @@ impl Model {
}
}
-#[derive(Debug, Clone, Default, Deserialize)]
+#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelData {
pub name: String,
#[serde(default = "default_model_mode")]
diff --git a/src/config/mod.rs b/src/config/mod.rs
index 903face9..97a6017e 100644
--- a/src/config/mod.rs
+++ b/src/config/mod.rs
@@ -1445,6 +1445,9 @@ impl WorkingMode {
pub fn is_repl(&self) -> bool {
*self == WorkingMode::Repl
}
+ pub fn is_serve(&self) -> bool {
+ *self == WorkingMode::Serve
+ }
}
bitflags::bitflags! {
diff --git a/src/logger.rs b/src/logger.rs
index f7ef2f5f..25d6382f 100644
--- a/src/logger.rs
+++ b/src/logger.rs
@@ -6,8 +6,9 @@ use simplelog::{format_description, Config as LogConfig, ConfigBuilder};
#[cfg(debug_assertions)]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
- let config = build_config();
- if working_mode == WorkingMode::Serve {
+ let is_serve = working_mode.is_serve();
+ let config = build_config(is_serve);
+ if is_serve {
simplelog::SimpleLogger::init(LevelFilter::Debug, config)?;
} else {
let file = std::fs::File::create(crate::config::Config::local_path("debug.log")?)?;
@@ -18,17 +19,21 @@ pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
#[cfg(not(debug_assertions))]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
- let config = build_config();
- if working_mode == WorkingMode::Serve {
+ if working_mode.is_serve() {
+ let config = build_config(true);
simplelog::SimpleLogger::init(log::LevelFilter::Info, config)?;
}
Ok(())
}
-fn build_config() -> LogConfig {
- let log_filter = match std::env::var("AICHAT_LOG_FILTER") {
- Ok(v) => v,
- Err(_) => "aichat".into(),
+fn build_config(is_serve: bool) -> LogConfig {
+ let log_filter = if is_serve {
+ "aichat::serve".into()
+ } else {
+ match std::env::var("AICHAT_LOG_FILTER") {
+ Ok(v) => v,
+ Err(_) => "aichat".into(),
+ }
};
ConfigBuilder::new()
.add_filter_allow(log_filter)
diff --git a/src/serve.rs b/src/serve.rs
index 3f33fc2a..fc80f2a9 100644
--- a/src/serve.rs
+++ b/src/serve.rs
@@ -49,6 +49,7 @@ pub async fn run(config: GlobalConfig, addr: Option) -> Result<()> {
let listener = TcpListener::bind(&addr).await?;
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
+ println!("Embeddings API: http://{addr}/v1/embeddings");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena");
shutdown_signal().await;
@@ -69,7 +70,7 @@ impl Server {
let clients = config.clients.clone();
let model = config.model.clone();
let roles = config.roles.clone();
- let mut models = list_chat_models(&config);
+ let mut models = list_models(&config);
let mut default_model = model.clone();
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
models.insert(0, &default_model);
@@ -82,26 +83,14 @@ impl Server {
} else {
model.id()
};
- let ModelData {
- max_input_tokens,
- max_output_tokens,
- require_max_tokens,
- input_price,
- output_price,
- supports_vision,
- supports_function_calling,
- ..
- } = model.data();
- json!({
- "id": id,
- "max_input_tokens": max_input_tokens,
- "max_output_tokens": max_output_tokens,
- "require_max_tokens": require_max_tokens,
- "input_price": input_price,
- "output_price": output_price,
- "supports_vision": supports_vision,
- "supports_function_calling": supports_function_calling,
- })
+ let mut value = json!(model.data());
+ if let Some(value_obj) = value.as_object_mut() {
+ value_obj.insert("id".into(), id.into());
+ value_obj.insert("object".into(), "model".into());
+ value_obj.insert("owned_by".into(), model.client_name().into());
+ value_obj.remove("name");
+ }
+ value
})
.collect();
Self {
@@ -161,7 +150,9 @@ impl Server {
let mut status = StatusCode::OK;
let res = if path == "/v1/chat/completions" {
- self.chat_completion(req).await
+ self.chat_completions(req).await
+ } else if path == "/v1/embeddings" {
+ self.embeddings(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
@@ -220,12 +211,16 @@ impl Server {
Ok(res)
}
- async fn chat_completion(&self, req: hyper::Request) -> Result {
+ async fn chat_completions(&self, req: hyper::Request) -> Result {
let req_body = req.collect().await?.to_bytes();
- let req_body: ChatCompletionReqBody = serde_json::from_slice(&req_body)
- .map_err(|err| anyhow!("Invalid request body, {err}"))?;
+ let req_body: Value = serde_json::from_slice(&req_body)
+ .map_err(|err| anyhow!("Invalid request json, {err}"))?;
- let ChatCompletionReqBody {
+ debug!("chat completions request: {req_body}");
+ let req_body = serde_json::from_value(req_body)
+ .map_err(|err| anyhow!("Invalid requst body, {err}"))?;
+
+ let ChatCompletionsReqBody {
model,
messages,
temperature,
@@ -358,10 +353,68 @@ impl Server {
Ok(res)
}
}
+
+ async fn embeddings(&self, req: hyper::Request) -> Result {
+ let req_body = req.collect().await?.to_bytes();
+ let req_body: Value = serde_json::from_slice(&req_body)
+ .map_err(|err| anyhow!("Invalid request json, {err}"))?;
+
+ debug!("embeddings request: {req_body}");
+ let req_body = serde_json::from_value(req_body)
+ .map_err(|err| anyhow!("Invalid requst body, {err}"))?;
+
+ let EmbeddingsReqBody {
+ input,
+ model: embedding_model_id,
+ } = req_body;
+
+ let config = Config {
+ clients: self.clients.to_vec(),
+ ..Default::default()
+ };
+ let config = Arc::new(RwLock::new(config));
+ let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?;
+
+ let texts = match input {
+ EmbeddingsReqBodyInput::Single(v) => vec![v],
+ EmbeddingsReqBodyInput::Multiple(v) => v,
+ };
+ let client = init_client(&config, Some(embedding_model))?;
+ let data = client
+ .embeddings(EmbeddingsData {
+ query: false,
+ texts,
+ })
+ .await?;
+ let data: Vec<_> = data
+ .into_iter()
+ .enumerate()
+ .map(|(i, v)| {
+ json!({
+ "object": "embedding",
+ "embedding": v,
+ "index": i,
+ })
+ })
+ .collect();
+ let output = json!({
+ "object": "list",
+ "data": data,
+ "model": embedding_model_id,
+ "usage": {
+ "prompt_tokens": 0,
+ "total_tokens": 0,
+ }
+ });
+ let res = Response::builder()
+ .header("Content-Type", "application/json")
+ .body(Full::new(Bytes::from(output.to_string())).boxed())?;
+ Ok(res)
+ }
}
#[derive(Debug, Deserialize)]
-struct ChatCompletionReqBody {
+struct ChatCompletionsReqBody {
model: String,
messages: Vec,
temperature: Option,
@@ -371,6 +424,19 @@ struct ChatCompletionReqBody {
stream: bool,
}
+#[derive(Debug, Deserialize)]
+struct EmbeddingsReqBody {
+ pub input: EmbeddingsReqBodyInput,
+ pub model: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(untagged)]
+enum EmbeddingsReqBodyInput {
+ Single(String),
+ Multiple(Vec),
+}
+
#[derive(Debug)]
enum ResEvent {
First(Option),