Skip to content

Commit

Permalink
feat: new openai embedding (#350)
Browse files Browse the repository at this point in the history
* feat: new openai embedding

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

* fix

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

* fix

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

---------

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
  • Loading branch information
cutecutecat authored Feb 21, 2024
1 parent 3d1621b commit 5889b5d
Show file tree
Hide file tree
Showing 11 changed files with 1,434 additions and 15 deletions.
1,119 changes: 1,104 additions & 15 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ base = { path = "crates/base" }
detect = { path = "crates/detect" }
send_fd = { path = "crates/send_fd" }
service = { path = "crates/service" }
embedding = { path = "crates/embedding" }
interprocess_atomic_wait = { path = "crates/interprocess-atomic-wait" }
memfd = { path = "crates/memfd" }
pgrx = { version = "0.11.3", default-features = false, features = [] }
Expand All @@ -42,6 +43,7 @@ toml = "0.8.10"

[dev-dependencies]
pgrx-tests = "0.11.3"
httpmock = "0.7"

[patch.crates-io]
pgrx = { git = "https://github.com/tensorchord/pgrx.git", branch = "v0.11.3-patch" }
Expand Down
19 changes: 19 additions & 0 deletions crates/embedding/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "embedding"
version.workspace = true
edition.workspace = true

[dependencies]
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
thiserror = "~1.0"
serde = "~1.0"

[lints]
rust.internal_features = "allow"
rust.unsafe_op_in_unsafe_fn = "forbid"
rust.unused_lifetimes = "warn"
rust.unused_qualifications = "warn"

[dev-dependencies]
httpmock = "0.7"
serde_json = "~1.0"
145 changes: 145 additions & 0 deletions crates/embedding/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
pub mod openai;

use crate::openai::EmbeddingError;
use crate::openai::EmbeddingRequest;
use crate::openai::EmbeddingResponse;
use reqwest::blocking::Client;
use std::time::Duration;

pub struct OpenAIOptions {
pub base_url: String,
pub api_key: String,
}

pub fn openai_embedding(
input: String,
model: String,
opt: OpenAIOptions,
) -> Result<EmbeddingResponse, EmbeddingError> {
let url = format!("{}/embeddings", opt.base_url);
let client = match Client::builder().timeout(Duration::from_secs(30)).build() {
Ok(c) => c,
Err(e) => {
return Err(EmbeddingError {
hint: e.to_string(),
})
}
};
let form: EmbeddingRequest = EmbeddingRequest::new(model.to_string(), input);
let resp = match client
.post(url)
.header("Authorization", format!("Bearer {}", opt.api_key))
.form(&form)
.send()
{
Ok(c) => c,
Err(e) => {
return Err(EmbeddingError {
hint: e.to_string(),
})
}
};
match resp.json::<EmbeddingResponse>() {
Ok(c) => Ok(c),
Err(e) => Err(EmbeddingError {
hint: e.to_string(),
}),
}
}

#[cfg(test)]
mod tests {
use crate::openai::EmbeddingData;
use crate::openai::Usage;

use super::openai_embedding;
use super::EmbeddingResponse;
use super::OpenAIOptions;
use httpmock::Method::POST;
use httpmock::MockServer;

fn mock_server(resp: EmbeddingResponse) -> MockServer {
let server = MockServer::start();
let data = serde_json::to_string(&resp).unwrap();
let _ = server.mock(|when, then| {
when.method(POST).path("/embeddings");
then.status(200)
.header("content-type", "text/html; charset=UTF-8")
.body(data);
});
server
}

#[test]
fn test_openai_embedding_successful() {
let embedding = vec![1.0, 2.0, 3.0];
let resp = EmbeddingResponse {
object: "mock-object".to_string(),
data: vec![EmbeddingData {
object: "mock-object".to_string(),
embedding: embedding.clone(),
index: 0,
}],
model: "mock-model".to_string(),
usage: Usage {
prompt_tokens: 0,
total_tokens: 0,
},
};
let server = mock_server(resp);

let opt = OpenAIOptions {
base_url: server.url(""),
api_key: "fake-key".to_string(),
};

let real_resp = openai_embedding("mock-input".to_string(), "mock-model".to_string(), opt);
assert!(real_resp.is_ok());
let real_embedding = real_resp.unwrap().try_pop_embedding();
assert!(real_embedding.is_ok());
}

#[test]
fn test_openai_embedding_empty_embedding() {
let resp = EmbeddingResponse {
object: "mock-object".to_string(),
data: vec![],
model: "mock-model".to_string(),
usage: Usage {
prompt_tokens: 0,
total_tokens: 0,
},
};
let server = mock_server(resp);

let opt = OpenAIOptions {
base_url: server.url(""),
api_key: "fake-key".to_string(),
};

let real_resp = openai_embedding("mock-input".to_string(), "mock-model".to_string(), opt);
assert!(real_resp.is_ok());
let real_embedding = real_resp.unwrap().try_pop_embedding();
assert!(real_embedding.is_err());
}

#[test]
fn test_openai_embedding_error() {
let server = MockServer::start();

server.mock(|when, then| {
when.method(POST).path("/embeddings");
then.status(502)
.header("content-type", "text/html; charset=UTF-8")
.body("502 Bad Gateway");
});

let opt = OpenAIOptions {
base_url: server.url(""),
api_key: "fake-key".to_string(),
};

let real_resp = openai_embedding("mock-input".to_string(), "mock-model".to_string(), opt);
assert!(real_resp.is_err());
}
}
66 changes: 66 additions & 0 deletions crates/embedding/src/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use thiserror::Error;

#[derive(Debug, Error)]
#[error(
"\
Error happens at embedding.
INFORMATION: hint = {hint}"
)]
pub struct EmbeddingError {
pub hint: String,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f32>,
pub index: i32,
}

