Skip to content

Commit

Permalink
feat: save rag in YAML instead of bin (#848)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Sep 8, 2024
1 parent 89554e0 commit a56d5f2
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ impl Config {
pub fn rag_file(&self, name: &str) -> Result<PathBuf> {
let path = match &self.agent {
Some(agent) => Self::agent_rag_file(agent.name(), name)?,
None => Self::rags_dir()?.join(format!("{name}.bin")),
None => Self::rags_dir()?.join(format!("{name}.yaml")),
};
Ok(path)
}
Expand All @@ -359,7 +359,7 @@ impl Config {
}

pub fn agent_rag_file(agent_name: &str, rag_name: &str) -> Result<PathBuf> {
Ok(Self::agent_config_dir(agent_name)?.join(format!("{rag_name}.bin")))
Ok(Self::agent_config_dir(agent_name)?.join(format!("{rag_name}.yaml")))
}

pub fn agent_variables_file(name: &str) -> Result<PathBuf> {
Expand Down Expand Up @@ -1186,7 +1186,7 @@ impl Config {
let mut names = vec![];
for entry in rd.flatten() {
let name = entry.file_name();
if let Some(name) = name.to_string_lossy().strip_suffix(".bin") {
if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") {
names.push(name.to_string());
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,10 @@ impl Session {
self.path = Some(session_path.display().to_string());

let content = serde_yaml::to_string(&self)
.with_context(|| format!("Failed to serde session {}", self.name))?;
.with_context(|| format!("Failed to serde session '{}'", self.name))?;
write(session_path, content).with_context(|| {
format!(
"Failed to write session {} to {}",
"Failed to write session '{}' to '{}'",
self.name,
session_path.display()
)
Expand Down
23 changes: 13 additions & 10 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ use crate::utils::*;

mod bm25;
mod loader;
mod serde_vectors;
mod splitter;

use anyhow::bail;
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use hnsw_rs::prelude::*;
use indexmap::{IndexMap, IndexSet};
use inquire::{required, validator::Validation, Confirm, Select, Text};
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::{fmt::Debug, io::BufReader, path::Path};
use std::{collections::HashMap, fmt::Debug, fs, path::Path};

pub struct Rag {
config: GlobalConfig,
Expand Down Expand Up @@ -103,9 +102,8 @@ impl Rag {

pub fn load(config: &GlobalConfig, name: &str, path: &Path) -> Result<Self> {
let err = || format!("Failed to load rag '{name}' at '{}'", path.display());
let file = std::fs::File::open(path).with_context(err)?;
let reader = BufReader::new(file);
let data: RagData = bincode::deserialize_from(reader).with_context(err)?;
let content = fs::read_to_string(path).with_context(err)?;
let data: RagData = serde_yaml::from_str(&content).with_context(err)?;
Self::create(config, name, path, data)
}

Expand Down Expand Up @@ -236,9 +234,13 @@ impl Rag {
}
let path = Path::new(&self.path);
ensure_parent_exists(path)?;
let mut file = std::fs::File::create(path)?;
bincode::serialize_into(&mut file, &self.data)
.with_context(|| format!("Failed to save rag '{}'", self.name))?;

let content = serde_yaml::to_string(&self.data)
.with_context(|| format!("Failed to serde rag '{}'", self.name))?;
fs::write(path, content).with_context(|| {
format!("Failed to save rag '{}' to '{}'", self.name, path.display())
})?;

Ok(true)
}

Expand Down Expand Up @@ -576,6 +578,7 @@ pub struct RagData {
pub next_file_id: FileId,
pub document_paths: Vec<String>,
pub files: IndexMap<FileId, RagFile>,
#[serde(with = "serde_vectors")]
pub vectors: IndexMap<DocumentId, Vec<f32>>,
}

Expand Down
69 changes: 69 additions & 0 deletions src/rag/serde_vectors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use super::*;

use base64::{engine::general_purpose::STANDARD, Engine};
use serde::{de, Deserializer, Serializer};

pub fn serialize<S>(
vectors: &IndexMap<DocumentId, Vec<f32>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded_map: IndexMap<String, String> = vectors
.iter()
.map(|(key, vec)| {
let (h, l) = split_document_id(*key);
let byte_slice = unsafe {
std::slice::from_raw_parts(
vec.as_ptr() as *const u8,
vec.len() * std::mem::size_of::<f32>(),
)
};
(format!("{h}-{l}"), STANDARD.encode(byte_slice))
})
.collect();

encoded_map.serialize(serializer)
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<IndexMap<DocumentId, Vec<f32>>, D::Error>
where
D: Deserializer<'de>,
{
let encoded_map: IndexMap<String, String> =
IndexMap::<String, String>::deserialize(deserializer)?;

let mut decoded_map = IndexMap::new();
for (key, base64_str) in encoded_map {
let decoded_key: DocumentId = key
.split_once('-')
.and_then(|(h, l)| {
let h = h.parse::<usize>().ok()?;
let l = l.parse::<usize>().ok()?;
Some(combine_document_id(h, l))
})
.ok_or_else(|| de::Error::custom(format!("Invalid key '{key}'")))?;

let decoded_data = STANDARD.decode(&base64_str).map_err(de::Error::custom)?;

if decoded_data.len() % std::mem::size_of::<f32>() != 0 {
return Err(de::Error::custom(format!("Invalid vector at '{key}'")));
}

let num_f32s = decoded_data.len() / std::mem::size_of::<f32>();

let mut vec_f32 = vec![0.0f32; num_f32s];
unsafe {
std::ptr::copy_nonoverlapping(
decoded_data.as_ptr(),
vec_f32.as_mut_ptr() as *mut u8,
decoded_data.len(),
);
}

decoded_map.insert(decoded_key, vec_f32);
}

Ok(decoded_map)
}

0 comments on commit a56d5f2

Please sign in to comment.