Skip to content

Commit

Permalink
feat: add support vertex-ai http bindings (#419)
Browse files Browse the repository at this point in the history
* feat: add support vertex-ai http bindings

* support prefix / suffix
  • Loading branch information
wsxiaoys authored Sep 9, 2023
1 parent 17397c8 commit f0ed366
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion crates/http-api-bindings/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Usage
## Examples

```bash
export MODEL_ID="code-gecko"
Expand All @@ -8,3 +8,14 @@ export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"

cargo run --example simple
```

## Usage

```bash
export MODEL_ID="code-gecko"
export PROJECT_ID="$(gcloud config get project)"
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"

cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}"
```
21 changes: 16 additions & 5 deletions crates/http-api-bindings/src/vertex_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,34 @@ impl VertexAIEngine {
client,
}
}

pub fn prompt_template() -> String {
"{prefix}<MID>{suffix}".to_owned()
}
}

#[async_trait]
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let stop_sequences: Vec<String> = options
.stop_words
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
.take(5)
.collect();

let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
instances: vec![Instance {
prefix: prompt.to_owned(),
suffix: None,
prefix: tokens[0].to_owned(),
suffix: Some(tokens[1].to_owned()),
}],
// options.max_input_length is ignored.
parameters: Parameters {
temperature: options.sampling_temperature,
max_output_tokens: options.max_decoding_length,
// vertex supports at most 64 output tokens.
max_output_tokens: std::cmp::min(options.max_decoding_length, 64),
stop_sequences,
},
};
Expand Down
1 change: 1 addition & 0 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ tantivy = { workspace = true }
anyhow = { workspace = true }
sysinfo = "0.29.8"
nvml-wrapper = "0.9.0"
http-api-bindings = { path = "../http-api-bindings" }

[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
Expand Down
47 changes: 41 additions & 6 deletions crates/tabby/src/serve/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use std::{path::Path, sync::Arc};

use axum::{extract::State, Json};
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::vertex_ai::VertexAIEngine;
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_common::{config::Config, events, path::ModelDir};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
Expand Down Expand Up @@ -128,22 +130,55 @@ pub struct CompletionState {

impl CompletionState {
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_engine(args, &model_dir, &metadata);
let (engine, prompt_template) = create_engine(args);

Self {
engine,
prompt_builder: prompt::PromptBuilder::new(
metadata.prompt_template,
prompt_template,
config.experimental.enable_prompt_rewrite,
),
}
}
}

fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
} else {
let params: Value =
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");

let kind = params
.get("kind")
.expect("Missing kind field")
.as_str()
.expect("Type unmatched");

if kind != "vertex-ai" {
fatal!("Only vertex_ai is supported for http backend");
}

let api_endpoint = params
.get("api_endpoint")
.expect("Missing api_endpoint field")
.as_str()
.expect("Type unmatched");
let authorization = params
.get("authorization")
.expect("Missing authorization field")
.as_str()
.expect("Type unmatched");
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
(engine, Some(VertexAIEngine::prompt_template()))
}
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_engine(
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
Expand All @@ -152,7 +187,7 @@ fn create_engine(
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_engine(
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
Expand Down
41 changes: 25 additions & 16 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use clap::Args;
use tabby_common::{config::Config, usage};
use tokio::time::sleep;
use tower_http::cors::CorsLayer;
use tracing::info;
use tracing::{info, warn};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;

Expand Down Expand Up @@ -58,6 +58,9 @@ pub enum Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,

#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}

#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
Expand Down Expand Up @@ -129,22 +132,28 @@ fn should_download_ggml_files(device: &Device) -> bool {
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);

// Ensure model exists.
tabby_download::download_model(
&args.model,
/* download_ctranslate2_files= */
!should_download_ggml_files(&args.device),
/* download_ggml_files= */ should_download_ggml_files(&args.device),
/* prefer_local_file= */ true,
)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
if args.device != Device::ExperimentalHttp {
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
let download_ggml_files = should_download_ggml_files(&args.device);

// Ensure model exists.
tabby_download::download_model(
&args.model,
download_ctranslate2_files,
download_ggml_files,
/* prefer_local_file= */ true,
)
});
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
}

info!("Starting server, this might takes a few minutes...");
let app = Router::new()
Expand Down

0 comments on commit f0ed366

Please sign in to comment.