diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 34ad38e13a..987d8e8cbc 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -29,7 +29,7 @@ futures-util = { version = "0.3.19", default-features = false, features = ["allo # Cryptographic Primitives crc = "3.0.0" hkdf = "0.12.0" -hmac = { version = "0.12.0", default-features = false } +hmac = { version = "0.12.0", default-features = false, features = ["reset"]} md-5 = { version = "0.10.0", default-features = false } rand = { version = "0.8.4", default-features = false, features = ["std", "std_rng"] } sha1 = { version = "0.10.1", default-features = false } diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index e82df42582..16ccd20484 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -195,15 +195,33 @@ fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error mac.update(&salt); mac.update(&1u32.to_be_bytes()); - let mut u = mac.finalize().into_bytes(); + let mut u = mac.finalize_reset().into_bytes(); let mut hi = u; for _ in 1..iter_count { - let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(u.as_slice()); - u = mac.finalize().into_bytes(); + u = mac.finalize_reset().into_bytes(); hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); } Ok(hi.into()) } + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_sasl_hi(b: &mut test::Bencher) { + use test::black_box; + + let mut rng = rand::thread_rng(); + let nonce: Vec = std::iter::repeat(()) + .map(|()| rng.sample(rand::distributions::Alphanumeric)) + .take(64) + .collect(); + b.iter(|| { + let _ = hi( + test::black_box("secret_password"), + test::black_box(&nonce), + test::black_box(4096), + ); + }); +}