Skip to content

Commit

Permalink
feat: add .save agent-config repl command (#870)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 14, 2024
1 parent 5a26c59 commit 6211d01
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 91 deletions.
14 changes: 7 additions & 7 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ editor: null # Specifies the command used to edit input buff
wrap: no # Controls text wrapping (no, auto, <max-width>)
wrap_code: false # Enables or disables wrapping of code blocks

# ---- function-calling ----
# Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search')

# ---- prelude ----
prelude: null # Set a default role or session to start with (e.g. role:<name>, session:<name>)
repl_prelude: null # Overrides the `prelude` setting specifically for conversations started in REPL
Expand All @@ -26,13 +33,6 @@ summarize_prompt: 'Summarize the discussion briefly in 200 words or less to use
# Text prompt used for including the summary of the entire session
summary_prompt: 'This is a summary of the chat history as a recap: '

# ---- function-calling ----
# Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search')

# ---- RAG ----
# See [RAG-Guide](https://github.com/sigoden/aichat/wiki/RAG-Guide) for more details.
rag_embedding_model: null # Specifies the embedding model to use
Expand Down
35 changes: 26 additions & 9 deletions src/config/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl Agent {
let agent_config = if config_path.exists() {
AgentConfig::load(&config_path)?
} else {
AgentConfig::default()
AgentConfig::new(&config.read())
};
let mut definition = AgentDefinition::load(&definition_file_path)?;
init_variables(&variables_path, &mut definition.variables)
Expand Down Expand Up @@ -91,6 +91,18 @@ impl Agent {
})
}

pub fn save_config(&self) -> Result<()> {
let config_path = Config::agent_config_file(&self.name)?;
ensure_parent_exists(&config_path)?;
let content = serde_yaml::to_string(&self.config)?;
fs::write(&config_path, content).with_context(|| {
format!("Failed to save agent config to '{}'", config_path.display())
})?;

println!("✨ Saved agent config to '{}'", config_path.display());
Ok(())
}

pub fn export(&self) -> Result<String> {
let mut agent = self.clone();
agent.definition.instructions = self.interpolated_instructions();
Expand Down Expand Up @@ -143,6 +155,10 @@ impl Agent {
self.config.agent_prelude.as_deref()
}

pub fn set_agent_prelude(&mut self, value: Option<String>) {
self.config.agent_prelude = value;
}

pub fn variables(&self) -> &[AgentVariable] {
&self.definition.variables
}
Expand Down Expand Up @@ -208,22 +224,23 @@ impl RoleLike for Agent {

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentConfig {
#[serde(
rename(serialize = "model", deserialize = "model"),
skip_serializing_if = "Option::is_none"
)]
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent_prelude: Option<String>,
}

