Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: serve embeddings api #624

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assets/arena.html
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@
async init() {
try {
const models = await fetchJSON(MODELS_API);
this.models = models;
this.models = models.filter(v => !v.mode || v.mode === "chat");
} catch (err) {
toast("No available model");
console.error("Failed to load models", err);
Expand Down
2 changes: 1 addition & 1 deletion assets/playground.html
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@
async init() {
await Promise.all([
fetchJSON(MODELS_API).then(models => {
this.models = models;
this.models = models.filter(v => !v.mode || v.mode === "chat");
}).catch(err => {
toast("No model available");
console.error("Failed to load models", err);
Expand Down
4 changes: 2 additions & 2 deletions src/client/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::config::Config;
use crate::utils::{estimate_token_length, format_option_value};

use anyhow::{bail, Result};
use serde::Deserialize;
use serde::{Deserialize, Serialize};

const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;
Expand Down Expand Up @@ -242,7 +242,7 @@ impl Model {
}
}

#[derive(Debug, Clone, Default, Deserialize)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelData {
pub name: String,
#[serde(default = "default_model_mode")]
Expand Down
3 changes: 3 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,9 @@ impl WorkingMode {
pub fn is_repl(&self) -> bool {
*self == WorkingMode::Repl
}
pub fn is_serve(&self) -> bool {
*self == WorkingMode::Serve
}
}

bitflags::bitflags! {
Expand Down
21 changes: 13 additions & 8 deletions src/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use simplelog::{format_description, Config as LogConfig, ConfigBuilder};

#[cfg(debug_assertions)]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
let config = build_config();
if working_mode == WorkingMode::Serve {
let is_serve = working_mode.is_serve();
let config = build_config(is_serve);
if is_serve {
simplelog::SimpleLogger::init(LevelFilter::Debug, config)?;
} else {
let file = std::fs::File::create(crate::config::Config::local_path("debug.log")?)?;
Expand All @@ -18,17 +19,21 @@ pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {

#[cfg(not(debug_assertions))]
pub fn setup_logger(working_mode: WorkingMode) -> Result<()> {
let config = build_config();
if working_mode == WorkingMode::Serve {
if working_mode.is_serve() {
let config = build_config(true);
simplelog::SimpleLogger::init(log::LevelFilter::Info, config)?;
}
Ok(())
}

fn build_config() -> LogConfig {
let log_filter = match std::env::var("AICHAT_LOG_FILTER") {
Ok(v) => v,
Err(_) => "aichat".into(),
fn build_config(is_serve: bool) -> LogConfig {
let log_filter = if is_serve {
"aichat::serve".into()
} else {
match std::env::var("AICHAT_LOG_FILTER") {
Ok(v) => v,
Err(_) => "aichat".into(),
}
};
ConfigBuilder::new()
.add_filter_allow(log_filter)
Expand Down
120 changes: 93 additions & 27 deletions src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let listener = TcpListener::bind(&addr).await?;
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
println!("Embeddings API: http://{addr}/v1/embeddings");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena");
shutdown_signal().await;
Expand All @@ -69,7 +70,7 @@ impl Server {
let clients = config.clients.clone();
let model = config.model.clone();
let roles = config.roles.clone();
let mut models = list_chat_models(&config);
let mut models = list_models(&config);
let mut default_model = model.clone();
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
models.insert(0, &default_model);
Expand All @@ -82,26 +83,14 @@ impl Server {
} else {
model.id()
};
let ModelData {
max_input_tokens,
max_output_tokens,
require_max_tokens,
input_price,
output_price,
supports_vision,
supports_function_calling,
..
} = model.data();
json!({
"id": id,
"max_input_tokens": max_input_tokens,
"max_output_tokens": max_output_tokens,
"require_max_tokens": require_max_tokens,
"input_price": input_price,
"output_price": output_price,
"supports_vision": supports_vision,
"supports_function_calling": supports_function_calling,
})
let mut value = json!(model.data());
if let Some(value_obj) = value.as_object_mut() {
value_obj.insert("id".into(), id.into());
value_obj.insert("object".into(), "model".into());
value_obj.insert("owned_by".into(), model.client_name().into());
value_obj.remove("name");
}
value
})
.collect();
Self {
Expand Down Expand Up @@ -161,7 +150,9 @@ impl Server {

let mut status = StatusCode::OK;
let res = if path == "/v1/chat/completions" {
self.chat_completion(req).await
self.chat_completions(req).await
} else if path == "/v1/embeddings" {
self.embeddings(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
Expand Down Expand Up @@ -220,12 +211,16 @@ impl Server {
Ok(res)
}

async fn chat_completion(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: ChatCompletionReqBody = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;

let ChatCompletionReqBody {
debug!("chat completions request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid requst body, {err}"))?;

let ChatCompletionsReqBody {
model,
messages,
temperature,
Expand Down Expand Up @@ -358,10 +353,68 @@ impl Server {
Ok(res)
}
}

async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;

debug!("embeddings request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid requst body, {err}"))?;

let EmbeddingsReqBody {
input,
model: embedding_model_id,
} = req_body;

let config = Config {
clients: self.clients.to_vec(),
..Default::default()
};
let config = Arc::new(RwLock::new(config));
let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?;

let texts = match input {
EmbeddingsReqBodyInput::Single(v) => vec![v],
EmbeddingsReqBodyInput::Multiple(v) => v,
};
let client = init_client(&config, Some(embedding_model))?;
let data = client
.embeddings(EmbeddingsData {
query: false,
texts,
})
.await?;
let data: Vec<_> = data
.into_iter()
.enumerate()
.map(|(i, v)| {
json!({
"object": "embedding",
"embedding": v,
"index": i,
})
})
.collect();
let output = json!({
"object": "list",
"data": data,
"model": embedding_model_id,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
}

#[derive(Debug, Deserialize)]
struct ChatCompletionReqBody {
struct ChatCompletionsReqBody {
model: String,
messages: Vec<Message>,
temperature: Option<f64>,
Expand All @@ -371,6 +424,19 @@ struct ChatCompletionReqBody {
stream: bool,
}

#[derive(Debug, Deserialize)]
struct EmbeddingsReqBody {
pub input: EmbeddingsReqBodyInput,
pub model: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum EmbeddingsReqBodyInput {
Single(String),
Multiple(Vec<String>),
}

#[derive(Debug)]
enum ResEvent {
First(Option<String>),
Expand Down