Skip to content

Commit

Permalink
feat: serve embeddings api (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Jun 21, 2024
1 parent 97c82e5 commit 054e998
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 39 deletions.
2 changes: 1 addition & 1 deletion assets/arena.html
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@
async init() {
try {
const models = await fetchJSON(MODELS_API);
this.models = models;
this.models = models.filter(v => !v.mode || v.mode === "chat");
} catch (err) {
toast("No available model");
console.error("Failed to load models", err);
Expand Down
2 changes: 1 addition & 1 deletion assets/playground.html
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@
async init() {
await Promise.all([
fetchJSON(MODELS_API).then(models => {
this.models = models;
this.models = models.filter(v => !v.mode || v.mode === "chat");
}).catch(err => {
toast("No model available");
console.error("Failed to load models", err);
Expand Down
4 changes: 2 additions & 2 deletions src/client/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::config::Config;
use crate::utils::{estimate_token_length, format_option_value};

use anyhow::{bail, Result};
use serde::Deserialize;
use serde::{Deserialize, Serialize};

const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;
Expand Down Expand Up @@ -242,7 +242,7 @@ impl Model {
}
}

#[derive(Debug, Clone, Default, Deserialize)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelData {
pub name: String,
#[serde(default = "default_model_mode")]
Expand Down
3 changes: 3 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,9 @@ impl WorkingMode {
pub fn is_repl(&self) -> bool {
*self == WorkingMode::Repl
}
pub fn is_serve(&self) -> bool {
*self == WorkingMode::Serve
}
}

bitflags::bitflags! {
Expand Down
21 changes: 13 additions & 8 deletions src/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use simplelog::{format_description, Config as LogConfig, ConfigBuilder};

#[cfg(debug_assertions)]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
let config = build_config();
if working_mode == WorkingMode::Serve {
let is_serve = working_mode.is_serve();
let config = build_config(is_serve);
if is_serve {
simplelog::SimpleLogger::init(LevelFilter::Debug, config)?;
} else {
let file = std::fs::File::create(crate::config::Config::local_path("debug.log")?)?;
Expand All @@ -18,17 +19,21 @@ pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {

#[cfg(not(debug_assertions))]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
let config = build_config();
if working_mode == WorkingMode::Serve {
if working_mode.is_serve() {
let config = build_config(true);
simplelog::SimpleLogger::init(log::LevelFilter::Info, config)?;
}
Ok(())
}

fn build_config() -> LogConfig {
let log_filter = match std::env::var("AICHAT_LOG_FILTER") {
Ok(v) => v,
Err(_) => "aichat".into(),
fn build_config(is_serve: bool) -> LogConfig {
let log_filter = if is_serve {
"aichat::serve".into()
} else {
match std::env::var("AICHAT_LOG_FILTER") {
Ok(v) => v,
Err(_) => "aichat".into(),
}
};
ConfigBuilder::new()
.add_filter_allow(log_filter)
Expand Down
120 changes: 93 additions & 27 deletions src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let listener = TcpListener::bind(&addr).await?;
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
println!("Embeddings API: http://{addr}/v1/embeddings");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena");
shutdown_signal().await;
Expand All @@ -69,7 +70,7 @@ impl Server {
let clients = config.clients.clone();
let model = config.model.clone();
let roles = config.roles.clone();
let mut models = list_chat_models(&config);
let mut models = list_models(&config);
let mut default_model = model.clone();
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
models.insert(0, &default_model);
Expand All @@ -82,26 +83,14 @@ impl Server {
} else {
model.id()
};
let ModelData {
max_input_tokens,
max_output_tokens,
require_max_tokens,
input_price,
output_price,
supports_vision,
supports_function_calling,
..
} = model.data();
json!({
"id": id,
"max_input_tokens": max_input_tokens,
"max_output_tokens": max_output_tokens,
"require_max_tokens": require_max_tokens,
"input_price": input_price,
"output_price": output_price,
"supports_vision": supports_vision,
"supports_function_calling": supports_function_calling,
})
let mut value = json!(model.data());
if let Some(value_obj) = value.as_object_mut() {
value_obj.insert("id".into(), id.into());
value_obj.insert("object".into(), "model".into());
value_obj.insert("owned_by".into(), model.client_name().into());
value_obj.remove("name");
}
value
})
.collect();
Self {
Expand Down Expand Up @@ -161,7 +150,9 @@ impl Server {

let mut status = StatusCode::OK;
let res = if path == "/v1/chat/completions" {
self.chat_completion(req).await
self.chat_completions(req).await
} else if path == "/v1/embeddings" {
self.embeddings(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
Expand Down Expand Up @@ -220,12 +211,16 @@ impl Server {
Ok(res)
}

async fn chat_completion(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: ChatCompletionReqBody = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;

let ChatCompletionReqBody {
debug!("chat completions request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid requst body, {err}"))?;

let ChatCompletionsReqBody {
model,
messages,
temperature,
Expand Down Expand Up @@ -358,10 +353,68 @@ impl Server {
Ok(res)
}
}

async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;

debug!("embeddings request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid requst body, {err}"))?;

let EmbeddingsReqBody {
input,
model: embedding_model_id,
} = req_body;

let config = Config {
clients: self.clients.to_vec(),
..Default::default()
};
let config = Arc::new(RwLock::new(config));
let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?;

let texts = match input {
EmbeddingsReqBodyInput::Single(v) => vec![v],
EmbeddingsReqBodyInput::Multiple(v) => v,
};
let client = init_client(&config, Some(embedding_model))?;
let data = client
.embeddings(EmbeddingsData {
query: false,
texts,
})
.await?;
let data: Vec<_> = data
.into_iter()
.enumerate()
.map(|(i, v)| {
json!({
"object": "embedding",
"embedding": v,
"index": i,
})
})
.collect();
let output = json!({
"object": "list",
"data": data,
"model": embedding_model_id,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
}

#[derive(Debug, Deserialize)]
struct ChatCompletionReqBody {
struct ChatCompletionsReqBody {
model: String,
messages: Vec<Message>,
temperature: Option<f64>,
Expand All @@ -371,6 +424,19 @@ struct ChatCompletionReqBody {
stream: bool,
}

#[derive(Debug, Deserialize)]
struct EmbeddingsReqBody {
pub input: EmbeddingsReqBodyInput,
pub model: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum EmbeddingsReqBodyInput {
Single(String),
Multiple(Vec<String>),
}

#[derive(Debug)]
enum ResEvent {
First(Option<String>),
Expand Down

0 comments on commit 054e998

Please sign in to comment.