Skip to content

Commit

Permalink
feat: .model repl completions show max tokens and price
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Apr 29, 2024
1 parent 3a00fb2 commit 5a6bb57
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 41 deletions.
25 changes: 15 additions & 10 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,22 @@ macro_rules! register_client {
anyhow::bail!("Unknown client '{}'", client)
}

pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
static mut ALL_CLIENTS: Option<Vec<$crate::client::Model>> = None;

pub fn list_models(config: &$crate::config::Config) -> Vec<&$crate::client::Model> {
if unsafe { ALL_CLIENTS.is_none() } {
let models: Vec<_> = config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect();
unsafe { ALL_CLIENTS = Some(models) };
}
unsafe { ALL_CLIENTS.as_ref().unwrap().iter().collect() }
}

};
}

Expand Down
62 changes: 57 additions & 5 deletions src/client/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::message::{Message, MessageContent};

use crate::utils::count_tokens;
use crate::utils::{count_tokens, format_option_value};

use anyhow::{bail, Result};
use serde::Deserialize;
Expand All @@ -14,6 +14,9 @@ pub struct Model {
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
pub ref_max_output_tokens: Option<isize>,
pub input_price: Option<f64>,
pub output_price: Option<f64>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
pub capabilities: ModelCapabilities,
}
Expand All @@ -32,6 +35,9 @@ impl Model {
extra_fields: None,
max_input_tokens: None,
max_output_tokens: None,
ref_max_output_tokens: None,
input_price: None,
output_price: None,
capabilities: ModelCapabilities::Text,
}
}
Expand All @@ -43,13 +49,16 @@ impl Model {
Model::new(client_name, &v.name)
.set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens)
.set_ref_max_output_tokens(v.ref_max_output_tokens)
.set_input_price(v.input_price)
.set_output_price(v.output_price)
.set_supports_vision(v.supports_vision)
.set_extra_fields(&v.extra_fields)
})
.collect()
}

