Skip to content

Commit

Permalink
Make the issue optional on upstream OAuth 2.0 providers
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Dec 17, 2024
1 parent 80903ed commit a97d2da
Show file tree
Hide file tree
Showing 26 changed files with 85 additions and 58 deletions.
4 changes: 3 additions & 1 deletion crates/cli/src/commands/manage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,10 @@ impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> {
let provider = self.0;
if let Some(human_name) = &provider.human_name {
write!(f, "{} ({})", human_name, provider.id)
} else if let Some(issuer) = &provider.issuer {
write!(f, "{} ({})", issuer, provider.id)
} else {
write!(f, "{} ({})", provider.issuer, provider.id)
write!(f, "{}", provider.id)
}
}
}
Expand Down
13 changes: 12 additions & 1 deletion crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ impl ConfigurationSection for UpstreamOAuth2Config {
Err(error)
};

if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
&& provider.issuer.is_none()
{
return annotate(figment::Error::custom(
"The `issuer` field is required when discovery is enabled",
));
}

match provider.token_endpoint_auth_method {
TokenAuthMethod::None
| TokenAuthMethod::PrivateKeyJwt
Expand Down Expand Up @@ -438,7 +446,10 @@ pub struct Provider {
pub id: Ulid,

/// The OIDC issuer URL
pub issuer: String,
///
/// This is required if OIDC discovery is enabled (which is the default)
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,

/// A human-readable name for the provider, that will be shown to users
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down
2 changes: 1 addition & 1 deletion crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub struct InvalidUpstreamOAuth2TokenAuthMethod(String);
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
pub issuer: String,
pub issuer: Option<String>,
pub human_name: Option<String>,
pub brand_name: Option<String>,
pub discovery_mode: DiscoveryMode,
Expand Down
4 changes: 2 additions & 2 deletions crates/handlers/src/graphql/model/upstream_oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ impl UpstreamOAuth2Provider {
}

/// OpenID Connect issuer URL.
pub async fn issuer(&self) -> &str {
&self.provider.issuer
pub async fn issuer(&self) -> Option<&str> {
self.provider.issuer.as_deref()
}

/// Client ID used for this provider.
Expand Down
20 changes: 13 additions & 7 deletions crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ impl<'a> LazyProviderInfos<'a> {
}
};

let metadata = self
.cache
.get(self.client, &self.provider.issuer, verify)
.await?;
let Some(issuer) = &self.provider.issuer else {
return Err(DiscoveryError::MissingIssuer);
};

let metadata = self.cache.get(self.client, issuer, verify).await?;

self.loaded_metadata = Some(metadata);
}
Expand Down Expand Up @@ -179,8 +180,13 @@ impl MetadataCache {
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
};

if let Err(e) = self.fetch(client, &provider.issuer, verify).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
let Some(issuer) = &provider.issuer else {
tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
continue;
};

if let Err(e) = self.fetch(client, issuer, verify).await {
tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
}
}

