Skip to content

Commit

Permalink
feat: add http api bindings (#410)
Browse files Browse the repository at this point in the history
* feat: add http-api-bindings

* feat: add http-api-bindings

* hand max_input_length

* rename

* update

* update

* add examples/simple.rs

* update

* add default value for stop words

* update

* fix lint

* update
  • Loading branch information
wsxiaoys authored Sep 9, 2023
1 parent ad3b974 commit 17397c8
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 6 deletions.
24 changes: 18 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/rust-cxx-cmake-bridge",
"crates/llama-cpp-bindings",
"crates/stop-words",
"crates/http-api-bindings",
]

[workspace.package]
Expand Down
14 changes: 14 additions & 0 deletions crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "http-api-bindings"
version = "0.1.0"
edition = "2021"

[dependencies]
async-trait.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.105"
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }

[dev-dependencies]
tokio = { workspace = true, features = ["full"] }
10 changes: 10 additions & 0 deletions crates/http-api-bindings/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
## Usage

```bash
export MODEL_ID="code-gecko"
export PROJECT_ID="$(gcloud config get project)"
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"

cargo run --example simple
```
20 changes: 20 additions & 0 deletions crates/http-api-bindings/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use std::env;

use http_api_bindings::vertex_ai::VertexAIEngine;
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};

#[tokio::main]
async fn main() {
let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set");
let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set");
let engine = VertexAIEngine::create(&api_endpoint, &authorization);

let options = TextGenerationOptionsBuilder::default()
.sampling_temperature(0.1)
.max_decoding_length(32)
.build()
.unwrap();
let prompt = "def fib(n)";
let text = engine.generate(prompt, options).await;
println!("{}{}", prompt, text);
}
1 change: 1 addition & 0 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod vertex_ai;
99 changes: 99 additions & 0 deletions crates/http-api-bindings/src/vertex_ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use async_trait::async_trait;
use reqwest::header;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_inference::{TextGeneration, TextGenerationOptions};

#[derive(Serialize)]
struct Request {
instances: Vec<Instance>,
parameters: Parameters,
}

#[derive(Serialize)]
struct Instance {
prefix: String,
suffix: Option<String>,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Parameters {
temperature: f32,
max_output_tokens: usize,
stop_sequences: Vec<String>,
}

#[derive(Deserialize)]
struct Response {
predictions: Vec<Prediction>,
}

#[derive(Deserialize)]
struct Prediction {
content: String,
}

pub struct VertexAIEngine {
client: reqwest::Client,
api_endpoint: String,
}

impl VertexAIEngine {
pub fn create(api_endpoint: &str, authorization: &str) -> Self {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
header::HeaderValue::from_str(authorization)
.expect("Failed to create authorization header"),
);
let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Failed to construct HTTP client");
Self {
api_endpoint: api_endpoint.to_owned(),
client,
}
}
}

#[async_trait]
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();

let request = Request {
instances: vec![Instance {
prefix: prompt.to_owned(),
suffix: None,
}],
// options.max_input_length is ignored.
parameters: Parameters {
temperature: options.sampling_temperature,
max_output_tokens: options.max_decoding_length,
stop_sequences,
},
};

// API Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#code-completion-prompt-parameters
let resp = self
.client
.post(&self.api_endpoint)
.json(&request)
.send()
.await
.expect("Failed to making completion request");

if resp.status() != 200 {
let err: Value = resp.json().await.expect("Failed to parse response");
println!("Request failed: {}", err);
std::process::exit(1);
}

let resp: Response = resp.json().await.expect("Failed to parse response");

resp.predictions[0].content.clone()
}
}
3 changes: 3 additions & 0 deletions crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")]
pub sampling_temperature: f32,

#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>,
}

static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
Expand Down

0 comments on commit 17397c8

Please sign in to comment.