Skip to content

Commit

Permalink
Updated Candle example [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 21, 2024
1 parent 5f3b7cf commit 0522a5d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
8 changes: 4 additions & 4 deletions examples/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ edition = "2021"
publish = false

[dependencies]
candle-core = "0.6"
candle-nn = "0.6"
candle-transformers = "0.6"
candle-core = "0.7"
candle-nn = "0.7"
candle-transformers = "0.7"
hf-hub = "0.3"
pgvector = { path = "../..", features = ["postgres"] }
postgres = "0.19"
serde_json = "1"
tokenizers = "0.19"
tokenizers = "0.20"
5 changes: 2 additions & 3 deletions examples/candle/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ impl EmbeddingModel {
Ok(Self { tokenizer, model })
}

// embed one at a time since BertModel does not support attention mask
// https://github.com/huggingface/candle/issues/1798
// TODO support multiple texts
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
let tokens = self.tokenizer.encode(text, true)?;
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
let token_type_ids = token_ids.zeros_like()?;
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)
Expand Down

0 comments on commit 0522a5d

Please sign in to comment.