diff --git a/integration-tests/tests/docker/mod.rs b/integration-tests/tests/docker/mod.rs index 95cefbe52..feaddc50e 100644 --- a/integration-tests/tests/docker/mod.rs +++ b/integration-tests/tests/docker/mod.rs @@ -161,6 +161,7 @@ impl LeaderNode { near_root_account: &AccountId, account_creator_id: &AccountId, account_creator_sk: &SecretKey, + pagoda_firebase_audience_id: &str, ) -> anyhow::Result { create_network(docker, network).await?; let web_port = portpicker::pick_unused_port().expect("no free ports"); @@ -185,6 +186,8 @@ impl LeaderNode { account_creator_id.to_string(), "--account-creator-sk".to_string(), account_creator_sk.to_string(), + "--pagoda-firebase-audience-id".to_string(), + pagoda_firebase_audience_id.to_string(), ]; for sign_node in sign_nodes { cmd.push("--sign-nodes".to_string()); diff --git a/integration-tests/tests/lib.rs b/integration-tests/tests/lib.rs index 0041270b5..5f25f92e9 100644 --- a/integration-tests/tests/lib.rs +++ b/integration-tests/tests/lib.rs @@ -76,6 +76,9 @@ where let addr = SignNode::start(&docker, NETWORK, i as u64, &pk_set, &sk_shares[i - 1]).await?; sign_nodes.push(addr); } + + let pagoda_firebase_audience_id = "not actually used in integration tests"; + let leader_node = LeaderNode::start( &docker, NETWORK, @@ -88,6 +91,7 @@ where near_root_account.id(), &creator_account_id, &creator_account_sk, + pagoda_firebase_audience_id, ) .await?; diff --git a/mpc-recovery/src/leader_node/mod.rs b/mpc-recovery/src/leader_node/mod.rs index 5250fe2e8..d7172a995 100644 --- a/mpc-recovery/src/leader_node/mod.rs +++ b/mpc-recovery/src/leader_node/mod.rs @@ -39,6 +39,7 @@ pub struct Config { // TODO: temporary solution pub account_creator_sk: SecretKey, pub account_lookup_url: String, + pub pagoda_firebase_audience_id: String, } pub async fn run(config: Config) { @@ -54,6 +55,7 @@ pub async fn run(config: Config) { account_creator_id, account_creator_sk, account_lookup_url, + pagoda_firebase_audience_id, } = config; let _span = tracing::debug_span!("run", id, port); tracing::debug!(?sign_nodes, "running a leader node"); @@ -90,6 +92,7 @@ pub async fn run(config: Config) { account_creator_id, account_creator_sk, account_lookup_url, + pagoda_firebase_audience_id, }; //TODO: not secure, allow only for testnet, whitelist endpoint etc. for mainnet @@ -122,6 +125,7 @@ struct LeaderState { // TODO: temporary solution account_creator_sk: SecretKey, account_lookup_url: String, + pagoda_firebase_audience_id: String, } async fn parse(response_future: ResponseFuture) -> anyhow::Result { @@ -156,9 +160,10 @@ async fn process_new_account( .near_account_id .parse() .map_err(|e| NewAccountError::MalformedAccountId(request.near_account_id, e))?; - let oidc_token_claims = T::verify_token(&request.oidc_token) - .await - .map_err(NewAccountError::OidcVerificationFailed)?; + let oidc_token_claims = + T::verify_token(&request.oidc_token, &state.pagoda_firebase_audience_id) + .await + .map_err(NewAccountError::OidcVerificationFailed)?; let internal_acc_id = get_internal_account_id(oidc_token_claims); state @@ -333,9 +338,10 @@ async fn process_add_key( state: LeaderState, request: AddKeyRequest, ) -> Result { - let oidc_token_claims = T::verify_token(&request.oidc_token) - .await - .map_err(AddKeyError::OidcVerificationFailed)?; + let oidc_token_claims = + T::verify_token(&request.oidc_token, &state.pagoda_firebase_audience_id) + .await + .map_err(AddKeyError::OidcVerificationFailed)?; let internal_acc_id = get_internal_account_id(oidc_token_claims); let user_recovery_pk = get_user_recovery_pk(internal_acc_id.clone()); let user_recovery_sk = get_user_recovery_sk(internal_acc_id); @@ -451,7 +457,7 @@ async fn submit( // TODO: extract access token from payload let access_token = "validToken"; - match T::verify_token(access_token).await { + match T::verify_token(access_token, &state.pagoda_firebase_audience_id).await { Ok(_) => { tracing::info!("access token is valid"); // continue execution diff --git a/mpc-recovery/src/main.rs b/mpc-recovery/src/main.rs index fae9d3a42..b98f7c848 100644 --- a/mpc-recovery/src/main.rs +++ b/mpc-recovery/src/main.rs @@ -57,6 +57,8 @@ enum Cli { default_value("https://api.kitwallet.app") )] account_lookup_url: String, + #[arg(long, env("PAGODA_FIREBASE_AUDIENCE_ID"))] + pagoda_firebase_audience_id: String, }, StartSign { /// Node ID @@ -136,6 +138,7 @@ async fn main() -> anyhow::Result<()> { account_creator_id, account_creator_sk, account_lookup_url, + pagoda_firebase_audience_id, } => { let gcp_service = GcpService::new().await?; let sk_share = load_sh_skare(&gcp_service, node_id, sk_share).await?; @@ -159,6 +162,7 @@ async fn main() -> anyhow::Result<()> { account_creator_id, account_creator_sk, account_lookup_url, + pagoda_firebase_audience_id, }) .await; } diff --git a/mpc-recovery/src/oauth.rs b/mpc-recovery/src/oauth.rs index 6855c9972..336e9f522 100644 --- a/mpc-recovery/src/oauth.rs +++ b/mpc-recovery/src/oauth.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; #[async_trait::async_trait] pub trait OAuthTokenVerifier { - async fn verify_token(token: &str) -> anyhow::Result; + async fn verify_token(token: &str, audience: &str) -> anyhow::Result; /// This function validates JWT (OIDC ID token) by checking the signature received /// from the issuer, issuer, audience, and expiration time. @@ -38,13 +38,13 @@ pub struct UniversalTokenVerifier {} #[async_trait::async_trait] impl OAuthTokenVerifier for UniversalTokenVerifier { - async fn verify_token(token: &str) -> anyhow::Result { + async fn verify_token(token: &str, audience: &str) -> anyhow::Result { match get_token_verifier_type(token) { SupportedTokenVerifiers::PagodaFirebaseTokenVerifier => { - return PagodaFirebaseTokenVerifier::verify_token(token).await; + return PagodaFirebaseTokenVerifier::verify_token(token, audience).await; } SupportedTokenVerifiers::TestTokenVerifier => { - return TestTokenVerifier::verify_token(token).await; + return TestTokenVerifier::verify_token(token, audience).await; } } } @@ -72,21 +72,17 @@ impl OAuthTokenVerifier for PagodaFirebaseTokenVerifier { // Specs for ID token verification: // Google: https://developers.google.com/identity/openid-connect/openid-connect#validatinganidtoken // Firebase: https://firebase.google.com/docs/auth/admin/verify-id-tokens#verify_id_tokens_using_a_third-party_jwt_library - async fn verify_token(token: &str) -> anyhow::Result { + async fn verify_token(token: &str, audience: &str) -> anyhow::Result { let public_key = get_pagoda_firebase_public_key().expect("Failed to get Google public key"); - // this is a tmp Project ID, the real one is: pagoda-onboarding-dev - let pagoda_firebase_audience_id: String = "pagoda-fast-auth-441fe".to_string(); - let pagoda_firebase_issuer_id: String = format!( - "https://securetoken.google.com/{}", - pagoda_firebase_audience_id - ); + let pagoda_firebase_issuer_id: String = + format!("https://securetoken.google.com/{}", audience); let claims = Self::validate_jwt( token, public_key.as_bytes(), &pagoda_firebase_issuer_id, - &pagoda_firebase_audience_id, + audience, ) .expect("Failed to validate JWT"); @@ -99,7 +95,7 @@ pub struct TestTokenVerifier {} #[async_trait::async_trait] impl OAuthTokenVerifier for TestTokenVerifier { - async fn verify_token(token: &str) -> anyhow::Result { + async fn verify_token(token: &str, _audience: &str) -> anyhow::Result { match token { "validToken" => { tracing::info!(target: "test-token-verifier", "access token is valid"); @@ -242,15 +238,17 @@ mod tests { #[tokio::test] async fn test_verify_token_valid() { let token = "validToken"; - let claims = TestTokenVerifier::verify_token(token).await.unwrap(); let test_claims = get_test_claims(); + let claims = TestTokenVerifier::verify_token(token, &test_claims.aud) + .await + .unwrap(); assert!(compare_claims(claims, test_claims)); } #[tokio::test] async fn test_verify_token_invalid_with_test_verifier() { let token = "invalid"; - let result = TestTokenVerifier::verify_token(token).await; + let result = TestTokenVerifier::verify_token(token, "rand").await; match result { Ok(_) => panic!("Token verification should fail"), Err(e) => assert_eq!(e.to_string(), "Invalid token"), @@ -260,15 +258,17 @@ mod tests { #[tokio::test] async fn test_verify_token_valid_with_test_verifier() { let token = "validToken"; - let claims = TestTokenVerifier::verify_token(token).await.unwrap(); let test_claims = get_test_claims(); + let claims = TestTokenVerifier::verify_token(token, &test_claims.aud) + .await + .unwrap(); assert!(compare_claims(claims, test_claims)); } #[tokio::test] async fn test_verify_token_invalid_with_universal_verifier() { let token = "invalid"; - let result = UniversalTokenVerifier::verify_token(token).await; + let result = UniversalTokenVerifier::verify_token(token, "rand").await; match result { Ok(_) => panic!("Token verification should fail"), Err(e) => assert_eq!(e.to_string(), "Invalid token"), @@ -278,8 +278,10 @@ mod tests { #[tokio::test] async fn test_verify_token_valid_with_universal_verifier() { let token = "validToken"; - let claims = UniversalTokenVerifier::verify_token(token).await.unwrap(); let test_claims = get_test_claims(); + let claims = UniversalTokenVerifier::verify_token(token, &test_claims.aud) + .await + .unwrap(); assert!(compare_claims(claims, test_claims)); } diff --git a/mpc-recovery/src/sign_node/mod.rs b/mpc-recovery/src/sign_node/mod.rs index 93eecfd7f..c3bbd35ba 100644 --- a/mpc-recovery/src/sign_node/mod.rs +++ b/mpc-recovery/src/sign_node/mod.rs @@ -14,7 +14,13 @@ pub async fn run(id: NodeId, pk_set: PublicKeySet, sk_share: SecretKeyShare, por return; } - let state = SignNodeState { id, sk_share }; + let pagoda_firebase_audience_id = "pagoda-firebase-audience-id".to_string(); + + let state = SignNodeState { + id, + sk_share, + pagoda_firebase_audience_id, + }; let app = Router::new() .route("/sign", post(sign::)) @@ -32,6 +38,7 @@ pub async fn run(id: NodeId, pk_set: PublicKeySet, sk_share: SecretKeyShare, por struct SignNodeState { id: NodeId, sk_share: SecretKeyShare, + pagoda_firebase_audience_id: String, } #[tracing::instrument(level = "debug", skip_all, fields(id = state.id))] @@ -43,7 +50,7 @@ async fn sign( // TODO: extract access token from payload let access_token = "validToken"; - match T::verify_token(access_token).await { + match T::verify_token(access_token, &state.pagoda_firebase_audience_id).await { Ok(_) => { tracing::debug!("access token is valid"); let response = SigShareResponse::Ok {