Expand Down Expand Up @@ -395,7 +401,7 @@ mod tests {
let clock = MockClock::default();
let provider = UpstreamOAuthProvider {
id: Ulid::nil(),
issuer: mock_server.uri(),
issuer: Some(mock_server.uri()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
Expand Down
4 changes: 2 additions & 2 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ pub(crate) async fn handler(
);

let id_token_verification_data = JwtVerificationData {
issuer: &provider.issuer,
issuer: provider.issuer.as_deref(),
jwks: jwks.as_ref().unwrap(),
signing_algorithm: &provider.id_token_signed_response_alg,
client_id: &provider.client_id,
Expand Down Expand Up @@ -350,7 +350,7 @@ pub(crate) async fn handler(
lazy_metadata.userinfo_endpoint().await?,
token_response.access_token.as_str(),
Some(JwtVerificationData {
issuer: &provider.issuer,
issuer: provider.issuer.as_deref(),
jwks: &jwks,
signing_algorithm,
client_id: &provider.client_id,
Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ mod tests {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([OPENID]),
Expand Down
1 change: 0 additions & 1 deletion crates/handlers/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ fn client_credentials_for_provider(

ClientCredentials::SignInWithApple {
client_id,
audience: provider.issuer.clone(),
key,
key_id: params.key_id,
team_id: params.team_id,
Expand Down
4 changes: 2 additions & 2 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ mod test {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://first.com/".to_owned(),
issuer: Some("https://first.com/".to_owned()),
human_name: Some("First Ltd.".to_owned()),
brand_name: None,
scope: [OPENID].into_iter().collect(),
Expand Down Expand Up @@ -438,7 +438,7 @@ mod test {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://second.com/".to_owned(),
issuer: Some("https://second.com/".to_owned()),
human_name: None,
brand_name: None,
scope: [OPENID].into_iter().collect(),
Expand Down
5 changes: 5 additions & 0 deletions crates/oidc-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ pub enum DiscoveryError {
/// An error occurred validating the metadata.
Validation(#[from] ProviderMetadataVerificationError),

/// The provider doesn't have an issuer set, which is required if discovery
/// is enabled.
#[error("Provider doesn't have an issuer set")]
MissingIssuer,

/// Discovery is disabled for this provider.
#[error("Discovery is disabled for this provider")]
Disabled,
Expand Down
10 changes: 6 additions & 4 deletions crates/oidc-client/src/requests/jose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub async fn fetch_jwks(
#[derive(Clone, Copy)]
pub struct JwtVerificationData<'a> {
/// The URL of the issuer that generated the ID Token.
pub issuer: &'a str,
pub issuer: Option<&'a str>,

/// The issuer's JWKS.
pub jwks: &'a PublicJsonWebKeySet,
Expand All @@ -76,7 +76,7 @@ pub struct JwtVerificationData<'a> {
///
/// * The signature is verified with the given JWKS.
///
/// * The `iss` claim must be present and match the issuer.
/// * The `iss` claim must be present and match the issuer, if present
///
/// * The `aud` claim must be present and match the client ID.
///
Expand Down Expand Up @@ -117,8 +117,10 @@ pub fn verify_signed_jwt<'a>(

let (header, mut claims) = jwt.clone().into_parts();

// Must have the proper issuer.
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
if let Some(issuer) = issuer {
// Must have the proper issuer.
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
}

// Must have the proper audience.
claims::AUD.extract_required_with_options(&mut claims, client_id)?;
Expand Down
6 changes: 1 addition & 5 deletions crates/oidc-client/src/types/client_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ pub enum ClientCredentials {
/// The unique ID for the client.
client_id: String,

/// The audience to use. Usually `https://appleid.apple.com`
audience: String,

/// The ECDSA key used to sign
key: elliptic_curve::SecretKey<p256::NistP256>,

Expand Down Expand Up @@ -240,7 +237,6 @@ impl ClientCredentials {

ClientCredentials::SignInWithApple {
client_id,
audience,
key,
key_id,
team_id,
Expand All @@ -253,7 +249,7 @@ impl ClientCredentials {

claims::ISS.insert(&mut claims, team_id)?;
claims::SUB.insert(&mut claims, client_id)?;
claims::AUD.insert(&mut claims, audience.clone())?;
claims::AUD.insert(&mut claims, "https://appleid.apple.com".to_owned())?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::microseconds(60 * 1000 * 1000))?;

Expand Down
6 changes: 3 additions & 3 deletions crates/oidc-client/tests/it/requests/authorization_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async fn pass_access_token_with_authorization_code() {

let (id_token, jwks) = id_token(issuer.as_str());
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand Down Expand Up @@ -251,7 +251,7 @@ async fn fail_access_token_with_authorization_code_wrong_nonce() {

let (id_token, jwks) = id_token(issuer.as_str());
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand Down Expand Up @@ -312,7 +312,7 @@ async fn fail_access_token_with_authorization_code_no_id_token() {
};

let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &PublicJsonWebKeySet::default(),
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand Down
14 changes: 7 additions & 7 deletions crates/oidc-client/tests/it/requests/jose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async fn pass_verify_id_token() {
let (id_token, jwks) = id_token(issuer, None, Some(now));

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand All @@ -111,7 +111,7 @@ async fn fail_verify_id_token_wrong_issuer() {
let now = now();

let verification_data = JwtVerificationData {
issuer: wrong_issuer,
issuer: Some(wrong_issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand All @@ -135,7 +135,7 @@ async fn fail_verify_id_token_wrong_audience() {
let now = now();

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &"wrong_client_id".to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand All @@ -159,7 +159,7 @@ async fn fail_verify_id_token_wrong_signing_algorithm() {
let now = now();

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &JsonWebSignatureAlg::Unknown("wrong_algorithm".to_owned()),
Expand All @@ -180,7 +180,7 @@ async fn fail_verify_id_token_wrong_expiration() {
let now = now();

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand All @@ -199,7 +199,7 @@ async fn fail_verify_id_token_wrong_subject() {
let (id_token, jwks) = id_token(issuer, Some(IdTokenFlag::WrongSubject), None);

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand All @@ -224,7 +224,7 @@ async fn fail_verify_id_token_wrong_auth_time() {
let (id_token, jwks) = id_token(issuer, None, Some(now + Duration::try_hours(1).unwrap()));

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.

-- Make the issuer field in the upstream_oauth_providers table optional
ALTER TABLE "upstream_oauth_providers"
ALTER COLUMN "issuer" DROP NOT NULL;
4 changes: 2 additions & 2 deletions crates/storage-pg/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
db.query.text,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
Expand Down Expand Up @@ -192,7 +192,7 @@ impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
upstream_oauth_link.subject = subject,
upstream_oauth_link.human_account_name = human_account_name,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
Expand Down
9 changes: 4 additions & 5 deletions crates/storage-pg/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ mod tests {
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: None,
brand_name: None,
scope: Scope::from_iter([OPENID]),
Expand Down Expand Up @@ -88,13 +88,13 @@ mod tests {
.await
.unwrap()
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
assert_eq!(provider.client_id, "client-id");

// It should be in the list of all providers
let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].issuer, "https://example.com/");
assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
assert_eq!(providers[0].client_id, "client-id");

// Start a session
Expand Down Expand Up @@ -277,7 +277,6 @@ mod tests {
/// provider repository
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_provider_repository_pagination(pool: PgPool) {
const ISSUER: &str = "https://example.com/";
let scope = Scope::from_iter([OPENID]);

let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
Expand All @@ -302,7 +301,7 @@ mod tests {
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: ISSUER.to_owned(),
issuer: None,
human_name: None,
brand_name: None,
scope: scope.clone(),
Expand Down
Loading

0 comments on commit a97d2da

Please sign in to comment.