Skip to content

Commit

Permalink
feat: add .regenerate repl command (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Jun 17, 2024
1 parent ff28477 commit 62b297e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 24 deletions.
15 changes: 15 additions & 0 deletions src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct Input {
text: String,
patched_text: Option<String>,
continue_output: Option<String>,
regenerate: bool,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_call: Option<ToolResults>,
Expand All @@ -48,6 +49,7 @@ impl Input {
text: text.to_string(),
patched_text: None,
continue_output: None,
regenerate: false,
medias: Default::default(),
data_urls: Default::default(),
tool_call: None,
Expand Down Expand Up @@ -106,6 +108,7 @@ impl Input {
text: texts.join("\n"),
patched_text: None,
continue_output: None,
regenerate: false,
medias,
data_urls,
tool_call: Default::default(),
Expand Down Expand Up @@ -147,6 +150,18 @@ impl Input {
self.continue_output = Some(output);
}

pub fn regenerate(&self) -> bool {
self.regenerate
}

pub fn set_regenerate(&mut self) {
let role = self.config.read().extract_role();
if role.name() == self.role().name() {
self.role = role;
}
self.regenerate = true;
}

pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
if self.text.is_empty() {
return Ok(());
Expand Down
7 changes: 5 additions & 2 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,11 @@ impl Config {
}

pub fn exit_bot(&mut self) -> Result<()> {
self.rag.take();
self.bot.take();
if self.bot.take().is_some() {
self.exit_session()?;
self.rag.take();
self.last_message = None;
}
Ok(())
}

Expand Down
46 changes: 26 additions & 20 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,30 +353,33 @@ impl Session {
}

pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
match input.continue_output() {
Some(_) => {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = format!("{text}{output}");
}
if input.continue_output().is_some() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = format!("{text}{output}");
}
}
None => {
let mut need_add_msg = true;
if self.messages.is_empty() {
self.messages.extend(input.role().build_messages(input));
need_add_msg = false;
} else if input.regenerate() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = output.to_string();
}
if need_add_msg {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
}
} else {
let mut need_add_msg = true;
if self.messages.is_empty() {
self.messages.extend(input.role().build_messages(input));
need_add_msg = false;
}
if need_add_msg {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
}
self.dirty = true;
Ok(())
Expand All @@ -398,6 +401,9 @@ impl Session {
let mut messages = self.messages.clone();
if input.continue_output().is_some() {
return messages;
} else if input.regenerate() {
messages.pop();
return messages;
}
let mut need_add_msg = true;
let len = messages.len();
Expand Down
17 changes: 15 additions & 2 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ lazy_static! {
const MENU_NAME: &str = "completion_menu";

lazy_static! {
static ref REPL_COMMANDS: [ReplCommand; 25] = [
static ref REPL_COMMANDS: [ReplCommand; 26] = [
ReplCommand::new(".help", "Show this help message", AssertState::pass()),
ReplCommand::new(".info", "View system info", AssertState::pass()),
ReplCommand::new(".model", "Change the current LLM", AssertState::pass()),
Expand Down Expand Up @@ -124,6 +124,11 @@ lazy_static! {
AssertState::pass()
),
ReplCommand::new(".continue", "Continue response", AssertState::pass()),
ReplCommand::new(
".regenerate",
"Regenerate the last response",
AssertState::pass()
),
ReplCommand::new(".set", "Adjust settings", AssertState::pass()),
ReplCommand::new(".copy", "Copy the last response", AssertState::pass()),
ReplCommand::new(".exit", "Exit the REPL", AssertState::pass()),
Expand Down Expand Up @@ -321,11 +326,19 @@ Tips: use <tab> to autocomplete conversation starter text.
".continue" => {
let (mut input, output) = match self.config.read().last_message.clone() {
Some(v) => v,
None => bail!("No incomplete response."),
None => bail!("Unable to continue response"),
};
input.set_continue_output(&output);
ask(&self.config, self.abort_signal.clone(), input, true).await?;
}
".regenerate" => {
let (mut input, _) = match self.config.read().last_message.clone() {
Some(v) => v,
None => bail!("Unable to regenerate the last response"),
};
input.set_regenerate();
ask(&self.config, self.abort_signal.clone(), input, true).await?;
}
".set" => match args {
Some(args) => {
self.config.write().update(args)?;
Expand Down

0 comments on commit 62b297e

Please sign in to comment.