Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: when saving input to message.md, use file paths instead of file contents #905

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::utils::{base64_encode, sha256, AbortSignal};

use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use path_absolutize::Absolutize;
use std::{collections::HashMap, fs::File, io::Read, path::Path};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};

Expand All @@ -22,6 +23,7 @@ lazy_static::lazy_static! {
pub struct Input {
config: GlobalConfig,
text: String,
raw: (String, Vec<String>),
patched_text: Option<String>,
continue_output: Option<String>,
regenerate: bool,
Expand All @@ -40,6 +42,7 @@ impl Input {
Self {
config: config.clone(),
text: text.to_string(),
raw: (text.to_string(), vec![]),
patched_text: None,
continue_output: None,
regenerate: false,
Expand All @@ -59,16 +62,33 @@ impl Input {
paths: Vec<String>,
role: Option<Role>,
) -> Result<Self> {
let spinner = create_spinner("Loading files").await;
let raw_text = text.to_string();
let mut raw_paths = vec![];
let mut local_paths = vec![];
let mut remote_urls = vec![];
for path in paths {
match resolve_local_path(&path) {
Some(v) => {
if let Ok(path) = Path::new(&v).absolutize() {
raw_paths.push(path.display().to_string());
}
local_paths.push(v);
}
None => {
raw_paths.push(path.clone());
remote_urls.push(path);
}
}
}
let ret = load_documents(config, local_paths, remote_urls).await;
spinner.stop();
let (files, medias, data_urls) = ret?;
let mut texts = vec![];
if !text.is_empty() {
texts.push(text.to_string());
};
let spinner = create_spinner("Loading files").await;
let ret = load_paths(config, paths).await;
spinner.stop();
let (files, medias, data_urls) = ret?;
let files_len = files.len();
if files_len > 0 {
if !files.is_empty() {
texts.push(String::new());
}
for (path, contents) in files {
Expand All @@ -78,6 +98,7 @@ impl Input {
Ok(Self {
config: config.clone(),
text: texts.join("\n"),
raw: (raw_text, raw_paths),
patched_text: None,
continue_output: None,
regenerate: false,
Expand Down Expand Up @@ -266,6 +287,21 @@ impl Input {
}
}

pub fn raw(&self) -> String {
let (text, files) = &self.raw;
let mut segments = files.to_vec();
if !segments.is_empty() {
segments.insert(0, ".file".into());
}
if !text.is_empty() {
if !segments.is_empty() {
segments.push("--".into());
}
segments.push(text.clone());
}
segments.join(" ")
}

pub fn render(&self) -> String {
let text = self.text();
if self.medias.is_empty() {
Expand Down Expand Up @@ -316,22 +352,15 @@ fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
}
}

async fn load_paths(
async fn load_documents(
config: &GlobalConfig,
paths: Vec<String>,
local_paths: Vec<String>,
remote_urls: Vec<String>,
) -> Result<(Vec<(String, String)>, Vec<String>, HashMap<String, String>)> {
let mut files = vec![];
let mut medias = vec![];
let mut data_urls = HashMap::new();
let loaders = config.read().document_loaders.clone();
let mut local_paths = vec![];
let mut remote_urls = vec![];
for path in paths {
match resolve_local_path(&path) {
Some(v) => local_paths.push(v),
None => remote_urls.push(path),
}
}
let local_files = expand_glob_paths(&local_paths).await?;
for file_path in local_files {
if is_image(&file_path) {
Expand Down
6 changes: 4 additions & 2 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,7 @@ impl Config {
}
let timestamp = now();
let summary = input.summary();
let input_markdown = input.render();
let raw_input = input.raw();
let scope = if self.agent.is_none() {
let role_name = if input.role().is_derived() {
None
Expand All @@ -1775,7 +1775,9 @@ impl Config {
} else {
String::new()
};
let output = format!("# CHAT: {summary} [{timestamp}]{scope}\n{input_markdown}\n--------\n{output}\n--------\n\n",);
let output = format!(
"# CHAT: {summary} [{timestamp}]{scope}\n{raw_input}\n--------\n{output}\n--------\n\n",
);
file.write_all(output.as_bytes())
.with_context(|| "Failed to save message")
}
Expand Down
2 changes: 1 addition & 1 deletion src/config/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl Role {
} else if self.is_embedded_prompt() {
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
} else {
format!("{}\n\n{}", self.prompt, input.render())
format!("{}\n\n{}", self.prompt, input_markdown)
}
}

Expand Down