impl AgentConfig {
pub fn new(config: &Config) -> Self {
Self {
use_tools: config.use_tools.clone(),
agent_prelude: config.agent_prelude.clone(),
..Default::default()
}
}

pub fn load(path: &Path) -> Result<Self> {
let contents = read_to_string(path)
.with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?;
Expand Down
127 changes: 73 additions & 54 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ pub struct Config {
pub wrap: Option<String>,
pub wrap_code: bool,

pub function_calling: bool,
pub mapping_tools: IndexMap<String, String>,
pub use_tools: Option<String>,

pub prelude: Option<String>,
pub repl_prelude: Option<String>,
pub agent_prelude: Option<String>,
Expand All @@ -108,10 +112,6 @@ pub struct Config {
pub summarize_prompt: Option<String>,
pub summary_prompt: Option<String>,

pub function_calling: bool,
pub mapping_tools: IndexMap<String, String>,
pub use_tools: Option<String>,

pub rag_embedding_model: Option<String>,
pub rag_reranker_model: Option<String>,
pub rag_top_k: usize,
Expand Down Expand Up @@ -166,6 +166,10 @@ impl Default for Config {
wrap: None,
wrap_code: false,

function_calling: true,
mapping_tools: Default::default(),
use_tools: None,

prelude: None,
repl_prelude: None,
agent_prelude: None,
Expand All @@ -175,10 +179,6 @@ impl Default for Config {
summarize_prompt: None,
summary_prompt: None,

function_calling: true,
mapping_tools: Default::default(),
use_tools: None,

rag_embedding_model: None,
rag_reranker_model: None,
rag_top_k: 4,
Expand Down Expand Up @@ -402,7 +402,7 @@ impl Config {
self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into())
}

pub fn log(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
let log_level = env::var(get_env_name("log_level"))
.ok()
.and_then(|v| v.parse().ok())
Expand Down Expand Up @@ -513,10 +513,14 @@ impl Config {
.wrap
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let (rag_reranker_model, rag_top_k) = match self.rag.as_ref() {
let (rag_reranker_model, rag_top_k) = match &self.rag {
Some(rag) => rag.get_config(),
None => (self.rag_reranker_model.clone(), self.rag_top_k),
};
let agent_prelude = match &self.agent {
Some(agent) => agent.agent_prelude(),
None => self.agent_prelude.as_deref(),
};
let role = self.extract_role();
let mut items = vec![
("model", role.model().id()),
Expand All @@ -535,10 +539,11 @@ impl Config {
("keybindings", self.keybindings.clone()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("save_session", format_option_value(&self.save_session)),
("compress_threshold", self.compress_threshold.to_string()),
("function_calling", self.function_calling.to_string()),
("use_tools", format_option_value(&role.use_tools())),
("agent_prelude", format_option_value(&agent_prelude)),
("save_session", format_option_value(&self.save_session)),
("compress_threshold", self.compress_threshold.to_string()),
(
"rag_reranker_model",
format_option_value(&rag_reranker_model),
Expand All @@ -554,7 +559,7 @@ impl Config {
("functions_dir", display_path(&Self::functions_dir()?)),
("messages_file", display_path(&self.messages_file()?)),
];
if let Ok((_, Some(log_path))) = Self::log(self.working_mode.is_serve()) {
if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) {
items.push(("log_path", display_path(&log_path)));
}
let output = items
Expand Down Expand Up @@ -597,14 +602,6 @@ impl Config {
let value = value.parse().with_context(|| "Invalid value")?;
config.write().save = value;
}
"rag_reranker_model" => {
let value = parse_value(value)?;
Self::set_rag_reranker_model(config, value)?;
}
"rag_top_k" => {
let value = value.parse().with_context(|| "Invalid value")?;
Self::set_rag_top_k(config, value)?;
}
"function_calling" => {
let value = value.parse().with_context(|| "Invalid value")?;
if value && config.write().functions.is_empty() {
Expand All @@ -616,6 +613,10 @@ impl Config {
let value = parse_value(value)?;
config.write().set_use_tools(value);
}
"agent_prelude" => {
let value = parse_value(value)?;
config.write().set_agent_prelude(value);
}
"save_session" => {
let value = parse_value(value)?;
config.write().set_save_session(value);
Expand All @@ -624,6 +625,14 @@ impl Config {
let value = parse_value(value)?;
config.write().set_compress_threshold(value);
}
"rag_reranker_model" => {
let value = parse_value(value)?;
Self::set_rag_reranker_model(config, value)?;
}
"rag_top_k" => {
let value = value.parse().with_context(|| "Invalid value")?;
Self::set_rag_top_k(config, value)?;
}
"highlight" => {
let value = value.parse().with_context(|| "Invalid value")?;
config.write().highlight = value;
Expand All @@ -638,7 +647,7 @@ impl Config {
"roles" => (Self::roles_dir()?, Some(".md")),
"sessions" => (config.read().sessions_dir()?, Some(".yaml")),
"rags" => (Self::rags_dir()?, Some(".yaml")),
"agents-config" => (Self::agents_config_dir()?, None),
"agents" => (Self::agents_config_dir()?, None),
_ => bail!("Unknown kind '{kind}'"),
};
let names = match read_dir(&dir) {
Expand Down Expand Up @@ -722,6 +731,13 @@ impl Config {
}
}

pub fn set_agent_prelude(&mut self, value: Option<String>) {
match self.agent.as_mut() {
Some(agent) => agent.set_agent_prelude(value),
None => self.agent_prelude = value,
}
}

pub fn set_save_session(&mut self, value: Option<bool>) {
if let Some(session) = self.session.as_mut() {
session.set_save_session(value);
Expand Down Expand Up @@ -1269,13 +1285,9 @@ impl Config {
bail!("Already in a agent, please run '.exit agent' first to exit the current agent.");
}
let agent = Agent::init(config, name, abort_signal).await?;
let session = session.map(|v| v.to_string()).or_else(|| {
agent
.agent_prelude()
.map(|v| v.to_string())
.or_else(|| config.read().agent_prelude.clone())
.and_then(|v| if v.is_empty() { None } else { Some(v) })
});
let session = session
.map(|v| v.to_string())
.or_else(|| agent.agent_prelude().map(|v| v.to_string()));
config.write().rag = agent.rag();
config.write().agent = Some(agent);
if let Some(session) = session {
Expand Down Expand Up @@ -1314,6 +1326,14 @@ impl Config {
Ok(())
}

pub fn save_agent_config(&mut self) -> Result<()> {
let agent = match &self.agent {
Some(v) => v,
None => bail!("No agent"),
};
agent.save_config()
}

pub fn exit_agent(&mut self) -> Result<()> {
self.exit_session()?;
if self.agent.take().is_some() {
Expand Down Expand Up @@ -1472,10 +1492,11 @@ impl Config {
"dry_run",
"stream",
"save",
"save_session",
"compress_threshold",
"function_calling",
"use_tools",
"agent_prelude",
"save_session",
"compress_threshold",
"rag_reranker_model",
"rag_top_k",
"highlight",
Expand All @@ -1486,9 +1507,7 @@ impl Config {
.map(|v| (format!("{v} "), None))
.collect()
}
".delete" => {
map_completion_values(vec!["roles", "sessions", "rags", "agents-config"])
}
".delete" => map_completion_values(vec!["roles", "sessions", "rags", "agents"]),
_ => vec![],
};
filter = args[0]
Expand All @@ -1501,14 +1520,6 @@ impl Config {
"dry_run" => complete_bool(self.dry_run),
"stream" => complete_bool(self.stream),
"save" => complete_bool(self.save),
"save_session" => {
let save_session = if let Some(session) = &self.session {
session.save_session()
} else {
self.save_session
};
complete_option_bool(save_session)
}
"function_calling" => complete_bool(self.function_calling),
"use_tools" => {
let mut prefix = String::new();
Expand All @@ -1529,6 +1540,14 @@ impl Config {
.map(|v| format!("{prefix}{v}"))
.collect()
}
"save_session" => {
let save_session = if let Some(session) = &self.session {
session.save_session()
} else {
self.save_session
};
complete_option_bool(save_session)
}
"rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(),
"highlight" => complete_bool(self.highlight),
_ => vec![],
Expand Down Expand Up @@ -1840,6 +1859,18 @@ impl Config {
self.wrap_code = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("prelude") {
self.prelude = v;
}
Expand All @@ -1863,18 +1894,6 @@ impl Config {
self.summary_prompt = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("rag_embedding_model") {
self.rag_embedding_model = v;
}
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ async fn create_input(
}

fn setup_logger(is_serve: bool) -> Result<()> {
let (log_level, log_path) = Config::log(is_serve)?;
let (log_level, log_path) = Config::log_config(is_serve)?;
if log_level == LevelFilter::Off {
return Ok(());
}
Expand Down
Loading

0 comments on commit 6211d01

Please sign in to comment.