diff --git a/src/config/input.rs b/src/config/input.rs index 350f74c1..4ddc06fa 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -31,6 +31,7 @@ pub struct Input { text: String, patched_text: Option, continue_output: Option, + regenerate: bool, medias: Vec, data_urls: HashMap, tool_call: Option, @@ -48,6 +49,7 @@ impl Input { text: text.to_string(), patched_text: None, continue_output: None, + regenerate: false, medias: Default::default(), data_urls: Default::default(), tool_call: None, @@ -106,6 +108,7 @@ impl Input { text: texts.join("\n"), patched_text: None, continue_output: None, + regenerate: false, medias, data_urls, tool_call: Default::default(), @@ -147,6 +150,18 @@ impl Input { self.continue_output = Some(output); } + pub fn regenerate(&self) -> bool { + self.regenerate + } + + pub fn set_regenerate(&mut self) { + let role = self.config.read().extract_role(); + if role.name() == self.role().name() { + self.role = role; + } + self.regenerate = true; + } + pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> { if self.text.is_empty() { return Ok(()); diff --git a/src/config/mod.rs b/src/config/mod.rs index a27245db..cb510eff 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -933,8 +933,11 @@ impl Config { } pub fn exit_bot(&mut self) -> Result<()> { - self.rag.take(); - self.bot.take(); + if self.bot.take().is_some() { + self.exit_session()?; + self.rag.take(); + self.last_message = None; + } Ok(()) } diff --git a/src/config/session.rs b/src/config/session.rs index 0b48cbda..2149ba5d 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -353,30 +353,33 @@ impl Session { } pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> { - match input.continue_output() { - Some(_) => { - if let Some(message) = self.messages.last_mut() { - if let MessageContent::Text(text) = &mut message.content { - *text = format!("{text}{output}"); - } + if input.continue_output().is_some() { + if let Some(message) = self.messages.last_mut() { + if let MessageContent::Text(text) = &mut message.content { + *text = format!("{text}{output}"); } } - None => { - let mut need_add_msg = true; - if self.messages.is_empty() { - self.messages.extend(input.role().build_messages(input)); - need_add_msg = false; + } else if input.regenerate() { + if let Some(message) = self.messages.last_mut() { + if let MessageContent::Text(text) = &mut message.content { + *text = output.to_string(); } - if need_add_msg { - self.messages - .push(Message::new(MessageRole::User, input.message_content())); - } - self.data_urls.extend(input.data_urls()); - self.messages.push(Message::new( - MessageRole::Assistant, - MessageContent::Text(output.to_string()), - )); } + } else { + let mut need_add_msg = true; + if self.messages.is_empty() { + self.messages.extend(input.role().build_messages(input)); + need_add_msg = false; + } + if need_add_msg { + self.messages + .push(Message::new(MessageRole::User, input.message_content())); + } + self.data_urls.extend(input.data_urls()); + self.messages.push(Message::new( + MessageRole::Assistant, + MessageContent::Text(output.to_string()), + )); } self.dirty = true; Ok(()) @@ -398,6 +401,9 @@ impl Session { let mut messages = self.messages.clone(); if input.continue_output().is_some() { return messages; + } else if input.regenerate() { + messages.pop(); + return messages; } let mut need_add_msg = true; let len = messages.len(); diff --git a/src/repl/mod.rs b/src/repl/mod.rs index cfeb3410..ca718b5f 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -33,7 +33,7 @@ lazy_static! { const MENU_NAME: &str = "completion_menu"; lazy_static! { - static ref REPL_COMMANDS: [ReplCommand; 25] = [ + static ref REPL_COMMANDS: [ReplCommand; 26] = [ ReplCommand::new(".help", "Show this help message", AssertState::pass()), ReplCommand::new(".info", "View system info", AssertState::pass()), ReplCommand::new(".model", "Change the current LLM", AssertState::pass()), @@ -124,6 +124,11 @@ lazy_static! { AssertState::pass() ), ReplCommand::new(".continue", "Continue response", AssertState::pass()), + ReplCommand::new( + ".regenerate", + "Regenerate the last response", + AssertState::pass() + ), ReplCommand::new(".set", "Adjust settings", AssertState::pass()), ReplCommand::new(".copy", "Copy the last response", AssertState::pass()), ReplCommand::new(".exit", "Exit the REPL", AssertState::pass()), @@ -321,11 +326,19 @@ Tips: use to autocomplete conversation starter text. ".continue" => { let (mut input, output) = match self.config.read().last_message.clone() { Some(v) => v, - None => bail!("No incomplete response."), + None => bail!("Unable to continue response"), }; input.set_continue_output(&output); ask(&self.config, self.abort_signal.clone(), input, true).await?; } + ".regenerate" => { + let (mut input, _) = match self.config.read().last_message.clone() { + Some(v) => v, + None => bail!("Unable to regenerate the last response"), + }; + input.set_regenerate(); + ask(&self.config, self.abort_signal.clone(), input, true).await?; + } ".set" => match args { Some(args) => { self.config.write().update(args)?;