Skip to content

Commit

Permalink
feat: all config fields have related environment variables (#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Jul 27, 2024
1 parent d020893 commit 9030f3b
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 81 deletions.
243 changes: 166 additions & 77 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub struct Config {

pub dry_run: bool,
pub save: bool,
pub keybindings: Keybindings,
pub keybindings: String,
pub buffer_editor: Option<String>,
pub wrap: Option<String>,
pub wrap_code: bool,
Expand All @@ -116,9 +116,10 @@ pub struct Config {
pub rag_min_score_vector_search: f32,
pub rag_min_score_keyword_search: f32,
pub rag_min_score_rerank: f32,
pub rag_template: Option<String>,

#[serde(default)]
pub document_loaders: HashMap<String, String>,
pub rag_template: Option<String>,

pub highlight: bool,
pub light_theme: bool,
Expand Down Expand Up @@ -156,7 +157,7 @@ impl Default for Config {

dry_run: false,
save: false,
keybindings: Default::default(),
keybindings: "emacs".into(),
buffer_editor: None,
wrap: None,
wrap_code: false,
Expand All @@ -165,6 +166,11 @@ impl Default for Config {
repl_prelude: None,
agent_prelude: None,

save_session: None,
compress_threshold: 4000,
summarize_prompt: None,
summary_prompt: None,

function_calling: true,
mapping_tools: Default::default(),
use_tools: None,
Expand All @@ -177,13 +183,9 @@ impl Default for Config {
rag_min_score_vector_search: 0.0,
rag_min_score_keyword_search: 0.0,
rag_min_score_rerank: 0.0,
document_loaders: Default::default(),
rag_template: None,

save_session: None,
compress_threshold: 4000,
summarize_prompt: None,
summary_prompt: None,
document_loaders: Default::default(),

highlight: true,
light_theme: false,
Expand Down Expand Up @@ -216,23 +218,23 @@ impl Config {
create_config_file(&config_path)?;
}
let mut config = if platform.is_some() {
Self::load_config_env(&platform.unwrap())?
Self::load_dynamic_config(&platform.unwrap())?
} else {
Self::load_config_file(&config_path)?
};

config.working_mode = working_mode;

config.load_envs();

if let Some(wrap) = config.wrap.clone() {
config.set_wrap(&wrap)?;
}

config.working_mode = working_mode;

config.load_functions()?;
config.load_roles()?;

config.setup_model()?;
config.setup_highlight();
config.setup_light_theme()?;
config.setup_document_loaders();

Ok(config)
Expand Down Expand Up @@ -512,7 +514,7 @@ impl Config {
("top_p", format_option_value(&role.top_p())),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("keybindings", self.keybindings.stringify().into()),
("keybindings", self.keybindings.clone()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("save_session", format_option_value(&self.save_session)),
Expand Down Expand Up @@ -1518,11 +1520,7 @@ impl Config {
Ok(config)
}

fn load_config_env(platform: &str) -> Result<Self> {
let model_id = match env::var(get_env_name("model_name")) {
Ok(model_name) => format!("{platform}:{model_name}"),
Err(_) => platform.to_string(),
};
fn load_dynamic_config(platform: &str) -> Result<Self> {
let is_openai_compatible = OPENAI_COMPATIBLE_PLATFORMS
.into_iter()
.any(|(name, _)| platform == name);
Expand All @@ -1532,7 +1530,7 @@ impl Config {
json!({ "type": platform })
};
let config = json!({
"model": model_id,
"model": platform.to_string(),
"save": false,
"clients": vec![client],
});
Expand All @@ -1541,6 +1539,132 @@ impl Config {
Ok(config)
}

fn load_envs(&mut self) {
if let Ok(v) = env::var(get_env_name("model")) {
self.model_id = v;
}
if let Some(v) = read_env_value::<f64>("temperature") {
self.temperature = v;
}
if let Some(v) = read_env_value::<f64>("top_p") {
self.top_p = v;
}

if let Some(Some(v)) = read_env_bool("dry_run") {
self.dry_run = v;
}
if let Some(Some(v)) = read_env_bool("save") {
self.save = v;
}
if let Ok(v) = env::var(get_env_name("keybindings")) {
if v == "vi" {
self.keybindings = v;
}
}
if let Some(v) = read_env_value::<String>("buffer_editor") {
self.buffer_editor = v;
}
if let Some(v) = read_env_value::<String>("wrap") {
self.wrap = v;
}
if let Some(Some(v)) = read_env_bool("wrap_code") {
self.wrap_code = v;
}

if let Some(v) = read_env_value::<String>("prelude") {
self.prelude = v;
}
if let Some(v) = read_env_value::<String>("repl_prelude") {
self.repl_prelude = v;
}
if let Some(v) = read_env_value::<String>("agent_prelude") {
self.agent_prelude = v;
}

if let Some(v) = read_env_bool("save_session") {
self.save_session = v;
}
if let Some(Some(v)) = read_env_value::<usize>("compress_threshold") {
self.compress_threshold = v;
}
if let Some(v) = read_env_value::<String>("summarize_prompt") {
self.summarize_prompt = v;
}
if let Some(v) = read_env_value::<String>("summary_prompt") {
self.summary_prompt = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("rag_embedding_model") {
self.rag_embedding_model = v;
}
if let Some(v) = read_env_value::<String>("rag_reranker_model") {
self.rag_reranker_model = v;
}
if let Some(Some(v)) = read_env_value::<usize>("rag_top_k") {
self.rag_top_k = v;
}
if let Some(v) = read_env_value::<usize>("rag_chunk_size") {
self.rag_chunk_size = v;
}
if let Some(v) = read_env_value::<usize>("rag_chunk_overlap") {
self.rag_chunk_overlap = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_vector_search") {
self.rag_min_score_vector_search = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_keyword_search") {
self.rag_min_score_keyword_search = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_rerank") {
self.rag_min_score_rerank = v;
}
if let Some(v) = read_env_value::<String>("rag_template") {
self.rag_template = v;
}

if let Ok(v) = env::var(get_env_name("document_loaders")) {
if let Ok(v) = serde_json::from_str(&v) {
self.document_loaders = v;
}
}

if let Some(Some(v)) = read_env_bool("highlight") {
self.highlight = v;
}
if let Ok(value) = env::var("NO_COLOR") {
if let Some(false) = parse_bool(&value) {
self.highlight = false;
}
}
if let Some(Some(v)) = read_env_bool("light_theme") {
self.light_theme = v;
} else if !self.light_theme {
if let Ok(v) = env::var("COLORFGBG") {
if let Some(v) = light_theme_from_colorfgbg(&v) {
self.light_theme = v
}
}
}
if let Some(v) = read_env_value::<String>("left_prompt") {
self.left_prompt = v;
}
if let Some(v) = read_env_value::<String>("right_prompt") {
self.right_prompt = v;
}
}

fn load_functions(&mut self) -> Result<()> {
self.functions = Functions::init(&Self::functions_file()?)?;
if self.functions.is_empty() {
Expand Down Expand Up @@ -1569,10 +1693,7 @@ impl Config {
}

fn setup_model(&mut self) -> Result<()> {
let mut model_id = match env::var(get_env_name("model")) {
Ok(v) => v,
Err(_) => self.model_id.clone(),
};
let mut model_id = self.model_id.clone();
if model_id.is_empty() {
let models = list_chat_models(self);
if models.is_empty() {
Expand All @@ -1585,31 +1706,6 @@ impl Config {
Ok(())
}

fn setup_highlight(&mut self) {
if let Ok(value) = env::var("NO_COLOR") {
let mut no_color = false;
set_bool(&mut no_color, &value);
if no_color {
self.highlight = false;
}
}
}

fn setup_light_theme(&mut self) -> Result<()> {
if self.light_theme {
return Ok(());
}
if let Ok(value) = env::var(get_env_name("light_theme")) {
set_bool(&mut self.light_theme, &value);
return Ok(());
} else if let Ok(value) = env::var("COLORFGBG") {
if let Some(light) = light_theme_from_colorfgbg(&value) {
self.light_theme = light
}
};
Ok(())
}

fn setup_document_loaders(&mut self) {
[
("pdf", "pdftotext $1 -"),
Expand Down Expand Up @@ -1637,33 +1733,12 @@ pub fn load_env_file() -> Result<()> {
continue;
}
if let Some((key, value)) = line.split_once('=') {
std::env::set_var(key.trim(), value.trim());
env::set_var(key.trim(), value.trim());
}
}
Ok(())
}

#[derive(Debug, Clone, Deserialize, Default)]
pub enum Keybindings {
#[serde(rename = "emacs")]
#[default]
Emacs,
#[serde(rename = "vi")]
Vi,
}

impl Keybindings {
pub fn is_vi(&self) -> bool {
matches!(self, Keybindings::Vi)
}
pub fn stringify(&self) -> &str {
match self {
Keybindings::Emacs => "emacs",
Keybindings::Vi => "vi",
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WorkingMode {
Command,
Expand Down Expand Up @@ -1757,12 +1832,13 @@ pub(crate) fn ensure_parent_exists(path: &Path) -> Result<()> {
Ok(())
}

fn set_bool(target: &mut bool, value: &str) {
match value {
"1" | "true" => *target = true,
"0" | "false" => *target = false,
_ => {}
}
fn read_env_value<T>(key: &str) -> Option<Option<T>>
where
T: std::str::FromStr,
{
let value = env::var(get_env_name(key)).ok()?;
let value = parse_value(&value).ok()?;
Some(value)
}

fn parse_value<T>(value: &str) -> Result<Option<T>>
Expand All @@ -1781,6 +1857,19 @@ where
Ok(value)
}

fn read_env_bool(key: &str) -> Option<Option<bool>> {
let value = env::var(get_env_name(key)).ok()?;
Some(parse_bool(&value))
}

fn parse_bool(value: &str) -> Option<bool> {
match value {
"1" | "true" => Some(true),
"0" | "false" => Some(false),
_ => None,
}
}

fn complete_bool(value: bool) -> Vec<String> {
vec![(!value).to_string()]
}
Expand Down
9 changes: 6 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ use inquire::{Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use simplelog::{format_description, ConfigBuilder, LevelFilter, SimpleLogger, WriteLogger};
use std::io::{stderr, stdin, Read};
use std::process;
use std::sync::Arc;
use std::{
env,
io::{stderr, stdin, Read},
process,
sync::Arc,
};

#[tokio::main]
async fn main() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ Type ".help" for additional help.
}

fn create_edit_mode(config: &GlobalConfig) -> Box<dyn EditMode> {
let edit_mode: Box<dyn EditMode> = if config.read().keybindings.is_vi() {
let edit_mode: Box<dyn EditMode> = if config.read().keybindings == "vi" {
let mut normal_keybindings = default_vi_normal_keybindings();
let mut insert_keybindings = default_vi_insert_keybindings();
Self::extra_keybindings(&mut normal_keybindings);
Expand Down

0 comments on commit 9030f3b

Please sign in to comment.