From a0e96f48ecf5cfbfa31dd162414171cf0ee860cf Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 5 Oct 2024 10:05:39 +0800 Subject: [PATCH] feat: add `.compress session` REPL command --- src/config/mod.rs | 24 ++++++++++++++++++++---- src/config/session.rs | 9 +++++---- src/repl/mod.rs | 36 +++++++++++++++++++++++++----------- 3 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 3243c26c..97cd3021 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1133,11 +1133,28 @@ impl Config { false } - pub fn compress_session(&mut self, summary: &str) { - if let Some(session) = self.session.as_mut() { - let summary_prompt = self.summary_prompt.as_deref().unwrap_or(SUMMARY_PROMPT); + pub async fn compress_session(config: &GlobalConfig) -> Result<()> { + match config.read().session.as_ref() { + Some(session) => { + if !session.has_user_messages() { + bail!("No need to compress since there are no messages in the session") + } + } + None => bail!("No session"), + } + let input = Input::from_str(config, config.read().summarize_prompt(), None); + let client = input.create_client()?; + let summary = client.chat_completions(input).await?.text; + let summary_prompt = config + .read() + .summary_prompt + .clone() + .unwrap_or_else(|| SUMMARY_PROMPT.into()); + if let Some(session) = config.write().session.as_mut() { session.compress(format!("{}{}", summary_prompt, summary)); } + config.write().last_message = None; + Ok(()) } pub fn summarize_prompt(&self) -> &str { @@ -1155,7 +1172,6 @@ impl Config { if let Some(session) = self.session.as_mut() { session.set_compressing(false); } - self.last_message = None; } pub async fn use_rag( diff --git a/src/config/session.rs b/src/config/session.rs index 0bb55e8c..4bda3478 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -110,6 +110,10 @@ impl Session { self.model().total_tokens(&self.messages) } + pub fn has_user_messages(&self) -> bool { + self.messages.iter().any(|v| v.role.is_user()) + } + pub fn user_messages_len(&self) -> usize { self.messages.iter().filter(|v| v.role.is_user()).count() } @@ -372,12 +376,9 @@ impl Session { } } } else { - let mut need_add_msg = true; if self.messages.is_empty() { self.messages.extend(input.role().build_messages(input)); - need_add_msg = false; - } - if need_add_msg { + } else { self.messages .push(Message::new(MessageRole::User, input.message_content())); } diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 2b103f48..5a99ae98 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -10,7 +10,7 @@ use crate::client::{call_chat_completions, call_chat_completions_streaming}; use crate::config::{AssertState, Config, GlobalConfig, Input, StateFlags}; use crate::function::need_send_tool_results; use crate::render::render_error; -use crate::utils::{create_abort_signal, set_text, temp_file, AbortSignal}; +use crate::utils::{create_abort_signal, create_spinner, set_text, temp_file, AbortSignal}; use anyhow::{bail, Context, Result}; use fancy_regex::Regex; @@ -31,7 +31,7 @@ lazy_static::lazy_static! { const MENU_NAME: &str = "completion_menu"; lazy_static::lazy_static! { - static ref REPL_COMMANDS: [ReplCommand; 33] = [ + static ref REPL_COMMANDS: [ReplCommand; 34] = [ ReplCommand::new(".help", "Show this help message", AssertState::pass()), ReplCommand::new(".info", "View system info", AssertState::pass()), ReplCommand::new(".model", "Change the current LLM", AssertState::pass()), @@ -75,6 +75,11 @@ lazy_static::lazy_static! { "Erase messages in the current session", AssertState::True(StateFlags::SESSION) ), + ReplCommand::new( + ".compress session", + "Compress messages in the current session", + AssertState::True(StateFlags::SESSION) + ), ReplCommand::new( ".info session", "View session info", @@ -360,6 +365,23 @@ impl Repl { } } } + ".compress" => { + match args.map(|v| match v.split_once(' ') { + Some((subcmd, args)) => (subcmd, Some(args.trim())), + None => (v, None), + }) { + Some(("session", _)) => { + let spinner = create_spinner("Compressing").await; + let ret = Config::compress_session(&self.config).await; + spinner.stop(); + ret?; + println!("✨ Successfully compressed the session"); + } + _ => { + println!(r#"Usage: .compress session"#) + } + } + } ".rebuild" => { match args.map(|v| match v.split_once(' ') { Some((subcmd, args)) => (subcmd, Some(args.trim())), @@ -657,7 +679,7 @@ async fn ask( color.italic().paint("Compressing the session."), ); tokio::spawn(async move { - let _ = compress_session(&config).await; + let _ = Config::compress_session(&config).await; config.write().end_compressing_session(); }); } @@ -696,14 +718,6 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> { } } -async fn compress_session(config: &GlobalConfig) -> Result<()> { - let input = Input::from_str(config, config.read().summarize_prompt(), None); - let client = input.create_client()?; - let summary = client.chat_completions(input).await?.text; - config.write().compress_session(&summary); - Ok(()) -} - fn split_files_text(args: &str) -> (&str, &str) { match SPLIT_FILES_TEXT_ARGS_RE.find(args).ok().flatten() { Some(mat) => {