Skip to content

Commit

Permalink
feat: add support for prompt file in commit command
Browse files Browse the repository at this point in the history
This commit adds support for specifying a prompt file in the commit command. If a prompt file is provided, the content of the file will be used as the prompt for the chat completion. This allows users to provide a custom prompt for generating chat completions.

The `prompt_file` field has been added to the `Cmd` struct in the `commit` module. If the `prompt_file` option is provided, the content of the file will be read and used as the prompt. If the `prompt_file` option is not provided, the default prompt from the `res/prompt.md` file will be used.

This change enhances the flexibility of the commit command by allowing users to easily customize the prompt used for generating chat completions.
  • Loading branch information
liblaf committed Nov 24, 2023
1 parent 582915e commit 0d48d56
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/dist
/target
53 changes: 36 additions & 17 deletions src/cmd/commit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ mod tests;

#[derive(Debug, Args)]
pub struct Cmd {
/// If not provided, will use `bw get notes OPENAI_API_KEY`
#[arg(short, long, env = "OPENAI_API_KEY")]
api_key: Option<String>,

#[arg(short, long)]
#[arg(short, long, default_values = EXCLUDE)]
exclude: Vec<PathBuf>,

#[arg(short, long)]
Expand All @@ -32,34 +33,33 @@ pub struct Cmd {
#[arg(long, default_value_t = false)]
no_pre_commit: bool,

#[arg(short, long)]
prompt: Option<String>,

#[arg(long)]
prompt_file: Option<PathBuf>,

/// ID of the model to use
#[arg(long, default_value = "gpt-3.5-turbo-16k")]
model: String,

/// The maximum number of tokens to generate in the chat completion
#[arg(long, default_value_t = 500)]
max_tokens: u16,

/// How many chat completion choices to generate for each input message
#[arg(short, default_value_t = 1)]
n: u8,

/// What sampling temperature to use, between 0 and 2
#[arg(long, default_value_t = 0.0)]
temperature: f32,

/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass
#[arg(long, default_value_t = 0.1)]
top_p: f32,
}

impl Cmd {
fn api_key(&self) -> Result<String> {
if let Some(api_key) = self.api_key.as_ref() {
return Ok(api_key.to_string());
}
if let Ok(api_key) = crate::external::bitwarden::get_notes("OPENAI_KEY") {
return Ok(api_key);
}
crate::bail!("OPENAI_API_KEY is not provided");
}
}

const EXCLUDE: &[&str] = &["*-lock.*", "*.lock"];

#[async_trait::async_trait]
Expand All @@ -77,10 +77,7 @@ impl Run for Cmd {
let request = CreateChatCompletionRequestArgs::default()
.messages([
ChatCompletionRequestSystemMessageArgs::default()
.content(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/res/prompt.md"
)))
.content(self.prompt()?)
.build()
.log()?
.into(),
Expand Down Expand Up @@ -126,6 +123,28 @@ impl Run for Cmd {
}
}

impl Cmd {
fn api_key(&self) -> Result<String> {
if let Some(api_key) = self.api_key.as_deref() {
return Ok(api_key.to_string());
}
if let Ok(api_key) = crate::external::bitwarden::get_notes("OPENAI_API_KEY") {
return Ok(api_key);
}
crate::bail!("OPENAI_API_KEY is not provided");
}

fn prompt(&self) -> Result<String> {
if let Some(prompt) = self.prompt.as_deref() {
return Ok(prompt.to_string());
}
if let Some(prompt_file) = self.prompt_file.as_deref() {
return std::fs::read_to_string(prompt_file).log();
}
Ok(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/res/prompt.md")).to_string())
}
}

fn sanitize<S>(message: S) -> Option<String>
where
S: AsRef<str>,
Expand Down
3 changes: 3 additions & 0 deletions src/cmd/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use clap_complete::Shell;

use crate::cmd::Run;

/// Generate tab-completion scripts for your shell
///
/// $ ai-commit-cli complete fish >$HOME/.local/share/fish/vendor_completions.d
#[derive(Debug, Args)]
pub struct Cmd {
shell: Shell,
Expand Down
24 changes: 18 additions & 6 deletions src/cmd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use anyhow::Result;
use clap::builder::styling::AnsiColor;
use clap::builder::Styles;
use clap::{Parser, Subcommand};
use tracing::Level;

use crate::common::log::Level;

mod commit;
mod complete;
Expand All @@ -13,7 +14,7 @@ pub struct Cmd {
#[command(subcommand)]
sub_cmd: SubCmd,

#[arg(short, long, default_value_t = Level::INFO)]
#[arg(short, long, env, default_value_t = Level::Info)]
log_level: Level,
}

Expand All @@ -37,10 +38,21 @@ enum SubCmd {
#[async_trait::async_trait]
impl Run for Cmd {
async fn run(&self) -> Result<()> {
tracing_subscriber::fmt()
.pretty()
.with_max_level(self.log_level)
.init();
if self.log_level < Level::Info {
tracing_subscriber::fmt()
.pretty()
.with_max_level(self.log_level.as_level())
.init();
} else {
tracing_subscriber::fmt()
.pretty()
.with_file(false)
.with_line_number(false)
.with_max_level(self.log_level.as_level())
.with_target(false)
.without_time()
.init();
}
match &self.sub_cmd {
SubCmd::Commit(cmd) => cmd.run().await,
SubCmd::Complete(cmd) => cmd.run().await,
Expand Down
26 changes: 20 additions & 6 deletions src/common/err.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
#[macro_export]
macro_rules! anyhow {
($msg:literal $(,)?) => {
$crate::common::log::LogError::log(anyhow::anyhow!($msg))
};

($err:expr $(,)?) => {
$crate::common::log::LogError::log(anyhow::anyhow!($err))
};

($fmt:expr, $($arg:tt)*) => {
$crate::common::log::LogError::log(anyhow::anyhow!($fmt, $($arg)*))
};
}

#[macro_export]
macro_rules! bail {
($msg:literal $(,)?) => {
return Err(anyhow::anyhow!($msg)).log()
return Err($crate::anyhow!($msg))
};

($err:expr $(,)?) => {
return Err(anyhow::anyhow!($err)).log()
return Err($crate::anyhow!($err))
};

($fmt:expr, $($arg:tt)*) => {
return Err(anyhow::anyhow!($fmt, $($arg)*)).log()
return Err($crate::anyhow!($fmt, $($arg)*))
};
}

#[macro_export]
macro_rules! ensure {
($cond:expr $(,)?) => {
if !$cond {
return Err(anyhow::anyhow!(concat!(
return Err($crate::anyhow!(concat!(
"Condition failed: `",
stringify!($cond),
"`"
)))
.log();
)));
}
};
}
66 changes: 51 additions & 15 deletions src/common/log.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,46 @@
use std::panic::Location;
use std::{fmt::Display, panic::Location};

pub trait LogResult<T> {
#[track_caller]
fn log(self) -> anyhow::Result<T>;
use clap::ValueEnum;

#[derive(Clone, Debug, ValueEnum, PartialEq, PartialOrd)]
pub enum Level {
Trace,
Debug,
Info,
Warn,
Error,
}

pub trait LogError {
#[track_caller]
fn log(self) -> anyhow::Error;
}

impl<T, E> LogResult<T> for Result<T, E>
where
E: Into<anyhow::Error>,
{
pub trait LogResult<T> {
#[track_caller]
fn log(self) -> anyhow::Result<T> {
fn log(self) -> anyhow::Result<T>;
}

impl Display for Level {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ok(t) => Ok(t),
Err(e) => Err(e.log()),
Self::Trace => write!(f, "trace"),
Self::Debug => write!(f, "debug"),
Self::Info => write!(f, "info"),
Self::Warn => write!(f, "warn"),
Self::Error => write!(f, "error"),
}
}
}

impl Level {
pub fn as_level(&self) -> tracing::Level {
match self {
Self::Trace => tracing::Level::TRACE,
Self::Debug => tracing::Level::DEBUG,
Self::Info => tracing::Level::INFO,
Self::Warn => tracing::Level::WARN,
Self::Error => tracing::Level::ERROR,
}
}
}
Expand All @@ -30,17 +52,18 @@ where
#[track_caller]
fn log(self) -> anyhow::Error {
let e = self.into();
let mut message = e.to_string() + "\n";
let mut message = e.to_string();
let sources = e
.chain()
.skip(1)
.enumerate()
.map(|(i, e)| format!("{:>5}: {}\n", i, e))
.map(|(i, e)| format!("{:>5}: {}", i, e))
.collect::<Vec<String>>()
.join("");
.join("\n");
if !sources.is_empty() {
message += "Caused by:\n";
message += "\nCaused by:\n";
message += &sources;
message += "\n";
}
let location = Location::caller();
tracing::error!(
Expand All @@ -51,3 +74,16 @@ where
e
}
}

impl<T, E> LogResult<T> for Result<T, E>
where
E: Into<anyhow::Error>,
{
#[track_caller]
fn log(self) -> anyhow::Result<T> {
match self {
Ok(t) => Ok(t),
Err(e) => Err(e.log()),
}
}
}
2 changes: 1 addition & 1 deletion src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod err;
mod err;
pub mod log;
1 change: 1 addition & 0 deletions src/external/bitwarden.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ where
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
let output = cmd.output().log()?;
crate::ensure!(output.status.success());
String::from_utf8(output.stdout).log()
}
2 changes: 1 addition & 1 deletion src/external/git.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ where
.stdin(Stdio::piped())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit());
let mut child = cmd.spawn()?;
let mut child = cmd.spawn().log()?;
child
.stdin
.as_ref()
Expand Down
2 changes: 1 addition & 1 deletion src/external/pre_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn run() -> Result<()> {
cmd.stdin(Stdio::null())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit());
let status = cmd.status()?;
let status = cmd.status().log()?;
crate::ensure!(status.success());
Ok(())
}

0 comments on commit 0d48d56

Please sign in to comment.