From f0ed366420493ef2193cbd4f15c5937a370759a7 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 9 Sep 2023 19:22:58 +0800 Subject: [PATCH] feat: add support vertex-ai http bindings (#419) * feat: add support vertex-ai http bindings * support prefix / suffix --- Cargo.lock | 1 + crates/http-api-bindings/README.md | 13 ++++++- crates/http-api-bindings/src/vertex_ai.rs | 21 +++++++--- crates/tabby/Cargo.toml | 1 + crates/tabby/src/serve/completions.rs | 47 ++++++++++++++++++++--- crates/tabby/src/serve/mod.rs | 41 ++++++++++++-------- 6 files changed, 96 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 999933915442..c522efd42763 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2959,6 +2959,7 @@ dependencies = [ "axum-tracing-opentelemetry", "clap", "ctranslate2-bindings", + "http-api-bindings", "hyper", "lazy_static", "llama-cpp-bindings", diff --git a/crates/http-api-bindings/README.md b/crates/http-api-bindings/README.md index 81ad900ceb7a..8664710a2bc0 100644 --- a/crates/http-api-bindings/README.md +++ b/crates/http-api-bindings/README.md @@ -1,4 +1,4 @@ -## Usage +## Examples ```bash export MODEL_ID="code-gecko" @@ -8,3 +8,14 @@ export AUTHORIZATION="Bearer $(gcloud auth print-access-token)" cargo run --example simple ``` + +## Usage + +```bash +export MODEL_ID="code-gecko" +export PROJECT_ID="$(gcloud config get project)" +export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict" +export AUTHORIZATION="Bearer $(gcloud auth print-access-token)" + +cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}" +``` diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs index e1f7e8a775ff..c2dd226379e0 100644 --- a/crates/http-api-bindings/src/vertex_ai.rs +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -56,23 +56,34 @@ impl VertexAIEngine { client, } } + + pub fn prompt_template() -> String { + "{prefix}{suffix}".to_owned() + } } #[async_trait] impl TextGeneration for VertexAIEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let stop_sequences: Vec = - options.stop_words.iter().map(|x| x.to_string()).collect(); + let stop_sequences: Vec = options + .stop_words + .iter() + .map(|x| x.to_string()) + // vertex supports at most 5 stop sequence. + .take(5) + .collect(); + let tokens: Vec<&str> = prompt.split("").collect(); let request = Request { instances: vec![Instance { - prefix: prompt.to_owned(), - suffix: None, + prefix: tokens[0].to_owned(), + suffix: Some(tokens[1].to_owned()), }], // options.max_input_length is ignored. parameters: Parameters { temperature: options.sampling_temperature, - max_output_tokens: options.max_decoding_length, + // vertex supports at most 64 output tokens. + max_output_tokens: std::cmp::min(options.max_decoding_length, 64), stop_sequences, }, }; diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 209db9945edf..ff79db125cde 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -35,6 +35,7 @@ tantivy = { workspace = true } anyhow = { workspace = true } sysinfo = "0.29.8" nvml-wrapper = "0.9.0" +http-api-bindings = { path = "../http-api-bindings" } [target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies] llama-cpp-bindings = { path = "../llama-cpp-bindings" } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 64255d7296f0..10e558fc8f50 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -5,8 +5,10 @@ use std::{path::Path, sync::Arc}; use axum::{extract::State, Json}; use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; +use http_api_bindings::vertex_ai::VertexAIEngine; use hyper::StatusCode; use serde::{Deserialize, Serialize}; +use serde_json::Value; use tabby_common::{config::Config, events, path::ModelDir}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; @@ -128,22 +130,55 @@ pub struct CompletionState { impl CompletionState { pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self { - let model_dir = get_model_dir(&args.model); - let metadata = read_metadata(&model_dir); - let engine = create_engine(args, &model_dir, &metadata); + let (engine, prompt_template) = create_engine(args); Self { engine, prompt_builder: prompt::PromptBuilder::new( - metadata.prompt_template, + prompt_template, config.experimental.enable_prompt_rewrite, ), } } } +fn create_engine(args: &crate::serve::ServeArgs) -> (Box, Option) { + if args.device != super::Device::ExperimentalHttp { + let model_dir = get_model_dir(&args.model); + let metadata = read_metadata(&model_dir); + let engine = create_local_engine(args, &model_dir, &metadata); + (engine, metadata.prompt_template) + } else { + let params: Value = + serdeconv::from_json_str(&args.model).expect("Failed to parse model string"); + + let kind = params + .get("kind") + .expect("Missing kind field") + .as_str() + .expect("Type unmatched"); + + if kind != "vertex-ai" { + fatal!("Only vertex_ai is supported for http backend"); + } + + let api_endpoint = params + .get("api_endpoint") + .expect("Missing api_endpoint field") + .as_str() + .expect("Type unmatched"); + let authorization = params + .get("authorization") + .expect("Missing authorization field") + .as_str() + .expect("Type unmatched"); + let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization)); + (engine, Some(VertexAIEngine::prompt_template())) + } +} + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] -fn create_engine( +fn create_local_engine( args: &crate::serve::ServeArgs, model_dir: &ModelDir, metadata: &Metadata, @@ -152,7 +187,7 @@ fn create_engine( } #[cfg(all(target_os = "macos", target_arch = "aarch64"))] -fn create_engine( +fn create_local_engine( args: &crate::serve::ServeArgs, model_dir: &ModelDir, metadata: &Metadata, diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index db03fe92eb72..2e0535df4921 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -14,7 +14,7 @@ use clap::Args; use tabby_common::{config::Config, usage}; use tokio::time::sleep; use tower_http::cors::CorsLayer; -use tracing::info; +use tracing::{info, warn}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -58,6 +58,9 @@ pub enum Device { #[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[strum(serialize = "metal")] Metal, + + #[strum(serialize = "experimental_http")] + ExperimentalHttp, } #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] @@ -129,22 +132,28 @@ fn should_download_ggml_files(device: &Device) -> bool { pub async fn main(config: &Config, args: &ServeArgs) { valid_args(args); - // Ensure model exists. - tabby_download::download_model( - &args.model, - /* download_ctranslate2_files= */ - !should_download_ggml_files(&args.device), - /* download_ggml_files= */ should_download_ggml_files(&args.device), - /* prefer_local_file= */ true, - ) - .await - .unwrap_or_else(|err| { - fatal!( - "Failed to fetch model due to '{}', is '{}' a valid model id?", - err, - args.model + if args.device != Device::ExperimentalHttp { + let download_ctranslate2_files = !should_download_ggml_files(&args.device); + let download_ggml_files = should_download_ggml_files(&args.device); + + // Ensure model exists. + tabby_download::download_model( + &args.model, + download_ctranslate2_files, + download_ggml_files, + /* prefer_local_file= */ true, ) - }); + .await + .unwrap_or_else(|err| { + fatal!( + "Failed to fetch model due to '{}', is '{}' a valid model id?", + err, + args.model + ) + }); + } else { + warn!("HTTP device is unstable and does not comply with semver expectations.") + } info!("Starting server, this might takes a few minutes..."); let app = Router::new()