diff --git a/src/lib.rs b/src/lib.rs index 06c794b8..74bc5244 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,6 +194,7 @@ use crate::service::body::BodyStreamExt; use chrono::{DateTime, Utc}; use http::{HeaderMap, HeaderValue, Method, Uri}; +use service::middleware::auth_header::AuthHeaderLayer; use std::convert::{Infallible, TryInto}; use std::fmt; use std::io::Write; @@ -685,27 +686,22 @@ impl OctocrabBuilder )); } - let auth_state = match self.config.auth { - Auth::None => AuthState::None, - Auth::Basic { username, password } => AuthState::BasicAuth { username, password }, - Auth::PersonalToken(token) => { - hmap.push(( - http::header::AUTHORIZATION, - format!("Bearer {}", token.expose_secret()).parse().unwrap(), - )); - AuthState::None + let (auth_header, auth_state): (Option, _) = match self.config.auth { + Auth::None => (None, AuthState::None), + Auth::Basic { username, password } => { + (None, AuthState::BasicAuth { username, password }) } - Auth::UserAccessToken(token) => { - hmap.push(( - http::header::AUTHORIZATION, - format!("Bearer {}", token.expose_secret()).parse().unwrap(), - )); - AuthState::None - } - Auth::App(app_auth) => AuthState::App(app_auth), - Auth::OAuth(device) => { - hmap.push(( - http::header::AUTHORIZATION, + Auth::PersonalToken(token) => ( + Some(format!("Bearer {}", token.expose_secret()).parse().unwrap()), + AuthState::None, + ), + Auth::UserAccessToken(token) => ( + Some(format!("Bearer {}", token.expose_secret()).parse().unwrap()), + AuthState::None, + ), + Auth::App(app_auth) => (None, AuthState::App(app_auth)), + Auth::OAuth(device) => ( + Some( format!( "{} {}", device.token_type, @@ -713,9 +709,9 @@ impl OctocrabBuilder ) .parse() .unwrap(), - )); - AuthState::None - } + ), + AuthState::None, + ), }; for (key, value) in self.config.extra_headers.iter() { @@ -742,6 +738,8 @@ impl OctocrabBuilder let client = BaseUriLayer::new(uri).layer(client); + let client = AuthHeaderLayer::new(auth_header).layer(client); + Ok(Octocrab::new(client, auth_state)) } } @@ -1515,10 +1513,16 @@ impl Octocrab { }; if let Some(mut auth_header) = auth_header { - auth_header.set_sensitive(true); - parts - .headers - .insert(http::header::AUTHORIZATION, auth_header); + // Only set the auth_header if the authority (host) is empty (destined for + // GitHub). Otherwise, leave it off as we could have been redirected + // away from GitHub (via follow_location_to_data()), and we don't + // want to give our credentials to third-party services. + if parts.uri.authority().is_none() { + auth_header.set_sensitive(true); + parts + .headers + .insert(http::header::AUTHORIZATION, auth_header); + } } let request = http::Request::from_parts(parts, body); diff --git a/src/service/middleware/auth_header.rs b/src/service/middleware/auth_header.rs new file mode 100644 index 00000000..e25b2fb3 --- /dev/null +++ b/src/service/middleware/auth_header.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; + +use http::{header::AUTHORIZATION, request::Request, HeaderValue}; +use tower::{Layer, Service}; + +#[derive(Clone)] +/// Layer that adds the authentication header to github-bound requests +pub struct AuthHeaderLayer { + pub(crate) auth_header: Arc>, +} + +impl AuthHeaderLayer { + pub fn new(auth_header: Option) -> Self { + AuthHeaderLayer { + auth_header: Arc::new(auth_header), + } + } +} + +impl Layer for AuthHeaderLayer { + type Service = AuthHeader; + + fn layer(&self, inner: S) -> Self::Service { + AuthHeader { + inner, + auth_header: self.auth_header.clone(), + } + } +} + +#[derive(Clone)] +/// Service that adds a static set of extra headers to each request +pub struct AuthHeader { + inner: S, + pub(crate) auth_header: Arc>, +} + +impl Service> for AuthHeader +where + S: Service>, +{ + type Error = S::Error; + type Future = S::Future; + type Response = S::Response; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + // Only set the auth_header if the authority (host) is empty (destined for + // GitHub). Otherwise, leave it off as we could have been redirected + // away from GitHub (via follow_location_to_data()), and we don't + // want to give our credentials to third-party services. + if req.uri().authority().is_none() { + if let Some(auth_header) = &*self.auth_header { + req.headers_mut().append(AUTHORIZATION, auth_header.clone()); + } + } + self.inner.call(req) + } +} diff --git a/src/service/middleware/mod.rs b/src/service/middleware/mod.rs index f9e05649..7dec247a 100644 --- a/src/service/middleware/mod.rs +++ b/src/service/middleware/mod.rs @@ -1,3 +1,4 @@ +pub mod auth_header; pub mod base_uri; pub mod extra_headers; #[cfg(feature = "retry")]