Skip to content

Commit

Permalink
Make the issue optional on upstream OAuth 2.0 providers
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Dec 17, 2024

Verified

This commit was signed with the committer’s verified signature.
sandhose Quentin Gliech
1 parent 80903ed commit 8122c25
Showing 26 changed files with 85 additions and 58 deletions.
4 changes: 3 additions & 1 deletion crates/cli/src/commands/manage.rs
Original file line number Diff line number Diff line change
@@ -764,8 +764,10 @@ impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> {
let provider = self.0;
if let Some(human_name) = &provider.human_name {
write!(f, "{} ({})", human_name, provider.id)
} else if let Some(issuer) = &provider.issuer {
write!(f, "{} ({})", issuer, provider.id)
} else {
write!(f, "{} ({})", provider.issuer, provider.id)
write!(f, "{}", provider.id)
}
}
}
13 changes: 12 additions & 1 deletion crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
@@ -47,6 +47,14 @@ impl ConfigurationSection for UpstreamOAuth2Config {
Err(error)
};

if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
&& provider.issuer.is_none()
{
return annotate(figment::Error::custom(
"The `issuer` field is required when discovery is enabled",
));
}

match provider.token_endpoint_auth_method {
TokenAuthMethod::None
| TokenAuthMethod::PrivateKeyJwt
@@ -438,7 +446,10 @@ pub struct Provider {
pub id: Ulid,

/// The OIDC issuer URL
pub issuer: String,
///
/// This is required if OIDC discovery is enabled (which is the default)
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,

/// A human-readable name for the provider, that will be shown to users
#[serde(skip_serializing_if = "Option::is_none")]
2 changes: 1 addition & 1 deletion crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
@@ -219,7 +219,7 @@ pub struct InvalidUpstreamOAuth2TokenAuthMethod(String);
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
pub issuer: String,
pub issuer: Option<String>,
pub human_name: Option<String>,
pub brand_name: Option<String>,
pub discovery_mode: DiscoveryMode,
4 changes: 2 additions & 2 deletions crates/handlers/src/graphql/model/upstream_oauth.rs
Original file line number Diff line number Diff line change
@@ -37,8 +37,8 @@ impl UpstreamOAuth2Provider {
}

/// OpenID Connect issuer URL.
pub async fn issuer(&self) -> &str {
&self.provider.issuer
pub async fn issuer(&self) -> Option<&str> {
self.provider.issuer.as_deref()
}

/// Client ID used for this provider.
20 changes: 13 additions & 7 deletions crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
@@ -61,10 +61,11 @@ impl<'a> LazyProviderInfos<'a> {
}
};

let metadata = self
.cache
.get(self.client, &self.provider.issuer, verify)
.await?;
let Some(issuer) = &self.provider.issuer else {
return Err(DiscoveryError::MissingIssuer);
};

let metadata = self.cache.get(self.client, issuer, verify).await?;

self.loaded_metadata = Some(metadata);
}
@@ -179,8 +180,13 @@ impl MetadataCache {
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
};

if let Err(e) = self.fetch(client, &provider.issuer, verify).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
let Some(issuer) = &provider.issuer else {
tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
continue;
};

if let Err(e) = self.fetch(client, issuer, verify).await {
tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
}
}

