diff --git a/Cargo.lock b/Cargo.lock index 9fdecf57..adfc7662 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + [[package]] name = "ai-commit-cli" version = "0.0.0" @@ -28,6 +37,7 @@ dependencies = [ "clap_complete", "colored", "inquire", + "regex", "tokio", "tracing", "tracing-subscriber", @@ -999,6 +1009,35 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + [[package]] name = "reqwest" version = "0.11.22" diff --git a/Cargo.toml b/Cargo.toml index cf91dd06..c787ead9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ clap = { features = ["cargo", "derive", "env"], version = "4.4.8" } clap_complete = "4.4.4" colored = "2.0.4" inquire = "0.6.2" +regex = "1.10.2" tokio = { features = ["full"], version = "1.34.0" } tracing = "0.1.40" tracing-subscriber = "0.3.18" diff --git a/src/cmd/commit.rs b/src/cmd/commit.rs index 633e949b..69bc2106 100644 --- a/src/cmd/commit.rs +++ b/src/cmd/commit.rs @@ -10,6 +10,7 @@ use async_openai::Client; use clap::Args; use colored::Colorize; use inquire::Select; +use regex::Regex; use crate::cmd::Run; use crate::common::log::LogResult; @@ -22,6 +23,9 @@ pub struct Cmd { #[arg(short, long)] exclude: Vec, + #[arg(short, long)] + include: Vec, + #[arg(long, default_value = "gpt-3.5-turbo-16k")] model: String, @@ -55,12 +59,13 @@ impl Cmd { #[async_trait::async_trait] impl Run for Cmd { async fn run(&self) -> anyhow::Result<()> { + crate::external::pre_commit::run()?; let mut exclude: Vec<_> = EXCLUDE.iter().map(PathBuf::from).collect(); self.exclude .iter() .for_each(|f| exclude.push(f.to_path_buf())); - crate::external::git::status(&exclude)?; - let diff = crate::external::git::diff(exclude)?; + crate::external::git::status(&exclude, &self.include)?; + let diff = crate::external::git::diff(exclude, &self.include)?; crate::ensure!(!diff.trim().is_empty()); let client = Client::with_config(OpenAIConfig::new().with_api_key(self.api_key()?)); let request = CreateChatCompletionRequestArgs::default() @@ -104,6 +109,7 @@ impl Run for Cmd { .choices .iter() .filter_map(|c| c.message.content.as_deref()) + .filter_map(sanitize) .collect(), ) .prompt() @@ -111,3 +117,33 @@ impl Run for Cmd { crate::external::git::commit(select) } } + +fn sanitize(message: S) -> Option +where + S: AsRef, +{ + let message = message.as_ref(); + let mut lines: Vec<_> = message.trim().split('\n').collect(); + let subject = lines[0].trim(); + let pattern: Regex = + Regex::new(r"(?P\w+)(?:\((?P\w+)\))?(?P!)?: (?P.+)") + .log() + .unwrap(); + let matches = pattern.captures(subject)?; + let type_ = matches.name("type")?.as_str(); + let scope = matches.name("scope").map(|s| s.as_str().to_lowercase()); + let breaking = matches.name("breaking").is_some(); + let description = matches.name("description")?.as_str(); + let description = description.chars().next()?.to_lowercase().to_string() + &description[1..]; + let mut subject = type_.to_string(); + if let Some(scope) = scope { + subject += &format!("({})", scope); + } + if breaking { + subject += "!"; + } + subject += ": "; + subject += &description; + lines[0] = &subject; + Some(lines.join("\n")) +} diff --git a/src/external/git.rs b/src/external/git.rs index c0e44bd8..21b4d6ae 100644 --- a/src/external/git.rs +++ b/src/external/git.rs @@ -27,16 +27,21 @@ where Ok(()) } -pub fn diff(exclude: I) -> Result +pub fn diff(exclude: I, include: J) -> Result where I: IntoIterator, + J: IntoIterator, S: AsRef, + T: AsRef, { let mut cmd = Command::new("git"); cmd.args(["diff", "--cached"]); exclude.into_iter().for_each(|p| { cmd.arg(format!(":(exclude){}", p.as_ref().to_str().unwrap())); }); + include.into_iter().for_each(|p| { + cmd.arg(p.as_ref().to_str().unwrap()); + }); cmd.stdin(Stdio::null()) .stdout(Stdio::piped()) .stderr(Stdio::inherit()); @@ -45,10 +50,12 @@ where String::from_utf8(output.stdout).log() } -pub fn status(exclude: I) -> Result<()> +pub fn status(exclude: I, include: J) -> Result<()> where I: IntoIterator, + J: IntoIterator, S: AsRef, + T: AsRef, { let mut cmd = Command::new("git"); if std::io::stdout().is_terminal() { @@ -58,6 +65,9 @@ where exclude.into_iter().for_each(|p| { cmd.arg(format!(":(exclude){}", p.as_ref().to_str().unwrap())); }); + include.into_iter().for_each(|p| { + cmd.arg(p.as_ref().to_str().unwrap()); + }); cmd.stdin(Stdio::null()) .stdout(Stdio::inherit()) .stderr(Stdio::inherit()); diff --git a/src/external/mod.rs b/src/external/mod.rs index ea80b743..21768613 100644 --- a/src/external/mod.rs +++ b/src/external/mod.rs @@ -1,2 +1,3 @@ pub mod bitwarden; pub mod git; +pub mod pre_commit; diff --git a/src/external/pre_commit.rs b/src/external/pre_commit.rs new file mode 100644 index 00000000..9dd3aa17 --- /dev/null +++ b/src/external/pre_commit.rs @@ -0,0 +1,16 @@ +use std::io::IsTerminal; +use std::process::Command; + +use crate::common::log::LogResult; +use anyhow::Result; + +pub fn run() -> Result<()> { + let mut cmd = Command::new("pre-commit"); + cmd.arg("run"); + if std::io::stdout().is_terminal() { + cmd.arg("--color=always"); + } + let status = cmd.status()?; + crate::ensure!(status.success()); + Ok(()) +}