pub fn find(models: &[Self], value: &str) -> Option<Self> {
pub fn find(models: &[&Self], value: &str) -> Option<Self> {
let mut model = None;
let (client_name, model_name) = match value.split_once(':') {
Some((client_name, model_name)) => {
Expand All @@ -64,16 +73,16 @@ impl Model {
match model_name {
Some(model_name) => {
if let Some(found) = models.iter().find(|v| v.id() == value) {
model = Some(found.clone());
model = Some((*found).clone());
} else if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
let mut found = found.clone();
let mut found = (*found).clone();
found.name = model_name.to_string();
model = Some(found)
}
}
None => {
if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
model = Some(found.clone());
model = Some((*found).clone());
}
}
}
Expand All @@ -84,6 +93,23 @@ impl Model {
format!("{}:{}", self.client_name, self.name)
}

pub fn description(&self) -> String {
let max_input_tokens = format_option_value(&self.max_input_tokens);
let max_output_tokens =
format_option_value(&self.max_output_tokens.or(self.ref_max_output_tokens));
let input_price = format_option_value(&self.input_price);
let output_price = format_option_value(&self.output_price);
let vision = if self.capabilities.contains(ModelCapabilities::Vision) {
"👁"
} else {
""
};
format!(
"{:>8} / {:>8} | {:>6} / {:>6} {}",
max_input_tokens, max_output_tokens, input_price, output_price, vision
)
}

pub fn set_max_input_tokens(mut self, max_input_tokens: Option<usize>) -> Self {
match max_input_tokens {
None | Some(0) => self.max_input_tokens = None,
Expand All @@ -100,6 +126,30 @@ impl Model {
self
}

pub fn set_ref_max_output_tokens(mut self, ref_max_output_tokens: Option<isize>) -> Self {
match ref_max_output_tokens {
None | Some(0) => self.ref_max_output_tokens = None,
_ => self.ref_max_output_tokens = ref_max_output_tokens,
}
self
}

pub fn set_input_price(mut self, input_price: Option<f64>) -> Self {
match input_price {
None => self.input_price = None,
_ => self.input_price = input_price,
}
self
}

pub fn set_output_price(mut self, output_price: Option<f64>) -> Self {
match output_price {
None => self.output_price = None,
_ => self.output_price = output_price,
}
self
}

pub fn set_supports_vision(mut self, supports_vision: bool) -> Self {
if supports_vision {
self.capabilities |= ModelCapabilities::Vision;
Expand Down Expand Up @@ -178,6 +228,8 @@ pub struct ModelConfig {
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
#[serde(rename = "max_output_tokens?")]
pub ref_max_output_tokens: Option<isize>,
pub input_price: Option<f64>,
pub output_price: Option<f64>,
#[serde(default)]
Expand Down
51 changes: 29 additions & 22 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, render_prompt, set_text};
use crate::utils::{
format_option_value, fuzzy_match, get_env_name, light_theme_from_colorfgbg, now, render_prompt,
set_text,
};

use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Select, Text};
Expand Down Expand Up @@ -415,18 +418,18 @@ impl Config {
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("model", self.model.id()),
("temperature", format_option(&self.temperature)),
("top_p", format_option(&self.top_p)),
("temperature", format_option_value(&self.temperature)),
("top_p", format_option_value(&self.top_p)),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("save_session", format_option(&self.save_session)),
("save_session", format_option_value(&self.save_session)),
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("auto_copy", self.auto_copy.to_string()),
("keybindings", self.keybindings.stringify().into()),
("prelude", format_option(&self.prelude)),
("prelude", format_option_value(&self.prelude)),
("compress_threshold", self.compress_threshold.to_string()),
("config_file", display_path(&Self::config_file()?)),
("roles_file", display_path(&Self::roles_file()?)),
Expand Down Expand Up @@ -476,12 +479,23 @@ impl Config {
.unwrap_or_default()
}

pub fn repl_complete(&self, cmd: &str, args: &[&str]) -> Vec<String> {
pub fn repl_complete(&self, cmd: &str, args: &[&str]) -> Vec<(String, String)> {
let (values, filter) = if args.len() == 1 {
let values = match cmd {
".role" => self.roles.iter().map(|v| v.name.clone()).collect(),
".model" => list_models(self).into_iter().map(|v| v.id()).collect(),
".session" => self.list_sessions(),
".role" => self
.roles
.iter()
.map(|v| (v.name.clone(), String::new()))
.collect(),
".model" => list_models(self)
.into_iter()
.map(|v| (v.id(), v.description()))
.collect(),
".session" => self
.list_sessions()
.into_iter()
.map(|v| (v.clone(), String::new()))
.collect(),
".set" => vec![
"temperature ",
"top_p ",
Expand All @@ -493,7 +507,7 @@ impl Config {
"auto_copy ",
]
.into_iter()
.map(|v| v.to_string())
.map(|v| (v.to_string(), String::new()))
.collect(),
_ => vec![],
};
Expand All @@ -514,13 +528,16 @@ impl Config {
"auto_copy" => complete_bool(self.auto_copy),
_ => vec![],
};
(values, args[1])
(
values.into_iter().map(|v| (v, String::new())).collect(),
args[1],
)
} else {
return vec![];
};
values
.into_iter()
.filter(|v| v.starts_with(filter))
.filter(|(value, _)| fuzzy_match(value, filter))
.collect()
}

Expand Down Expand Up @@ -1136,16 +1153,6 @@ where
Ok(value)
}

fn format_option<T>(value: &Option<T>) -> String
where
T: std::fmt::Display,
{
match value {
Some(value) => value.to_string(),
None => "-".to_string(),
}
}

fn complete_bool(value: bool) -> Vec<String> {
vec![(!value).to_string()]
}
Expand Down
13 changes: 9 additions & 4 deletions src/repl/completer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Completer for ReplCompleter {
.read()
.repl_complete(cmd, &args)
.iter()
.map(|name| create_suggestion(name.clone(), None, span)),
.map(|(value, description)| create_suggestion(value, description, span)),
)
}

Expand All @@ -69,7 +69,7 @@ impl Completer for ReplCompleter {
} else {
format!("{name} ")
};
create_suggestion(name, Some(description.to_string()), span)
create_suggestion(&name, description, span)
}))
}
suggestions
Expand Down Expand Up @@ -105,9 +105,14 @@ impl ReplCompleter {
}
}

fn create_suggestion(value: String, description: Option<String>, span: Span) -> Suggestion {
fn create_suggestion(value: &str, description: &str, span: Span) -> Suggestion {
let description = if description.is_empty() {
None
} else {
Some(description.to_string())
};
Suggestion {
value,
value: value.to_string(),
description,
style: None,
extra: None,
Expand Down
34 changes: 34 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,33 @@ pub fn extract_block(input: &str) -> String {
}
}

pub fn format_option_value<T>(value: &Option<T>) -> String
where
T: std::fmt::Display,
{
match value {
Some(value) => value.to_string(),
None => "-".to_string(),
}
}

pub fn fuzzy_match(text: &str, pattern: &str) -> bool {
let text_chars: Vec<char> = text.chars().collect();
let pattern_chars: Vec<char> = pattern.chars().collect();

let mut pattern_index = 0;
let mut text_index = 0;

while pattern_index < pattern_chars.len() && text_index < text_chars.len() {
if pattern_chars[pattern_index] == text_chars[text_index] {
pattern_index += 1;
}
text_index += 1;
}

pattern_index == pattern_chars.len()
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -180,4 +207,11 @@ mod tests {
fn test_count_tokens() {
assert_eq!(count_tokens("😊 hello world"), 4);
}

#[test]
fn test_fuzzy_match() {
assert!(fuzzy_match("openai:gpt-4-turbo", "gpt4"));
assert!(fuzzy_match("openai:gpt-4-turbo", "oai4"));
assert!(!fuzzy_match("openai:gpt-4-turbo", "4gpt"));
}
}

0 comments on commit 5a6bb57

Please sign in to comment.