Skip to content

Commit

Permalink
feat: add .compress session REPL command (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Oct 5, 2024
1 parent 11a706a commit ea4c213
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
24 changes: 20 additions & 4 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1133,11 +1133,28 @@ impl Config {
false
}

pub fn compress_session(&mut self, summary: &str) {
if let Some(session) = self.session.as_mut() {
let summary_prompt = self.summary_prompt.as_deref().unwrap_or(SUMMARY_PROMPT);
pub async fn compress_session(config: &GlobalConfig) -> Result<()> {
match config.read().session.as_ref() {
Some(session) => {
if !session.has_user_messages() {
bail!("No need to compress since there are no messages in the session")
}
}
None => bail!("No session"),
}
let input = Input::from_str(config, config.read().summarize_prompt(), None);
let client = input.create_client()?;
let summary = client.chat_completions(input).await?.text;
let summary_prompt = config
.read()
.summary_prompt
.clone()
.unwrap_or_else(|| SUMMARY_PROMPT.into());
if let Some(session) = config.write().session.as_mut() {
session.compress(format!("{}{}", summary_prompt, summary));
}
config.write().last_message = None;
Ok(())
}

pub fn summarize_prompt(&self) -> &str {
Expand All @@ -1155,7 +1172,6 @@ impl Config {
if let Some(session) = self.session.as_mut() {
session.set_compressing(false);
}
self.last_message = None;
}

pub async fn use_rag(
Expand Down
9 changes: 5 additions & 4 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl Session {
self.model().total_tokens(&self.messages)
}

pub fn has_user_messages(&self) -> bool {
self.messages.iter().any(|v| v.role.is_user())
}

pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}
Expand Down Expand Up @@ -372,12 +376,9 @@ impl Session {
}
}
} else {
let mut need_add_msg = true;
if self.messages.is_empty() {
self.messages.extend(input.role().build_messages(input));
need_add_msg = false;
}
if need_add_msg {
} else {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
Expand Down
36 changes: 25 additions & 11 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::client::{call_chat_completions, call_chat_completions_streaming};
use crate::config::{AssertState, Config, GlobalConfig, Input, StateFlags};
use crate::function::need_send_tool_results;
use crate::render::render_error;
use crate::utils::{create_abort_signal, set_text, temp_file, AbortSignal};
use crate::utils::{create_abort_signal, create_spinner, set_text, temp_file, AbortSignal};

use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
Expand All @@ -31,7 +31,7 @@ lazy_static::lazy_static! {
const MENU_NAME: &str = "completion_menu";

lazy_static::lazy_static! {
static ref REPL_COMMANDS: [ReplCommand; 33] = [
static ref REPL_COMMANDS: [ReplCommand; 34] = [
ReplCommand::new(".help", "Show this help message", AssertState::pass()),
ReplCommand::new(".info", "View system info", AssertState::pass()),
ReplCommand::new(".model", "Change the current LLM", AssertState::pass()),
Expand Down Expand Up @@ -75,6 +75,11 @@ lazy_static::lazy_static! {
"Erase messages in the current session",
AssertState::True(StateFlags::SESSION)
),
ReplCommand::new(
".compress session",
"Compress messages in the current session",
AssertState::True(StateFlags::SESSION)
),
ReplCommand::new(
".info session",
"View session info",
Expand Down Expand Up @@ -360,6 +365,23 @@ impl Repl {
}
}
}
".compress" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
}) {
Some(("session", _)) => {
let spinner = create_spinner("Compressing").await;
let ret = Config::compress_session(&self.config).await;
spinner.stop();
ret?;
println!("✨ Successfully compressed the session");
}
_ => {
println!(r#"Usage: .compress session"#)
}
}
}
".rebuild" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
Expand Down Expand Up @@ -657,7 +679,7 @@ async fn ask(
color.italic().paint("Compressing the session."),
);
tokio::spawn(async move {
let _ = compress_session(&config).await;
let _ = Config::compress_session(&config).await;
config.write().end_compressing_session();
});
}
Expand Down Expand Up @@ -696,14 +718,6 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
}

async fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(config, config.read().summarize_prompt(), None);
let client = input.create_client()?;
let summary = client.chat_completions(input).await?.text;
config.write().compress_session(&summary);
Ok(())
}

fn split_files_text(args: &str) -> (&str, &str) {
match SPLIT_FILES_TEXT_ARGS_RE.find(args).ok().flatten() {
Some(mat) => {
Expand Down

0 comments on commit ea4c213

Please sign in to comment.