Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: .model repl completions show max tokens and price #462

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,22 @@ macro_rules! register_client {
anyhow::bail!("Unknown client '{}'", client)
}

pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
static mut ALL_CLIENTS: Option<Vec<$crate::client::Model>> = None;

pub fn list_models(config: &$crate::config::Config) -> Vec<&$crate::client::Model> {
if unsafe { ALL_CLIENTS.is_none() } {
let models: Vec<_> = config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect();
unsafe { ALL_CLIENTS = Some(models) };
}
unsafe { ALL_CLIENTS.as_ref().unwrap().iter().collect() }
}

};
}

Expand Down
62 changes: 57 additions & 5 deletions src/client/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::message::{Message, MessageContent};

use crate::utils::count_tokens;
use crate::utils::{count_tokens, format_option_value};

use anyhow::{bail, Result};
use serde::Deserialize;
Expand All @@ -14,6 +14,9 @@ pub struct Model {
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
pub ref_max_output_tokens: Option<isize>,
pub input_price: Option<f64>,
pub output_price: Option<f64>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
pub capabilities: ModelCapabilities,
}
Expand All @@ -32,6 +35,9 @@ impl Model {
extra_fields: None,
max_input_tokens: None,
max_output_tokens: None,
ref_max_output_tokens: None,
input_price: None,
output_price: None,
capabilities: ModelCapabilities::Text,
}
}
Expand All @@ -43,13 +49,16 @@ impl Model {
Model::new(client_name, &v.name)
.set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens)
.set_ref_max_output_tokens(v.ref_max_output_tokens)
.set_input_price(v.input_price)
.set_output_price(v.output_price)
.set_supports_vision(v.supports_vision)
.set_extra_fields(&v.extra_fields)
})
.collect()
}

