Skip to content

Commit

Permalink
feat: add support fastchat http bindings
Browse files Browse the repository at this point in the history
* feat: add support fastchat http bindings

Signed-off-by: Lei Wen <wenlei03@qiyi.com>
  • Loading branch information
wenlei03 committed Sep 10, 2023
1 parent 12a37e2 commit 18e1909
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 7 deletions.
93 changes: 93 additions & 0 deletions crates/http-api-bindings/src/fastchat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use async_trait::async_trait;
use reqwest::header;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_inference::{TextGeneration, TextGenerationOptions};

#[derive(Serialize)]
struct Request {
model: String,
prompt: Vec<String>,
max_tokens: usize,
temperature: f32,
}

#[derive(Deserialize)]
struct Response {
choices: Vec<Prediction>,
}

#[derive(Deserialize)]
struct Prediction {
text: Vec<String>,
}

pub struct FastChatEngine {
client: reqwest::Client,
api_endpoint: String,
model_name: String,
}

impl FastChatEngine {
pub fn create(api_endpoint: &str, model_name: &str, authorization: &str) -> Self {
let mut headers = reqwest::header::HeaderMap::new();
if authorization.len() > 0 {
headers.insert(
"Authorization",
header::HeaderValue::from_str(authorization)
.expect("Failed to create authorization header"),
);
}
let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Failed to construct HTTP client");
Self {
api_endpoint: api_endpoint.to_owned(),
model_name: model_name.to_owned(),
client,
}
}

pub fn prompt_template() -> String {
"{prefix}<MID>{suffix}".to_owned()
}
}

#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> = options
.stop_words
.iter()
.map(|x| x.to_string())
.collect();

let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
model: self.model_name.to_owned(),
prompt: vec![tokens[0].to_owned()],
max_tokens: options.max_decoding_length,
temperature: options.sampling_temperature,
};

// API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md
let resp = self
.client
.post(&self.api_endpoint)
.json(&request)
.send()
.await
.expect("Failed to making completion request");

if resp.status() != 200 {
let err: Value = resp.json().await.expect("Failed to parse response");
println!("Request failed: {}", err);
std::process::exit(1);
}

let resp: Response = resp.json().await.expect("Failed to parse response");

resp.choices[0].text[0].clone()
}
}
1 change: 1 addition & 0 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod vertex_ai;
pub mod fastchat;
22 changes: 16 additions & 6 deletions crates/tabby/src/serve/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{path::Path, sync::Arc};
use axum::{extract::State, Json};
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::vertex_ai::VertexAIEngine;
use http_api_bindings::fastchat::FastChatEngine;
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -158,10 +159,6 @@ fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Op
.as_str()
.expect("Type unmatched");

if kind != "vertex-ai" {
fatal!("Only vertex_ai is supported for http backend");
}

let api_endpoint = params
.get("api_endpoint")
.expect("Missing api_endpoint field")
Expand All @@ -172,8 +169,21 @@ fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Op
.expect("Missing authorization field")
.as_str()
.expect("Type unmatched");
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
(engine, Some(VertexAIEngine::prompt_template()))

if kind == "vertex-ai" {
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
(engine, Some(VertexAIEngine::prompt_template()))
} else if kind == "fastchat" {
let model_name = params
.get("model_name")
.expect("Missing model_name field")
.as_str()
.expect("Type unmatched");
let engine = Box::new(FastChatEngine::create(api_endpoint, model_name, authorization));
(engine, Some(FastChatEngine::prompt_template()))
} else {
fatal!("Only vertex_ai and fastchat are supported for http backend");
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pub struct ServeArgs {
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn should_download_ggml_files(device: &Device) -> bool {
fn should_download_ggml_files(_device: &Device) -> bool {
false
}

Expand Down

0 comments on commit 18e1909

Please sign in to comment.