Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: all config fields have related environment variables #751

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 166 additions & 77 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub struct Config {

pub dry_run: bool,
pub save: bool,
pub keybindings: Keybindings,
pub keybindings: String,
pub buffer_editor: Option<String>,
pub wrap: Option<String>,
pub wrap_code: bool,
Expand All @@ -116,9 +116,10 @@ pub struct Config {
pub rag_min_score_vector_search: f32,
pub rag_min_score_keyword_search: f32,
pub rag_min_score_rerank: f32,
pub rag_template: Option<String>,

#[serde(default)]
pub document_loaders: HashMap<String, String>,
pub rag_template: Option<String>,

pub highlight: bool,
pub light_theme: bool,
Expand Down Expand Up @@ -156,7 +157,7 @@ impl Default for Config {

dry_run: false,
save: false,
keybindings: Default::default(),
keybindings: "emacs".into(),
buffer_editor: None,
wrap: None,
wrap_code: false,
Expand All @@ -165,6 +166,11 @@ impl Default for Config {
repl_prelude: None,
agent_prelude: None,

save_session: None,
compress_threshold: 4000,
summarize_prompt: None,
summary_prompt: None,

function_calling: true,
mapping_tools: Default::default(),
use_tools: None,
Expand All @@ -177,13 +183,9 @@ impl Default for Config {
rag_min_score_vector_search: 0.0,
rag_min_score_keyword_search: 0.0,
rag_min_score_rerank: 0.0,
document_loaders: Default::default(),
rag_template: None,

save_session: None,
compress_threshold: 4000,
summarize_prompt: None,
summary_prompt: None,
document_loaders: Default::default(),

highlight: true,
light_theme: false,
Expand Down Expand Up @@ -216,23 +218,23 @@ impl Config {
create_config_file(&config_path)?;
}
let mut config = if platform.is_some() {
Self::load_config_env(&platform.unwrap())?
Self::load_dynamic_config(&platform.unwrap())?
} else {
Self::load_config_file(&config_path)?
};

config.working_mode = working_mode;

config.load_envs();

if let Some(wrap) = config.wrap.clone() {
config.set_wrap(&wrap)?;
}

config.working_mode = working_mode;

config.load_functions()?;
config.load_roles()?;

config.setup_model()?;
config.setup_highlight();
config.setup_light_theme()?;
config.setup_document_loaders();

Ok(config)
Expand Down Expand Up @@ -512,7 +514,7 @@ impl Config {
("top_p", format_option_value(&role.top_p())),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("keybindings", self.keybindings.stringify().into()),
("keybindings", self.keybindings.clone()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("save_session", format_option_value(&self.save_session)),
Expand Down Expand Up @@ -1518,11 +1520,7 @@ impl Config {
Ok(config)
}

fn load_config_env(platform: &str) -> Result<Self> {
let model_id = match env::var(get_env_name("model_name")) {
Ok(model_name) => format!("{platform}:{model_name}"),
Err(_) => platform.to_string(),
};
fn load_dynamic_config(platform: &str) -> Result<Self> {
let is_openai_compatible = OPENAI_COMPATIBLE_PLATFORMS
.into_iter()
.any(|(name, _)| platform == name);
Expand All @@ -1532,7 +1530,7 @@ impl Config {
json!({ "type": platform })
};
let config = json!({
"model": model_id,
"model": platform.to_string(),
"save": false,
"clients": vec![client],
});
Expand All @@ -1541,6 +1539,132 @@ impl Config {
Ok(config)
}

fn load_envs(&mut self) {
if let Ok(v) = env::var(get_env_name("model")) {
self.model_id = v;
}
if let Some(v) = read_env_value::<f64>("temperature") {
self.temperature = v;
}
if let Some(v) = read_env_value::<f64>("top_p") {
self.top_p = v;
}

if let Some(Some(v)) = read_env_bool("dry_run") {
self.dry_run = v;
}
if let Some(Some(v)) = read_env_bool("save") {
self.save = v;
}
if let Ok(v) = env::var(get_env_name("keybindings")) {
if v == "vi" {
self.keybindings = v;
}
}
if let Some(v) = read_env_value::<String>("buffer_editor") {
self.buffer_editor = v;
}
if let Some(v) = read_env_value::<String>("wrap") {
self.wrap = v;
}
if let Some(Some(v)) = read_env_bool("wrap_code") {
self.wrap_code = v;
}

if let Some(v) = read_env_value::<String>("prelude") {
self.prelude = v;
}
if let Some(v) = read_env_value::<String>("repl_prelude") {
self.repl_prelude = v;
}
if let Some(v) = read_env_value::<String>("agent_prelude") {
self.agent_prelude = v;
}

if let Some(v) = read_env_bool("save_session") {
self.save_session = v;
}
if let Some(Some(v)) = read_env_value::<usize>("compress_threshold") {
self.compress_threshold = v;
}
if let Some(v) = read_env_value::<String>("summarize_prompt") {
self.summarize_prompt = v;
}
if let Some(v) = read_env_value::<String>("summary_prompt") {
self.summary_prompt = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("rag_embedding_model") {
self.rag_embedding_model = v;
}
if let Some(v) = read_env_value::<String>("rag_reranker_model") {
self.rag_reranker_model = v;
}
if let Some(Some(v)) = read_env_value::<usize>("rag_top_k") {
self.rag_top_k = v;
}
if let Some(v) = read_env_value::<usize>("rag_chunk_size") {
self.rag_chunk_size = v;
}
if let Some(v) = read_env_value::<usize>("rag_chunk_overlap") {
self.rag_chunk_overlap = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_vector_search") {
self.rag_min_score_vector_search = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_keyword_search") {
self.rag_min_score_keyword_search = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_rerank") {
self.rag_min_score_rerank = v;
}
if let Some(v) = read_env_value::<String>("rag_template") {
self.rag_template = v;
}

if let Ok(v) = env::var(get_env_name("document_loaders")) {
if let Ok(v) = serde_json::from_str(&v) {
self.document_loaders = v;
}
}

if let Some(Some(v)) = read_env_bool("highlight") {
self.highlight = v;
}
if let Ok(value) = env::var("NO_COLOR") {
if let Some(false) = parse_bool(&value) {
self.highlight = false;
}
}
if let Some(Some(v)) = read_env_bool("light_theme") {
self.light_theme = v;
} else if !self.light_theme {
if let Ok(v) = env::var("COLORFGBG") {
if let Some(v) = light_theme_from_colorfgbg(&v) {
self.light_theme = v
}
}
}
if let Some(v) = read_env_value::<String>("left_prompt") {
self.left_prompt = v;
}
if let Some(v) = read_env_value::<String>("right_prompt") {
self.right_prompt = v;
}
}

fn load_functions(&mut self) -> Result<()> {
self.functions = Functions::init(&Self::functions_file()?)?;
if self.functions.is_empty() {
Expand Down Expand Up @@ -1569,10 +1693,7 @@ impl Config {
}

fn setup_model(&mut self) -> Result<()> {
let mut model_id = match env::var(get_env_name("model")) {
Ok(v) => v,
Err(_) => self.model_id.clone(),
};
let mut model_id = self.model_id.clone();
if model_id.is_empty() {
let models = list_chat_models(self);
if models.is_empty() {
Expand All @@ -1585,31 +1706,6 @@ impl Config {
Ok(())
}

fn setup_highlight(&mut self) {
if let Ok(value) = env::var("NO_COLOR") {
let mut no_color = false;
set_bool(&mut no_color, &value);
if no_color {
self.highlight = false;
}
}
}

fn setup_light_theme(&mut self) -> Result<()> {
if self.light_theme {
return Ok(());
}
if let Ok(value) = env::var(get_env_name("light_theme")) {
set_bool(&mut self.light_theme, &value);
return Ok(());
} else if let Ok(value) = env::var("COLORFGBG") {
if let Some(light) = light_theme_from_colorfgbg(&value) {
self.light_theme = light
}
};
Ok(())
}

fn setup_document_loaders(&mut self) {
[
("pdf", "pdftotext $1 -"),
Expand Down Expand Up @@ -1637,33 +1733,12 @@ pub fn load_env_file() -> Result<()> {
continue;
}
if let Some((key, value)) = line.split_once('=') {
std::env::set_var(key.trim(), value.trim());
env::set_var(key.trim(), value.trim());
}
}
Ok(())
}

#[derive(Debug, Clone, Deserialize, Default)]
pub enum Keybindings {
#[serde(rename = "emacs")]
#[default]
Emacs,
#[serde(rename = "vi")]
Vi,
}

impl Keybindings {
pub fn is_vi(&self) -> bool {
matches!(self, Keybindings::Vi)
}
pub fn stringify(&self) -> &str {
match self {
Keybindings::Emacs => "emacs",
Keybindings::Vi => "vi",
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WorkingMode {
Command,
Expand Down Expand Up @@ -1757,12 +1832,13 @@ pub(crate) fn ensure_parent_exists(path: &Path) -> Result<()> {
Ok(())
}

fn set_bool(target: &mut bool, value: &str) {
match value {
"1" | "true" => *target = true,
"0" | "false" => *target = false,
_ => {}
}
fn read_env_value<T>(key: &str) -> Option<Option<T>>
where
T: std::str::FromStr,
{
let value = env::var(get_env_name(key)).ok()?;
let value = parse_value(&value).ok()?;
Some(value)
}

fn parse_value<T>(value: &str) -> Result<Option<T>>
Expand All @@ -1781,6 +1857,19 @@ where
Ok(value)
}

fn read_env_bool(key: &str) -> Option<Option<bool>> {
let value = env::var(get_env_name(key)).ok()?;
Some(parse_bool(&value))
}

fn parse_bool(value: &str) -> Option<bool> {
match value {
"1" | "true" => Some(true),
"0" | "false" => Some(false),
_ => None,
}
}

fn complete_bool(value: bool) -> Vec<String> {
vec![(!value).to_string()]
}
Expand Down
9 changes: 6 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ use inquire::{Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use simplelog::{format_description, ConfigBuilder, LevelFilter, SimpleLogger, WriteLogger};
use std::io::{stderr, stdin, Read};
use std::process;
use std::sync::Arc;
use std::{
env,
io::{stderr, stdin, Read},
process,
sync::Arc,
};

#[tokio::main]
async fn main() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ Type ".help" for additional help.
}

fn create_edit_mode(config: &GlobalConfig) -> Box<dyn EditMode> {
let edit_mode: Box<dyn EditMode> = if config.read().keybindings.is_vi() {
let edit_mode: Box<dyn EditMode> = if config.read().keybindings == "vi" {
let mut normal_keybindings = default_vi_normal_keybindings();
let mut insert_keybindings = default_vi_insert_keybindings();
Self::extra_keybindings(&mut normal_keybindings);
Expand Down