pub fn find(models: &[Self], value: &str) -> Option<Self> {
pub fn find(models: &[&Self], value: &str) -> Option<Self> {
let mut model = None;
let (client_name, model_name) = match value.split_once(':') {
Some((client_name, model_name)) => {
Expand All @@ -64,16 +73,16 @@ impl Model {
match model_name {
Some(model_name) => {
if let Some(found) = models.iter().find(|v| v.id() == value) {
model = Some(found.clone());
model = Some((*found).clone());
} else if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
let mut found = found.clone();
let mut found = (*found).clone();
found.name = model_name.to_string();
model = Some(found)
}
}
None => {
if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
model = Some(found.clone());
model = Some((*found).clone());
}
}
}
Expand All @@ -84,6 +93,23 @@ impl Model {
format!("{}:{}", self.client_name, self.name)
}

pub fn description(&self) -> String {
let max_input_tokens = format_option_value(&self.max_input_tokens);
let max_output_tokens =
format_option_value(&self.max_output_tokens.or(self.ref_max_output_tokens));
let input_price = format_option_value(&self.input_price);
let output_price = format_option_value(&self.output_price);
let vision = if self.capabilities.contains(ModelCapabilities::Vision) {
"👁"
} else {
""
};
format!(
"{:>8} / {:>8} | {:>6} / {:>6} {}",
max_input_tokens, max_output_tokens, input_price, output_price, vision
)
}

pub fn set_max_input_tokens(mut self, max_input_tokens: Option<usize>) -> Self {
match max_input_tokens {
None | Some(0) => self.max_input_tokens = None,
Expand All @@ -100,6 +126,30 @@ impl Model {
self
}

pub fn set_ref_max_output_tokens(mut self, ref_max_output_tokens: Option<isize>) -> Self {
match ref_max_output_tokens {
None | Some(0) => self.ref_max_output_tokens = None,
_ => self.ref_max_output_tokens = ref_max_output_tokens,
}
self
}

pub fn set_input_price(mut self, input_price: Option<f64>) -> Self {
match input_price {
None => self.input_price = None,
_ => self.input_price = input_price,
}
self
}

pub fn set_output_price(mut self, output_price: Option<f64>) -> Self {
match output_price {
None => self.output_price = None,
_ => self.output_price = output_price,
}
self
}

pub fn set_supports_vision(mut self, supports_vision: bool) -> Self {
if supports_vision {
self.capabilities |= ModelCapabilities::Vision;
Expand Down Expand Up @@ -178,6 +228,8 @@ pub struct ModelConfig {
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
#[serde(rename = "max_output_tokens?")]
pub ref_max_output_tokens: Option<isize>,
pub input_price: Option<f64>,
pub output_price: Option<f64>,
#[serde(default)]
Expand Down
51 changes: 29 additions & 22 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, render_prompt, set_text};
use crate::utils::{
format_option_value, fuzzy_match, get_env_name, light_theme_from_colorfgbg, now, render_prompt,
set_text,
};

use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Select, Text};
Expand Down Expand Up @@ -415,18 +418,18 @@ impl Config {
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("model", self.model.id()),
("temperature", format_option(&self.temperature)),
("top_p", format_option(&self.top_p)),
("temperature", format_option_value(&self.temperature)),
("top_p", format_option_value(&self.top_p)),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("save_session", format_option(&self.save_session)),
("save_session", format_option_value(&self.save_session)),
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("auto_copy", self.auto_copy.to_string()),
("keybindings", self.keybindings.stringify().into()),
("prelude", format_option(&self.prelude)),
("prelude", format_option_value(&self.prelude)),
("compress_threshold", self.compress_threshold.to_string()),
("config_file", display_path(&Self::config_file()?)),
("roles_file", display_path(&Self::roles_file()?)),
Expand Down Expand Up @@ -476,12 +479,23 @@ impl Config {
.unwrap_or_default()
}

pub fn repl_complete(&self, cmd: &str, args: &[&str]) -> Vec<String> {
pub fn repl_complete(&self, cmd: &str, args: &[&str]) -> Vec<(String, String)> {
let (values, filter) = if args.len() == 1 {
let values = match cmd {
".role" => self.roles.iter().map(|v| v.name.clone()).collect(),
".model" => list_models(self).into_iter().map(|v| v.id()).collect(),
".session" => self.list_sessions(),
".role" => self
.roles
.iter()
.map(|v| (v.name.clone(), String::new()))
.collect(),
".model" => list_models(self)
.into_iter()
.map(|v| (v.id(), v.description()))
.collect(),
".session" => self
.list_sessions()
.into_iter()
.map(|v| (v.clone(), String::new()))
.collect(),
".set" => vec![
"temperature ",
"top_p ",
Expand All @@ -493,7 +507,7 @@ impl Config {
"auto_copy ",
]
.into_iter()
.map(|v| v.to_string())
.map(|v| (v.to_string(), String::new()))
.collect(),
_ => vec![],
};
Expand All @@ -514,13 +528,16 @@ impl Config {
"auto_copy" => complete_bool(self.auto_copy),
_ => vec![],
};
(values, args[1])
(
values.into_iter().map(|v| (v, String::new())).collect(),
args[1],
)
} else {
return vec![];
};
values
.into_iter()
.filter(|v| v.starts_with(filter))
.filter(|(value, _)| fuzzy_match(value, filter))
.collect()
}

Expand Down Expand Up @@ -1136,16 +1153,6 @@ where
Ok(value)
}

fn format_option<T>(value: &Option<T>) -> String
where
T: std::fmt::Display,
{
match value {
Some(value) => value.to_string(),
None => "-".to_string(),
}
}

fn complete_bool(value: bool) -> Vec<String> {
vec![(!value).to_string()]
}
Expand Down
13 changes: 9 additions & 4 deletions src/repl/completer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Completer for ReplCompleter {
.read()
.repl_complete(cmd, &args)
.iter()
.map(|name| create_suggestion(name.clone(), None, span)),
.map(|(value, description)| create_suggestion(value, description, span)),
)
}

Expand All @@ -69,7 +69,7 @@ impl Completer for ReplCompleter {
} else {
format!("{name} ")
};
create_suggestion(name, Some(description.to_string()), span)
create_suggestion(&name, description, span)
}))
}
suggestions
Expand Down Expand Up @@ -105,9 +105,14 @@ impl ReplCompleter {
}
}

fn create_suggestion(value: String, description: Option<String>, span: Span) -> Suggestion {
fn create_suggestion(value: &str, description: &str, span: Span) -> Suggestion {
let description = if description.is_empty() {
None
} else {
Some(description.to_string())
};
Suggestion {
value,
value: value.to_string(),
description,
style: None,
extra: None,
Expand Down
34 changes: 34 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,33 @@ pub fn extract_block(input: &str) -> String {
}
}

pub fn format_option_value<T>(value: &Option<T>) -> String
where
T: std::fmt::Display,
{
match value {
Some(value) => value.to_string(),
None => "-".to_string(),
}
}

pub fn fuzzy_match(text: &str, pattern: &str) -> bool {
let text_chars: Vec<char> = text.chars().collect();
let pattern_chars: Vec<char> = pattern.chars().collect();

let mut pattern_index = 0;
let mut text_index = 0;

while pattern_index < pattern_chars.len() && text_index < text_chars.len() {
if pattern_chars[pattern_index] == text_chars[text_index] {
pattern_index += 1;
}
text_index += 1;
}

pattern_index == pattern_chars.len()
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -180,4 +207,11 @@ mod tests {
fn test_count_tokens() {
assert_eq!(count_tokens("😊 hello world"), 4);
}

#[test]
fn test_fuzzy_match() {
assert!(fuzzy_match("openai:gpt-4-turbo", "gpt4"));
assert!(fuzzy_match("openai:gpt-4-turbo", "oai4"));
assert!(!fuzzy_match("openai:gpt-4-turbo", "4gpt"));
}
}