Skip to content

Commit

Permalink
feat: custom REPL prompt (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Dec 24, 2023
1 parent 89fefb4 commit 1c9ca1b
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 51 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ aichat has a powerful Chat REPL.
The Chat REPL supports:
- Emacs/Vi keybinding
- Command autocompletion
- Edit/paste multiline input
- [Custom REPL Prompt](https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt)
- Tab Completion
- Edit/paste multiline text
- Undo support
### `.help` - print help message
Expand Down
6 changes: 5 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ auto_copy: false # Automatically copy the last output to the cli
keybindings: emacs # REPL keybindings. (emacs, vi)
prelude: '' # Set a default role or session (role:<name>, session:<name>)

# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt
left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} '
right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'

clients:
# All clients have the following configuration:
# - type: xxxx
Expand Down Expand Up @@ -38,7 +42,7 @@ clients:
# See https://github.com/jmorganca/ollama
- type: ollama
api_base: http://localhost:11434/api
api_key: Baisc xxx
api_key: Basic xxx # Set authorization header
chat_endpoint: /chat # Optional field
models:
- name: gpt4all-j
Expand Down
4 changes: 2 additions & 2 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool
Ok((body, has_upload))
}

/// Patch messsages, upload emebeded images to oss
/// Patch messsages, upload embedded images to oss
async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>) -> Result<()> {
for message in messages {
if let MessageContent::Array(list) = message.content.borrow_mut() {
Expand All @@ -258,7 +258,7 @@ async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>)
if url.starts_with("data:") {
*url = upload(model, api_key, url)
.await
.with_context(|| "Failed to upload embeded image to oss")?;
.with_context(|| "Failed to upload embedded image to oss")?;
}
}
}
Expand Down
92 changes: 80 additions & 12 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ use crate::client::{
Model, OpenAIClient, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err, render_prompt};