@@ -395,7 +401,7 @@ mod tests {
let clock = MockClock::default();
let provider = UpstreamOAuthProvider {
id: Ulid::nil(),
issuer: mock_server.uri(),
issuer: Some(mock_server.uri()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
4 changes: 2 additions & 2 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
@@ -284,7 +284,7 @@ pub(crate) async fn handler(
);

let id_token_verification_data = JwtVerificationData {
issuer: &provider.issuer,
issuer: provider.issuer.as_deref(),
jwks: jwks.as_ref().unwrap(),
signing_algorithm: &provider.id_token_signed_response_alg,
client_id: &provider.client_id,
@@ -350,7 +350,7 @@ pub(crate) async fn handler(
lazy_metadata.userinfo_endpoint().await?,
token_response.access_token.as_str(),
Some(JwtVerificationData {
issuer: &provider.issuer,
issuer: provider.issuer.as_deref(),
jwks: &jwks,
signing_algorithm,
client_id: &provider.client_id,
2 changes: 1 addition & 1 deletion crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
@@ -916,7 +916,7 @@ mod tests {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([OPENID]),
1 change: 0 additions & 1 deletion crates/handlers/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
@@ -131,7 +131,6 @@ fn client_credentials_for_provider(

ClientCredentials::SignInWithApple {
client_id,
audience: provider.issuer.clone(),
key,
key_id: params.key_id,
team_id: params.team_id,
4 changes: 2 additions & 2 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
@@ -398,7 +398,7 @@ mod test {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://first.com/".to_owned(),
issuer: Some("https://first.com/".to_owned()),
human_name: Some("First Ltd.".to_owned()),
brand_name: None,
scope: [OPENID].into_iter().collect(),
@@ -438,7 +438,7 @@ mod test {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://second.com/".to_owned(),
issuer: Some("https://second.com/".to_owned()),
human_name: None,
brand_name: None,
scope: [OPENID].into_iter().collect(),
5 changes: 5 additions & 0 deletions crates/oidc-client/src/error.rs
Original file line number Diff line number Diff line change
@@ -55,6 +55,11 @@ pub enum DiscoveryError {
/// An error occurred validating the metadata.
Validation(#[from] ProviderMetadataVerificationError),

/// The provider doesn't have an issuer set, which is required if discovery
/// is enabled.
#[error("Provider doesn't have an issuer set")]
MissingIssuer,

/// Discovery is disabled for this provider.
#[error("Discovery is disabled for this provider")]
Disabled,
10 changes: 6 additions & 4 deletions crates/oidc-client/src/requests/jose.rs
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ pub async fn fetch_jwks(
#[derive(Clone, Copy)]
pub struct JwtVerificationData<'a> {
/// The URL of the issuer that generated the ID Token.
pub issuer: &'a str,
pub issuer: Option<&'a str>,

/// The issuer's JWKS.
pub jwks: &'a PublicJsonWebKeySet,
@@ -76,7 +76,7 @@ pub struct JwtVerificationData<'a> {
///
/// * The signature is verified with the given JWKS.
///
/// * The `iss` claim must be present and match the issuer.
/// * The `iss` claim must be present and match the issuer, if present
///
/// * The `aud` claim must be present and match the client ID.
///
@@ -117,8 +117,10 @@ pub fn verify_signed_jwt<'a>(

let (header, mut claims) = jwt.clone().into_parts();

// Must have the proper issuer.
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
if let Some(issuer) = issuer {
// Must have the proper issuer.
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
}

// Must have the proper audience.
claims::AUD.extract_required_with_options(&mut claims, client_id)?;
6 changes: 1 addition & 5 deletions crates/oidc-client/src/types/client_credentials.rs
Original file line number Diff line number Diff line change
@@ -103,9 +103,6 @@ pub enum ClientCredentials {
/// The unique ID for the client.
client_id: String,

/// The audience to use. Usually `https://appleid.apple.com`
audience: String,

/// The ECDSA key used to sign
key: elliptic_curve::SecretKey<p256::NistP256>,

@@ -240,7 +237,6 @@ impl ClientCredentials {

ClientCredentials::SignInWithApple {
client_id,
audience,
key,
key_id,
team_id,
@@ -253,7 +249,7 @@ impl ClientCredentials {

claims::ISS.insert(&mut claims, team_id)?;
claims::SUB.insert(&mut claims, client_id)?;
claims::AUD.insert(&mut claims, audience.clone())?;
claims::AUD.insert(&mut claims, "https://appleid.apple.com".to_owned())?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::microseconds(60 * 1000 * 1000))?;

6 changes: 3 additions & 3 deletions crates/oidc-client/tests/it/requests/authorization_code.rs
Original file line number Diff line number Diff line change
@@ -193,7 +193,7 @@ async fn pass_access_token_with_authorization_code() {

let (id_token, jwks) = id_token(issuer.as_str());
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -251,7 +251,7 @@ async fn fail_access_token_with_authorization_code_wrong_nonce() {

let (id_token, jwks) = id_token(issuer.as_str());
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -312,7 +312,7 @@ async fn fail_access_token_with_authorization_code_no_id_token() {
};

let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &PublicJsonWebKeySet::default(),
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
14 changes: 7 additions & 7 deletions crates/oidc-client/tests/it/requests/jose.rs
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ async fn pass_verify_id_token() {
let (id_token, jwks) = id_token(issuer, None, Some(now));

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -111,7 +111,7 @@ async fn fail_verify_id_token_wrong_issuer() {
let now = now();

let verification_data = JwtVerificationData {
issuer: wrong_issuer,
issuer: Some(wrong_issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -135,7 +135,7 @@ async fn fail_verify_id_token_wrong_audience() {
let now = now();

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &"wrong_client_id".to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -159,7 +159,7 @@ async fn fail_verify_id_token_wrong_signing_algorithm() {
let now = now();

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &JsonWebSignatureAlg::Unknown("wrong_algorithm".to_owned()),
@@ -180,7 +180,7 @@ async fn fail_verify_id_token_wrong_expiration() {
let now = now();

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -199,7 +199,7 @@ async fn fail_verify_id_token_wrong_subject() {
let (id_token, jwks) = id_token(issuer, Some(IdTokenFlag::WrongSubject), None);

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -224,7 +224,7 @@ async fn fail_verify_id_token_wrong_auth_time() {
let (id_token, jwks) = id_token(issuer, None, Some(now + Duration::try_hours(1).unwrap()));

let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.

-- Make the issuer field in the upstream_oauth_providers table optional
ALTER TABLE "upstream_oauth_providers"
ALTER COLUMN "issuer" DROP NOT NULL;
4 changes: 2 additions & 2 deletions crates/storage-pg/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
@@ -148,7 +148,7 @@ impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
db.query.text,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
@@ -192,7 +192,7 @@ impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
upstream_oauth_link.subject = subject,
upstream_oauth_link.human_account_name = human_account_name,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
9 changes: 4 additions & 5 deletions crates/storage-pg/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ mod tests {
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: None,
brand_name: None,
scope: Scope::from_iter([OPENID]),
@@ -88,13 +88,13 @@ mod tests {
.await
.unwrap()
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
assert_eq!(provider.client_id, "client-id");

// It should be in the list of all providers
let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].issuer, "https://example.com/");
assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
assert_eq!(providers[0].client_id, "client-id");

// Start a session
@@ -277,7 +277,6 @@ mod tests {
/// provider repository
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_provider_repository_pagination(pool: PgPool) {
const ISSUER: &str = "https://example.com/";
let scope = Scope::from_iter([OPENID]);

let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
@@ -302,7 +301,7 @@ mod tests {
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: ISSUER.to_owned(),
issuer: None,
human_name: None,
brand_name: None,
scope: scope.clone(),
Loading

0 comments on commit 8122c25

Please sign in to comment.