#[derive(Debug, Serialize, Clone)]
pub struct EmbeddingRequest {
pub model: String,
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

impl EmbeddingRequest {
pub fn new(model: String, input: String) -> Self {
Self {
model,
input,
dimensions: None,
user: None,
}
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Usage,
}

impl EmbeddingResponse {
pub fn try_pop_embedding(mut self) -> Result<Vec<f32>, EmbeddingError> {
match self.data.pop() {
Some(d) => Ok(d.embedding),
None => Err(EmbeddingError {
hint: "no embedding from service".to_string(),
}),
}
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: i32,
pub total_tokens: i32,
}
20 changes: 20 additions & 0 deletions src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::datatype::vecf32::{Vecf32, Vecf32Output};
use crate::gucs::embedding::openai_options;
use embedding::openai_embedding;
use pgrx::error;
use service::prelude::F32;

#[pgrx::pg_extern(volatile, strict)]
fn _vectors_text2vec_openai(input: String, model: String) -> Vecf32Output {
let options = openai_options();
let resp = match openai_embedding(input, model, options) {
Ok(r) => r,
Err(e) => error!("{}", e.to_string()),
};
let embedding = match resp.try_pop_embedding() {
Ok(emb) => emb.into_iter().map(F32).collect::<Vec<_>>(),
Err(e) => error!("{}", e.to_string()),
};

Vecf32::new_in_postgres(&embedding)
}
35 changes: 35 additions & 0 deletions src/gucs/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use super::guc_string_parse;
use embedding::OpenAIOptions;
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};
use std::ffi::CStr;

pub fn openai_options() -> OpenAIOptions {
let base_url = guc_string_parse(&OPENAI_BASE_URL, "vectors.openai_base");
let api_key = guc_string_parse(&OPENAI_API_KEY, "vectors.openai_api_key");
OpenAIOptions { base_url, api_key }
}

static OPENAI_API_KEY: GucSetting<Option<&'static CStr>> =
GucSetting::<Option<&'static CStr>>::new(None);

static OPENAI_BASE_URL: GucSetting<Option<&'static CStr>> =
GucSetting::<Option<&'static CStr>>::new(Some(c"https://api.openai.com/v1/"));

pub unsafe fn init() {
GucRegistry::define_string_guc(
"vectors.openai_api_key",
"The API key of OpenAI.",
"",
&OPENAI_API_KEY,
GucContext::Userset,
GucFlags::default(),
);
GucRegistry::define_string_guc(
"vectors.openai_base_url",
"The base url of OpenAI or compatible server.",
"",
&OPENAI_BASE_URL,
GucContext::Userset,
GucFlags::default(),
);
}
20 changes: 20 additions & 0 deletions src/gucs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
use crate::prelude::bad_guc_literal;
use pgrx::GucSetting;
use std::ffi::CStr;

pub mod embedding;
pub mod executing;
pub mod internal;
pub mod planning;
Expand All @@ -7,5 +12,20 @@ pub unsafe fn init() {
self::planning::init();
self::internal::init();
self::executing::init();
self::embedding::init();
}
}

fn guc_string_parse(
target: &'static GucSetting<Option<&'static CStr>>,
name: &'static str,
) -> String {
let value = match target.get() {
Some(s) => s,
None => bad_guc_literal(name, "should not be `NULL`"),
};
match value.to_str() {
Ok(s) => s.to_string(),
Err(_e) => bad_guc_literal(name, "should be a valid UTF-8 string"),
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

mod bgworker;
mod datatype;
mod embedding;
mod gucs;
mod index;
mod ipc;
Expand Down
8 changes: 8 additions & 0 deletions src/prelude/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ You should edit `shared_preload_libraries` in `postgresql.conf` to include `vect
or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_libraries = \"vectors.so\"'`.");
}

pub fn bad_guc_literal(key: &str, hint: &str) -> ! {
error!(
"\
Failed to parse a GUC variable.
INFORMATION: GUC = {key}, hint = {hint}"
);
}

pub fn check_type_dimensions(dimensions: Option<NonZeroU16>) -> NonZeroU16 {
match dimensions {
None => {
Expand Down
14 changes: 14 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,20 @@ STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_pgvectors_upgrade_wrapper';
CREATE FUNCTION to_svector(dims INT, indices INT[], vals real[]) RETURNS svector
IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_to_svector_wrapper';

CREATE FUNCTION text2vec_openai(input TEXT, model TEXT) RETURNS vector
STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_text2vec_openai_wrapper';

CREATE FUNCTION text2vec_openai_v3(input TEXT) RETURNS vector
STRICT LANGUAGE plpgsql AS
$$
DECLARE
variable vectors.vector;
BEGIN
variable := vectors.text2vec_openai(input, 'text-embedding-3-small');
RETURN variable;
END;
$$;

-- List of casts

CREATE CAST (real[] AS vector)
Expand Down

0 comments on commit 5889b5d

Please sign in to comment.