-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: new openai embedding Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> --------- Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
- Loading branch information
1 parent
3d1621b
commit 5889b5d
Showing
11 changed files
with
1,434 additions
and
15 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[package] | ||
name = "embedding" | ||
version.workspace = true | ||
edition.workspace = true | ||
|
||
[dependencies] | ||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } | ||
thiserror = "~1.0" | ||
serde = "~1.0" | ||
|
||
[lints] | ||
rust.internal_features = "allow" | ||
rust.unsafe_op_in_unsafe_fn = "forbid" | ||
rust.unused_lifetimes = "warn" | ||
rust.unused_qualifications = "warn" | ||
|
||
[dev-dependencies] | ||
httpmock = "0.7" | ||
serde_json = "~1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
pub mod openai; | ||
|
||
use crate::openai::EmbeddingError; | ||
use crate::openai::EmbeddingRequest; | ||
use crate::openai::EmbeddingResponse; | ||
use reqwest::blocking::Client; | ||
use std::time::Duration; | ||
|
||
pub struct OpenAIOptions { | ||
pub base_url: String, | ||
pub api_key: String, | ||
} | ||
|
||
pub fn openai_embedding( | ||
input: String, | ||
model: String, | ||
opt: OpenAIOptions, | ||
) -> Result<EmbeddingResponse, EmbeddingError> { | ||
let url = format!("{}/embeddings", opt.base_url); | ||
let client = match Client::builder().timeout(Duration::from_secs(30)).build() { | ||
Ok(c) => c, | ||
Err(e) => { | ||
return Err(EmbeddingError { | ||
hint: e.to_string(), | ||
}) | ||
} | ||
}; | ||
let form: EmbeddingRequest = EmbeddingRequest::new(model.to_string(), input); | ||
let resp = match client | ||
.post(url) | ||
.header("Authorization", format!("Bearer {}", opt.api_key)) | ||
.form(&form) | ||
.send() | ||
{ | ||
Ok(c) => c, | ||
Err(e) => { | ||
return Err(EmbeddingError { | ||
hint: e.to_string(), | ||
}) | ||
} | ||
}; | ||
match resp.json::<EmbeddingResponse>() { | ||
Ok(c) => Ok(c), | ||
Err(e) => Err(EmbeddingError { | ||
hint: e.to_string(), | ||
}), | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::openai::EmbeddingData; | ||
use crate::openai::Usage; | ||
|
||
use super::openai_embedding; | ||
use super::EmbeddingResponse; | ||
use super::OpenAIOptions; | ||
use httpmock::Method::POST; | ||
use httpmock::MockServer; | ||
|
||
fn mock_server(resp: EmbeddingResponse) -> MockServer { | ||
let server = MockServer::start(); | ||
let data = serde_json::to_string(&resp).unwrap(); | ||
let _ = server.mock(|when, then| { | ||
when.method(POST).path("/embeddings"); | ||
then.status(200) | ||
.header("content-type", "text/html; charset=UTF-8") | ||
.body(data); | ||
}); | ||
server | ||
} | ||
|
||
#[test] | ||
fn test_openai_embedding_successful() { | ||
let embedding = vec![1.0, 2.0, 3.0]; | ||
let resp = EmbeddingResponse { | ||
object: "mock-object".to_string(), | ||
data: vec![EmbeddingData { | ||
object: "mock-object".to_string(), | ||
embedding: embedding.clone(), | ||
index: 0, | ||
}], | ||
model: "mock-model".to_string(), | ||
usage: Usage { | ||
prompt_tokens: 0, | ||
total_tokens: 0, | ||
}, | ||
}; | ||
let server = mock_server(resp); | ||
|
||
let opt = OpenAIOptions { | ||
base_url: server.url(""), | ||
api_key: "fake-key".to_string(), | ||
}; | ||
|
||
let real_resp = openai_embedding("mock-input".to_string(), "mock-model".to_string(), opt); | ||
assert!(real_resp.is_ok()); | ||
let real_embedding = real_resp.unwrap().try_pop_embedding(); | ||
assert!(real_embedding.is_ok()); | ||
} | ||
|
||
#[test] | ||
fn test_openai_embedding_empty_embedding() { | ||
let resp = EmbeddingResponse { | ||
object: "mock-object".to_string(), | ||
data: vec![], | ||
model: "mock-model".to_string(), | ||
usage: Usage { | ||
prompt_tokens: 0, | ||
total_tokens: 0, | ||
}, | ||
}; | ||
let server = mock_server(resp); | ||
|
||
let opt = OpenAIOptions { | ||
base_url: server.url(""), | ||
api_key: "fake-key".to_string(), | ||
}; | ||
|
||
let real_resp = openai_embedding("mock-input".to_string(), "mock-model".to_string(), opt); | ||
assert!(real_resp.is_ok()); | ||
let real_embedding = real_resp.unwrap().try_pop_embedding(); | ||
assert!(real_embedding.is_err()); | ||
} | ||
|
||
#[test] | ||
fn test_openai_embedding_error() { | ||
let server = MockServer::start(); | ||
|
||
server.mock(|when, then| { | ||
when.method(POST).path("/embeddings"); | ||
then.status(502) | ||
.header("content-type", "text/html; charset=UTF-8") | ||
.body("502 Bad Gateway"); | ||
}); | ||
|
||
let opt = OpenAIOptions { | ||
base_url: server.url(""), | ||
api_key: "fake-key".to_string(), | ||
}; | ||
|
||
let real_resp = openai_embedding("mock-input".to_string(), "mock-model".to_string(), opt); | ||
assert!(real_resp.is_err()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use serde::{Deserialize, Serialize}; | ||
use std::fmt::Debug; | ||
use thiserror::Error; | ||
|
||
#[derive(Debug, Error)] | ||
#[error( | ||
"\ | ||
Error happens at embedding. | ||
INFORMATION: hint = {hint}" | ||
)] | ||
pub struct EmbeddingError { | ||
pub hint: String, | ||
} | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct EmbeddingData { | ||
pub object: String, | ||
pub embedding: Vec<f32>, | ||
pub index: i32, | ||
} | ||
|
||
#[derive(Debug, Serialize, Clone)] | ||
pub struct EmbeddingRequest { | ||
pub model: String, | ||
pub input: String, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub dimensions: Option<i32>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub user: Option<String>, | ||
} | ||
|
||
impl EmbeddingRequest { | ||
pub fn new(model: String, input: String) -> Self { | ||
Self { | ||
model, | ||
input, | ||
dimensions: None, | ||
user: None, | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct EmbeddingResponse { | ||
pub object: String, | ||
pub data: Vec<EmbeddingData>, | ||
pub model: String, | ||
pub usage: Usage, | ||
} | ||
|
||
impl EmbeddingResponse { | ||
pub fn try_pop_embedding(mut self) -> Result<Vec<f32>, EmbeddingError> { | ||
match self.data.pop() { | ||
Some(d) => Ok(d.embedding), | ||
None => Err(EmbeddingError { | ||
hint: "no embedding from service".to_string(), | ||
}), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct Usage { | ||
pub prompt_tokens: i32, | ||
pub total_tokens: i32, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
use crate::datatype::vecf32::{Vecf32, Vecf32Output}; | ||
use crate::gucs::embedding::openai_options; | ||
use embedding::openai_embedding; | ||
use pgrx::error; | ||
use service::prelude::F32; | ||
|
||
#[pgrx::pg_extern(volatile, strict)] | ||
fn _vectors_text2vec_openai(input: String, model: String) -> Vecf32Output { | ||
let options = openai_options(); | ||
let resp = match openai_embedding(input, model, options) { | ||
Ok(r) => r, | ||
Err(e) => error!("{}", e.to_string()), | ||
}; | ||
let embedding = match resp.try_pop_embedding() { | ||
Ok(emb) => emb.into_iter().map(F32).collect::<Vec<_>>(), | ||
Err(e) => error!("{}", e.to_string()), | ||
}; | ||
|
||
Vecf32::new_in_postgres(&embedding) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
use super::guc_string_parse; | ||
use embedding::OpenAIOptions; | ||
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; | ||
use std::ffi::CStr; | ||
|
||
pub fn openai_options() -> OpenAIOptions { | ||
let base_url = guc_string_parse(&OPENAI_BASE_URL, "vectors.openai_base"); | ||
let api_key = guc_string_parse(&OPENAI_API_KEY, "vectors.openai_api_key"); | ||
OpenAIOptions { base_url, api_key } | ||
} | ||
|
||
static OPENAI_API_KEY: GucSetting<Option<&'static CStr>> = | ||
GucSetting::<Option<&'static CStr>>::new(None); | ||
|
||
static OPENAI_BASE_URL: GucSetting<Option<&'static CStr>> = | ||
GucSetting::<Option<&'static CStr>>::new(Some(c"https://api.openai.com/v1/")); | ||
|
||
pub unsafe fn init() { | ||
GucRegistry::define_string_guc( | ||
"vectors.openai_api_key", | ||
"The API key of OpenAI.", | ||
"", | ||
&OPENAI_API_KEY, | ||
GucContext::Userset, | ||
GucFlags::default(), | ||
); | ||
GucRegistry::define_string_guc( | ||
"vectors.openai_base_url", | ||
"The base url of OpenAI or compatible server.", | ||
"", | ||
&OPENAI_BASE_URL, | ||
GucContext::Userset, | ||
GucFlags::default(), | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
mod bgworker; | ||
mod datatype; | ||
mod embedding; | ||
mod gucs; | ||
mod index; | ||
mod ipc; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters