From 21bc50e31d94a0fe61d8e27d786dcefcb7c7b73d Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 6 Nov 2024 14:08:26 +0800 Subject: [PATCH] feat: support overriding agent config with env vars --- src/config/agent.rs | 29 +++++++++++++++++++- src/config/mod.rs | 67 +++++++++++++++++++++++---------------------- 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/config/agent.rs b/src/config/agent.rs index 5589f0ea..9a7e67ce 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -41,7 +41,7 @@ impl Agent { let functions_file_path = functions_dir.join("functions.json"); let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME)?; let config_path = Config::agent_config_file(name)?; - let agent_config = if config_path.exists() { + let mut agent_config = if config_path.exists() { AgentConfig::load(&config_path)? } else { AgentConfig::new(&config.read()) @@ -54,6 +54,8 @@ impl Agent { }; definition.replace_tools_placeholder(&functions); + agent_config.load_envs(&definition.name); + let model = { let config = config.read(); match agent_config.model_id.as_ref() { @@ -330,6 +332,31 @@ impl AgentConfig { .with_context(|| format!("Failed to load agent config at '{}'", path.display()))?; Ok(config) } + + fn load_envs(&mut self, name: &str) { + let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}")); + + if let Some(v) = read_env_value::(&with_prefix("model")) { + self.model_id = v; + } + if let Some(v) = read_env_value::(&with_prefix("temperature")) { + self.temperature = v; + } + if let Some(v) = read_env_value::(&with_prefix("top_p")) { + self.top_p = v; + } + if let Some(v) = read_env_value::(&with_prefix("use_tools")) { + self.use_tools = v; + } + if let Some(v) = read_env_value::(&with_prefix("agent_prelude")) { + self.agent_prelude = v; + } + if let Ok(v) = env::var(with_prefix("variables")) { + if let Ok(v) = serde_json::from_str(&v) { + self.variables = v; + } + } + } } #[derive(Debug, Clone, Default, Deserialize, Serialize)] diff --git a/src/config/mod.rs b/src/config/mod.rs index b9d9d8e0..6736624a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1942,20 +1942,20 @@ impl Config { if let Ok(v) = env::var(get_env_name("model")) { self.model_id = v; } - if let Some(v) = read_env_value::("temperature") { + if let Some(v) = read_env_value::(&get_env_name("temperature")) { self.temperature = v; } - if let Some(v) = read_env_value::("top_p") { + if let Some(v) = read_env_value::(&get_env_name("top_p")) { self.top_p = v; } - if let Some(Some(v)) = read_env_bool("dry_run") { + if let Some(Some(v)) = read_env_bool(&get_env_name("dry_run")) { self.dry_run = v; } - if let Some(Some(v)) = read_env_bool("stream") { + if let Some(Some(v)) = read_env_bool(&get_env_name("stream")) { self.stream = v; } - if let Some(Some(v)) = read_env_bool("save") { + if let Some(Some(v)) = read_env_bool(&get_env_name("save")) { self.save = v; } if let Ok(v) = env::var(get_env_name("keybindings")) { @@ -1963,17 +1963,17 @@ impl Config { self.keybindings = v; } } - if let Some(v) = read_env_value::("editor") { + if let Some(v) = read_env_value::(&get_env_name("editor")) { self.editor = v; } - if let Some(v) = read_env_value::("wrap") { + if let Some(v) = read_env_value::(&get_env_name("wrap")) { self.wrap = v; } - if let Some(Some(v)) = read_env_bool("wrap_code") { + if let Some(Some(v)) = read_env_bool(&get_env_name("wrap_code")) { self.wrap_code = v; } - if let Some(Some(v)) = read_env_bool("function_calling") { + if let Some(Some(v)) = read_env_bool(&get_env_name("function_calling")) { self.function_calling = v; } if let Ok(v) = env::var(get_env_name("mapping_tools")) { @@ -1981,55 +1981,56 @@ impl Config { self.mapping_tools = v; } } - if let Some(v) = read_env_value::("use_tools") { + if let Some(v) = read_env_value::(&get_env_name("use_tools")) { self.use_tools = v; } - if let Some(v) = read_env_value::("prelude") { + if let Some(v) = read_env_value::(&get_env_name("prelude")) { self.prelude = v; } - if let Some(v) = read_env_value::("repl_prelude") { + if let Some(v) = read_env_value::(&get_env_name("repl_prelude")) { self.repl_prelude = v; } - if let Some(v) = read_env_value::("agent_prelude") { + if let Some(v) = read_env_value::(&get_env_name("agent_prelude")) { self.agent_prelude = v; } - if let Some(v) = read_env_bool("save_session") { + if let Some(v) = read_env_bool(&get_env_name("save_session")) { self.save_session = v; } - if let Some(Some(v)) = read_env_value::("compress_threshold") { + if let Some(Some(v)) = read_env_value::(&get_env_name("compress_threshold")) { self.compress_threshold = v; } - if let Some(v) = read_env_value::("summarize_prompt") { + if let Some(v) = read_env_value::(&get_env_name("summarize_prompt")) { self.summarize_prompt = v; } - if let Some(v) = read_env_value::("summary_prompt") { + if let Some(v) = read_env_value::(&get_env_name("summary_prompt")) { self.summary_prompt = v; } - if let Some(v) = read_env_value::("rag_embedding_model") { + if let Some(v) = read_env_value::(&get_env_name("rag_embedding_model")) { self.rag_embedding_model = v; } - if let Some(v) = read_env_value::("rag_reranker_model") { + if let Some(v) = read_env_value::(&get_env_name("rag_reranker_model")) { self.rag_reranker_model = v; } - if let Some(Some(v)) = read_env_value::("rag_top_k") { + if let Some(Some(v)) = read_env_value::(&get_env_name("rag_top_k")) { self.rag_top_k = v; } - if let Some(v) = read_env_value::("rag_chunk_size") { + if let Some(v) = read_env_value::(&get_env_name("rag_chunk_size")) { self.rag_chunk_size = v; } - if let Some(v) = read_env_value::("rag_chunk_overlap") { + if let Some(v) = read_env_value::(&get_env_name("rag_chunk_overlap")) { self.rag_chunk_overlap = v; } - if let Some(Some(v)) = read_env_value::("rag_min_score_vector_search") { + if let Some(Some(v)) = read_env_value::(&get_env_name("rag_min_score_vector_search")) { self.rag_min_score_vector_search = v; } - if let Some(Some(v)) = read_env_value::("rag_min_score_keyword_search") { + if let Some(Some(v)) = read_env_value::(&get_env_name("rag_min_score_keyword_search")) + { self.rag_min_score_keyword_search = v; } - if let Some(v) = read_env_value::("rag_template") { + if let Some(v) = read_env_value::(&get_env_name("rag_template")) { self.rag_template = v; } @@ -2039,13 +2040,13 @@ impl Config { } } - if let Some(Some(v)) = read_env_bool("highlight") { + if let Some(Some(v)) = read_env_bool(&get_env_name("highlight")) { self.highlight = v; } if *NO_COLOR { self.highlight = false; } - if let Some(Some(v)) = read_env_bool("light_theme") { + if let Some(Some(v)) = read_env_bool(&get_env_name("light_theme")) { self.light_theme = v; } else if !self.light_theme { if let Ok(v) = env::var("COLORFGBG") { @@ -2054,17 +2055,17 @@ impl Config { } } } - if let Some(v) = read_env_value::("left_prompt") { + if let Some(v) = read_env_value::(&get_env_name("left_prompt")) { self.left_prompt = v; } - if let Some(v) = read_env_value::("right_prompt") { + if let Some(v) = read_env_value::(&get_env_name("right_prompt")) { self.right_prompt = v; } - if let Some(v) = read_env_value::("serve_addr") { + if let Some(v) = read_env_value::(&get_env_name("serve_addr")) { self.serve_addr = v; } - if let Some(v) = read_env_value::("user_agent") { + if let Some(v) = read_env_value::(&get_env_name("user_agent")) { self.user_agent = v; } } @@ -2230,7 +2231,7 @@ fn read_env_value(key: &str) -> Option> where T: std::str::FromStr, { - let value = env::var(get_env_name(key)).ok()?; + let value = env::var(key).ok()?; let value = parse_value(&value).ok()?; Some(value) } @@ -2252,7 +2253,7 @@ where } fn read_env_bool(key: &str) -> Option> { - let value = env::var(get_env_name(key)).ok()?; + let value = env::var(key).ok()?; Some(parse_bool(&value)) }