use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use serde::Deserialize;
use std::collections::HashMap;
use std::{
env,
fs::{create_dir_all, read_dir, read_to_string, remove_file, File, OpenOptions},
Expand Down Expand Up @@ -66,6 +67,10 @@ pub struct Config {
pub keybindings: Keybindings,
/// Set a default role or session (role:<name>, session:<name>)
pub prelude: String,
/// REPL left prompt
pub left_prompt: String,
/// REPL right prompt
pub right_prompt: String,
/// Setup clients
pub clients: Vec<ClientConfig>,
/// Predefined roles
Expand Down Expand Up @@ -99,6 +104,9 @@ impl Default for Config {
auto_copy: false,
keybindings: Default::default(),
prelude: String::new(),
left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(),
right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"
.to_string(),
clients: vec![ClientConfig::default()],
roles: vec![],
role: None,
Expand Down Expand Up @@ -648,18 +656,14 @@ impl Config {
Ok(RenderOptions::new(theme, wrap, self.wrap_code))
}

pub fn render_prompt_left(&self) -> String {
let variables = self.generate_prompt_context();
render_prompt(&self.left_prompt, &variables)
}

pub fn render_prompt_right(&self) -> String {
if let Some(session) = &self.session {
let (tokens, percent) = session.tokens_and_percent();
let percent = if percent == 0.0 {
String::new()
} else {
format!("({percent}%)")
};
format!("{tokens}{percent}")
} else {
String::new()
}
let variables = self.generate_prompt_context();
render_prompt(&self.right_prompt, &variables)
}

pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
Expand All @@ -681,6 +685,70 @@ impl Config {
}
}

fn generate_prompt_context(&self) -> HashMap<&str, String> {
let mut output = HashMap::new();
output.insert("model", self.model.id());
output.insert("client_name", self.model.client_name.clone());
output.insert("model_name", self.model.name.clone());
output.insert(
"max_tokens",
self.model.max_tokens.unwrap_or_default().to_string(),
);
if let Some(temperature) = self.temperature {
if temperature != 0.0 {
output.insert("temperature", temperature.to_string());
}
}
if self.dry_run {
output.insert("dry_run", "true".to_string());
}
if self.save {
output.insert("save", "true".to_string());
}
if let Some(wrap) = &self.wrap {
if wrap != "no" {
output.insert("wrap", wrap.clone());
}
}
if self.auto_copy {
output.insert("auto_copy", "true".to_string());
}
if let Some(role) = &self.role {
output.insert("role", role.name.clone());
}
if let Some(session) = &self.session {
output.insert("session", session.name().to_string());
let (tokens, percent) = session.tokens_and_percent();
output.insert("consume_tokens", tokens.to_string());
output.insert("consume_percent", percent.to_string());
output.insert("user_messages_len", session.user_messages_len().to_string());
}

if self.highlight {
output.insert("color.reset", "\u{1b}[0m".to_string());
output.insert("color.black", "\u{1b}[30m".to_string());
output.insert("color.dark_gray", "\u{1b}[90m".to_string());
output.insert("color.red", "\u{1b}[31m".to_string());
output.insert("color.light_red", "\u{1b}[91m".to_string());
output.insert("color.green", "\u{1b}[32m".to_string());
output.insert("color.light_green", "\u{1b}[92m".to_string());
output.insert("color.yellow", "\u{1b}[33m".to_string());
output.insert("color.light_yellow", "\u{1b}[93m".to_string());
output.insert("color.blue", "\u{1b}[34m".to_string());
output.insert("color.light_blue", "\u{1b}[94m".to_string());
output.insert("color.purple", "\u{1b}[35m".to_string());
output.insert("color.light_purple", "\u{1b}[95m".to_string());
output.insert("color.magenta", "\u{1b}[35m".to_string());
output.insert("color.light_magenta", "\u{1b}[95m".to_string());
output.insert("color.cyan", "\u{1b}[36m".to_string());
output.insert("color.light_cyan", "\u{1b}[96m".to_string());
output.insert("color.white", "\u{1b}[37m".to_string());
output.insert("color.light_gray", "\u{1b}[97m".to_string());
}

output
}

fn open_message_file(&self) -> Result<File> {
let path = Self::messages_file()?;
ensure_parent_exists(&path)?;
Expand Down
4 changes: 4 additions & 0 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl Session {
self.model.total_tokens(&self.messages)
}

pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}

pub fn export(&self) -> Result<String> {
self.guard_save()?;
let (tokens, percent) = self.tokens_and_percent();
Expand Down
36 changes: 2 additions & 34 deletions src/repl/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
use crate::config::GlobalConfig;

use crossterm::style::Color;
use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus};
use std::borrow::Cow;

const PROMPT_COLOR: Color = Color::Green;
const PROMPT_MULTILINE_COLOR: nu_ansi_term::Color = nu_ansi_term::Color::LightBlue;
const INDICATOR_COLOR: Color = Color::Cyan;
const PROMPT_RIGHT_COLOR: Color = Color::AnsiValue(5);

#[derive(Clone)]
pub struct ReplPrompt {
config: GlobalConfig,
Expand All @@ -24,25 +18,15 @@ impl ReplPrompt {

impl Prompt for ReplPrompt {
fn render_prompt_left(&self) -> Cow<str> {
if let Some(session) = &self.config.read().session {
Cow::Owned(session.name().to_string())
} else if let Some(role) = &self.config.read().role {
Cow::Owned(role.name.clone())
} else {
Cow::Borrowed("")
}
Cow::Owned(self.config.read().render_prompt_left())
}

fn render_prompt_right(&self) -> Cow<str> {
Cow::Owned(self.config.read().render_prompt_right())
}

fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<str> {
if self.config.read().session.is_some() {
Cow::Borrowed(") ")
} else {
Cow::Borrowed("> ")
}
Cow::Borrowed("")
}

fn render_prompt_multiline_indicator(&self) -> Cow<str> {
Expand All @@ -64,20 +48,4 @@ impl Prompt for ReplPrompt {
prefix, history_search.term
))
}

fn get_prompt_color(&self) -> Color {
PROMPT_COLOR
}
/// Get the default multiline prompt color
fn get_prompt_multiline_color(&self) -> nu_ansi_term::Color {
PROMPT_MULTILINE_COLOR
}
/// Get the default indicator color
fn get_indicator_color(&self) -> Color {
INDICATOR_COLOR
}
/// Get the default right prompt color
fn get_prompt_right_color(&self) -> Color {
PROMPT_RIGHT_COLOR
}
}
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod abort_signal;
mod clipboard;
mod prompt_input;
mod render_prompt;
mod tiktoken;

pub use self::abort_signal::{create_abort_signal, AbortSignal};
pub use self::clipboard::set_text;
pub use self::prompt_input::*;
pub use self::render_prompt::render_prompt;
pub use self::tiktoken::cl100k_base_singleton;

use sha2::{Digest, Sha256};
Expand Down
Loading

0 comments on commit 1c9ca1b

Please sign in to comment.