Skip to content

Commit

Permalink
Refactor AuthMiddleware and improve authentication coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb committed Apr 11, 2024
1 parent 50f265a commit 350bd0b
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 111 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/uv-auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ tempfile = { workspace = true }
tokio = { workspace = true }
wiremock = { workspace = true }
insta = { version = "1.36.1" }
test-log = { version = "0.2.15", features = ["trace"], default-features = false }
122 changes: 71 additions & 51 deletions crates/uv-auth/src/credentials.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use base64::prelude::BASE64_STANDARD;
use base64::read::DecoderReader;
use base64::write::EncoderWriter;
use netrc::Authenticator;
use netrc::Netrc;
use reqwest::header::HeaderValue;
use reqwest::Request;
use std::io::Read;
use std::io::Write;
use url::Url;

Expand All @@ -24,9 +28,16 @@ impl Credentials {
self.password.as_deref()
}

/// Extract credentials from a URL.
/// Return [`Credentials`] for a [`Url`] from a [`Netrc`] file, if any.
pub fn from_netrc(netrc: &Netrc, url: &Url) -> Option<Self> {
url.host_str()
.and_then(|host| netrc.hosts.get(host).or_else(|| netrc.hosts.get("default")))
.map(Self::from)
}

/// Parse [`Credentials`] from a URL, if any.
///
/// Returns `None` if `username` and `password` are not populated.
/// Returns [`None`] if both `username` and `password` are not populated.
pub fn from_url(url: &Url) -> Option<Self> {
if url.username().is_empty() && url.password().is_none() {
return None;
Expand All @@ -44,74 +55,83 @@ impl Credentials {
}),
})
}
}

impl From<Authenticator> for Credentials {
fn from(auth: Authenticator) -> Self {
Credentials {
username: auth.login,
password: Some(auth.password),
/// Parse [`Credentials`] from an HTTP request, if any.
///
/// Only HTTP Basic Authentication is supported.
pub fn from_request(request: &Request) -> Option<Self> {
// First, attempt to retrieve the credentials from the URL
Self::from_url(request.url()).or(
// Then, attempt to pull the credentials from the headers
request
.headers()
.get(reqwest::header::AUTHORIZATION)
.map(Self::from_header_value)?,
)
}

/// Parse [`Credentials`] from an authorization header, if any.
///
/// Only HTTP Basic Authentication is supported.
///
/// [`None`] will be returned if any error is encountered.
pub(crate) fn from_header_value(header: &HeaderValue) -> Option<Self> {
let mut value = header.as_bytes().strip_prefix(b"Basic ")?;
let mut decoder = DecoderReader::new(&mut value, &BASE64_STANDARD);
let mut buf = String::new();
decoder.read_to_string(&mut buf).ok()?;
let (username, password) = buf.split_once(':')?;
let password = if password.is_empty() {
None
} else {
Some(password.to_string())
};
Some(Self::new(username.to_string(), password))
}

/// Create an HTTP Basic Authentication header for the credentials.
pub(crate) fn to_header_value(&self) -> HeaderValue {
// See: <https://github.com/seanmonstar/reqwest/blob/2c11ef000b151c2eebeed2c18a7b81042220c6b0/src/util.rs#L3>
let mut buf = b"Basic ".to_vec();
{
let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD);
let _ = write!(encoder, "{}:", self.username());
if let Some(password) = self.password() {
let _ = write!(encoder, "{}", password);
}
}
let mut header = HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue");
header.set_sensitive(true);
header
}
}

impl Credentials {
/// Attach the credentials to the given request.
///
/// Any existing credentials will be overridden.
#[must_use]
pub fn authenticated_request(&self, mut request: reqwest::Request) -> reqwest::Request {
request.headers_mut().insert(
reqwest::header::AUTHORIZATION,
basic_auth(self.username(), self.password()),
);
request
.headers_mut()
.insert(reqwest::header::AUTHORIZATION, Self::to_header_value(self));
request
}
}

/// Create a `HeaderValue` for basic authentication.
///
/// Source: <https://github.com/seanmonstar/reqwest/blob/2c11ef000b151c2eebeed2c18a7b81042220c6b0/src/util.rs#L3>
fn basic_auth<U, P>(username: U, password: Option<P>) -> HeaderValue
where
U: std::fmt::Display,
P: std::fmt::Display,
{
let mut buf = b"Basic ".to_vec();
{
let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD);
let _ = write!(encoder, "{}:", username);
if let Some(password) = password {
let _ = write!(encoder, "{}", password);
impl From<&Authenticator> for Credentials {
fn from(auth: &Authenticator) -> Self {
Credentials {
username: auth.login.clone(),
password: Some(auth.password.clone()),
}
}
let mut header = HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue");
header.set_sensitive(true);
header
}

#[cfg(test)]
mod test {
use std::io::Read;

use base64::read::DecoderReader;
use insta::{assert_debug_snapshot, assert_snapshot};
use insta::assert_debug_snapshot;

use super::*;

fn decode_basic_auth(header: HeaderValue) -> String {
let mut value = header.as_bytes();
value = value
.strip_prefix(b"Basic ")
.expect("Basic authentication should start with 'Basic '");
let mut decoder = DecoderReader::new(&mut value, &BASE64_STANDARD);
let mut buf = "Basic: ".to_string();
decoder
.read_to_string(&mut buf)
.expect("Header contents should be valid base64");
buf
}

#[test]
fn from_url_no_credentials() {
let url = &Url::parse("https://example.com/simple/first/").unwrap();
Expand Down Expand Up @@ -148,7 +168,7 @@ mod test {
header.set_sensitive(false);

assert_debug_snapshot!(header, @r###""Basic dXNlcjpwYXNzd29yZA==""###);
assert_snapshot!(decode_basic_auth(header), @"Basic: user:password");
assert_eq!(Credentials::from_header_value(&header), Some(credentials));
}

#[test]
Expand All @@ -170,7 +190,7 @@ mod test {
header.set_sensitive(false);

assert_debug_snapshot!(header, @r###""Basic dXNlckBkb21haW46cGFzc3dvcmQ=""###);
assert_snapshot!(decode_basic_auth(header), @"Basic: user@domain:password");
assert_eq!(Credentials::from_header_value(&header), Some(credentials));
}

#[test]
Expand All @@ -192,6 +212,6 @@ mod test {
header.set_sensitive(false);

assert_debug_snapshot!(header, @r###""Basic dXNlcjpwYXNzd29yZD09""###);
assert_snapshot!(decode_basic_auth(header), @"Basic: user:password==");
assert_eq!(Credentials::from_header_value(&header), Some(credentials));
}
}
Loading

0 comments on commit 350bd0b

Please sign in to comment.