Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add .continue repl command #608

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,31 @@ pub struct Input {
config: GlobalConfig,
text: String,
patched_text: Option<String>,
continue_output: Option<String>,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_call: Option<ToolResults>,
rag_name: Option<String>,
role: Role,
with_session: bool,
with_bot: bool,
}

impl Input {
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
let (role, with_session) = resolve_role(&config.read(), role);
let (role, with_session, with_bot) = resolve_role(&config.read(), role);
Self {
config: config.clone(),
text: text.to_string(),
patched_text: None,
continue_output: None,
medias: Default::default(),
data_urls: Default::default(),
tool_call: None,
rag_name: None,
role,
with_session,
with_bot,
}
}

Expand Down Expand Up @@ -96,17 +100,19 @@ impl Input {
}
}

let (role, session) = resolve_role(&config.read(), role);
let (role, with_session, with_bot) = resolve_role(&config.read(), role);
Ok(Self {
config: config.clone(),
text: texts.join("\n"),
patched_text: None,
continue_output: None,
medias,
data_urls,
tool_call: Default::default(),
rag_name: None,
role,
with_session: session,
with_session,
with_bot,
})
}

Expand All @@ -129,6 +135,18 @@ impl Input {
self.text = text;
}

pub fn continue_output(&self) -> Option<&str> {
self.continue_output.as_deref()
}

pub fn set_continue_output(&mut self, output: &str) {
let output = match &self.continue_output {
Some(v) => format!("{v}{output}"),
None => output.to_string(),
};
self.continue_output = Some(output);
}

pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
if self.text.is_empty() {
return Ok(());
Expand All @@ -150,10 +168,6 @@ impl Input {
self.rag_name.as_deref()
}

pub fn clear_patch_text(&mut self) {
self.patched_text.take();
}

pub fn merge_tool_call(mut self, output: String, tool_call_results: Vec<ToolResult>) -> Self {
match self.tool_call.as_mut() {
Some(exist_tool_call_results) => {
Expand Down Expand Up @@ -234,6 +248,10 @@ impl Input {
}
}

pub fn with_bot(&self) -> bool {
self.with_bot
}

pub fn summary(&self) -> String {
let text: String = self
.text
Expand Down Expand Up @@ -296,10 +314,14 @@ impl Input {
}
}

fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool) {
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
match role {
Some(v) => (v, false),
None => (config.extract_role(), config.session.is_some()),
Some(v) => (v, false, false),
None => (
config.extract_role(),
config.session.is_some(),
config.bot.is_some(),
),
}
}

Expand Down
32 changes: 15 additions & 17 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ pub struct Config {
#[serde(skip)]
pub working_mode: WorkingMode,
#[serde(skip)]
pub last_message: Option<(Input, String, bool)>, // (input, output, is_bot)
pub last_message: Option<(Input, String)>,
}

impl Default for Config {
Expand Down Expand Up @@ -668,8 +668,8 @@ impl Config {
}
if let Some(session) = session.as_mut() {
if session.is_empty() {
if let Some((input, output, is_bot)) = &self.last_message {
if self.bot.is_some() == *is_bot {
if let Some((input, output)) = &self.last_message {
if self.bot.is_some() == input.with_bot() {
let ans = Confirm::new(
"Start a session that incorporates the last question and answer?",
)
Expand Down Expand Up @@ -739,13 +739,17 @@ impl Config {
)
})?;
self.session = Some(Session::load(self, &name, &session_path)?);
self.last_message = None;
Ok(())
}

pub fn clear_session_messages(&mut self) -> Result<()> {
if let Some(session) = self.session.as_mut() {
session.clear_messages();
} else {
bail!("No session")
}
self.last_message = None;
Ok(())
}

Expand Down Expand Up @@ -1081,7 +1085,7 @@ impl Config {
pub fn last_reply(&self) -> &str {
self.last_message
.as_ref()
.map(|(_, reply, _)| reply.as_str())
.map(|(_, reply)| reply.as_str())
.unwrap_or_default()
}

Expand Down Expand Up @@ -1207,7 +1211,7 @@ impl Config {
}

pub fn before_chat_completion(&mut self, input: &Input) -> Result<()> {
self.last_message = Some((input.clone(), String::new(), self.bot.is_some()));
self.last_message = Some((input.clone(), String::new()));
Ok(())
}

Expand All @@ -1216,23 +1220,17 @@ impl Config {
input: &mut Input,
output: &str,
tool_results: &[ToolResult],
) -> Result<()> {
input.clear_patch_text();
self.last_message = Some((input.clone(), output.to_string(), self.bot.is_some()));
self.save_message(input, output, tool_results)?;
Ok(())
}

fn save_message(
&mut self,
input: &mut Input,
output: &str,
tool_results: &[ToolResult],
) -> Result<()> {
if self.dry_run || output.is_empty() || !tool_results.is_empty() {
self.last_message = None;
return Ok(());
}
self.last_message = Some((input.clone(), output.to_string()));
self.save_message(input, output)?;
Ok(())
}

fn save_message(&mut self, input: &mut Input, output: &str) -> Result<()> {
if let Some(session) = input.session_mut(&mut self.session) {
session.add_message(input, output)?;
return Ok(());
Expand Down
9 changes: 8 additions & 1 deletion src/config/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ async function timeout(ms) {

pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut content = input.message_content();
if self.is_empty_prompt() {
let mut messages = if self.is_empty_prompt() {
vec![Message::new(MessageRole::User, content)]
} else if self.is_embedded_prompt() {
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
Expand All @@ -209,7 +209,14 @@ async function timeout(ms) {
}
messages.push(Message::new(MessageRole::User, content));
messages
};
if let Some(text) = input.continue_output() {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(text.into()),
));
}
messages
}
}

Expand Down
40 changes: 27 additions & 13 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,20 +353,31 @@ impl Session {
}

pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
let mut need_add_msg = true;
if self.messages.is_empty() {
self.messages.extend(input.role().build_messages(input));
need_add_msg = false;
}
if need_add_msg {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
match input.continue_output() {
Some(_) => {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = format!("{text}{output}");
}
}
}
None => {
let mut need_add_msg = true;
if self.messages.is_empty() {
self.messages.extend(input.role().build_messages(input));
need_add_msg = false;
}
if need_add_msg {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
}
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
self.dirty = true;
Ok(())
}
Expand All @@ -385,6 +396,9 @@ impl Session {

pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
if input.continue_output().is_some() {
return messages;
}
let mut need_add_msg = true;
let len = messages.len();
if len == 0 {
Expand Down
29 changes: 19 additions & 10 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ lazy_static! {
const MENU_NAME: &str = "completion_menu";

lazy_static! {
static ref REPL_COMMANDS: [ReplCommand; 24] = [
static ref REPL_COMMANDS: [ReplCommand; 25] = [
ReplCommand::new(".help", "Show this help message", AssertState::pass()),
ReplCommand::new(".info", "View system info", AssertState::pass()),
ReplCommand::new(".model", "Change the current LLM", AssertState::pass()),
Expand Down Expand Up @@ -123,6 +123,7 @@ lazy_static! {
"Include files with the message",
AssertState::pass()
),
ReplCommand::new(".continue", "Continue response", AssertState::pass()),
ReplCommand::new(".set", "Adjust settings", AssertState::pass()),
ReplCommand::new(".copy", "Copy the last response", AssertState::pass()),
ReplCommand::new(".exit", "Exit the REPL", AssertState::pass()),
Expand Down Expand Up @@ -308,6 +309,23 @@ Tips: use <tab> to autocomplete conversation starter text.
}
}
}
".file" => match args {
Some(args) => {
let (files, text) = split_files_text(args);
let files = shell_words::split(files).with_context(|| "Invalid args")?;
let input = Input::new(&self.config, text, files, None)?;
ask(&self.config, self.abort_signal.clone(), input, true).await?;
}
None => println!("Usage: .file <files>... [-- <text>...]"),
},
".continue" => {
let (mut input, output) = match self.config.read().last_message.clone() {
Some(v) => v,
None => bail!("No incomplete response."),
};
input.set_continue_output(&output);
ask(&self.config, self.abort_signal.clone(), input, true).await?;
}
".set" => match args {
Some(args) => {
self.config.write().update(args)?;
Expand All @@ -321,15 +339,6 @@ Tips: use <tab> to autocomplete conversation starter text.
self.copy(config.last_reply())
.with_context(|| "Failed to copy the last output")?;
}
".file" => match args {
Some(args) => {
let (files, text) = split_files_text(args);
let files = shell_words::split(files).with_context(|| "Invalid args")?;
let input = Input::new(&self.config, text, files, None)?;
ask(&self.config, self.abort_signal.clone(), input, true).await?;
}
None => println!("Usage: .file <files>... [-- <text>...]"),
},
".exit" => match args {
Some("role") => {
self.config.write().exit_role()?;
Expand Down