Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add jwt_signature_pk_urls to state #866

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions integration-tests/fastauth/src/env/containers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use testcontainers::{
use tokio::io::AsyncWriteExt;
use tracing;

use std::collections::HashMap;
use std::fs;

use crate::env::{Context, LeaderNodeApi, SignerNodeApi};
Expand Down Expand Up @@ -522,6 +523,8 @@ impl SignerNode<'_> {
cipher_key: &GenericArray<u8, U32>,
) -> anyhow::Result<SignerNode<'a>> {
tracing::info!("Running signer node container {}...", node_id);
let mut jwt_signature_pk_urls = HashMap::new();
jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone());
let args = mpc_recovery::Cli::StartSign {
env: ctx.env.clone(),
node_id: node_id as u64,
Expand All @@ -530,7 +533,7 @@ impl SignerNode<'_> {
cipher_key: Some(hex::encode(cipher_key)),
gcp_project_id: ctx.gcp_project_id.clone(),
gcp_datastore_url: Some(ctx.datastore.address.clone()),
jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_url.clone(),
jwt_signature_pk_urls,
logging_options: logging::Options::default(),
}
.into_str_args();
Expand Down Expand Up @@ -636,6 +639,8 @@ impl<'a> LeaderNode<'a> {
pub async fn run(ctx: &Context<'a>, sign_nodes: Vec<String>) -> anyhow::Result<LeaderNode<'a>> {
tracing::info!("Running leader node container...");
let account_creator = &ctx.relayer_ctx.creator_account;
let mut jwt_signature_pk_urls = HashMap::new();
jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone());
let args = mpc_recovery::Cli::StartLeader {
env: ctx.env.clone(),
web_port: Self::CONTAINER_PORT,
Expand Down Expand Up @@ -667,7 +672,7 @@ impl<'a> LeaderNode<'a> {
fast_auth_partners_filepath: None,
gcp_project_id: ctx.gcp_project_id.clone(),
gcp_datastore_url: Some(ctx.datastore.address.to_string()),
jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_url.to_string(),
jwt_signature_pk_urls,
logging_options: logging::Options::default(),
}
.into_str_args();
Expand Down
10 changes: 8 additions & 2 deletions integration-tests/fastauth/src/env/local.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use crate::env::{LeaderNodeApi, SignerNodeApi};
use crate::mpc::{self, NodeProcess};
use crate::util;
Expand Down Expand Up @@ -29,6 +31,8 @@ impl SignerNode {
cipher_key: &GenericArray<u8, U32>,
) -> anyhow::Result<Self> {
let web_port = util::pick_unused_port().await?;
let mut jwt_signature_pk_urls = HashMap::new();
jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone());
let cli = mpc_recovery::Cli::StartSign {
env: ctx.env.clone(),
node_id,
Expand All @@ -37,7 +41,7 @@ impl SignerNode {
cipher_key: Some(hex::encode(cipher_key)),
gcp_project_id: ctx.gcp_project_id.clone(),
gcp_datastore_url: Some(ctx.datastore.local_address.clone()),
jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_local_url.clone(),
jwt_signature_pk_urls,
logging_options: logging::Options::default(),
};

Expand Down Expand Up @@ -87,6 +91,8 @@ impl LeaderNode {
tracing::info!("Running leader node...");
let account_creator = &ctx.relayer_ctx.creator_account;
let web_port = util::pick_unused_port().await?;
let mut jwt_signature_pk_urls = HashMap::new();
jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone());
let cli = mpc_recovery::Cli::StartLeader {
env: ctx.env.clone(),
web_port,
Expand Down Expand Up @@ -118,7 +124,7 @@ impl LeaderNode {
),
gcp_project_id: ctx.gcp_project_id.clone(),
gcp_datastore_url: Some(ctx.datastore.local_address.clone()),
jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_local_url.clone(),
jwt_signature_pk_urls,
logging_options: logging::Options::default(),
};

Expand Down
15 changes: 8 additions & 7 deletions mpc-recovery/src/leader_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use near_primitives::delegate_action::{DelegateAction, NonDelegateAction};
use near_primitives::transaction::{Action, DeleteAccountAction, DeleteKeyAction};
use near_primitives::types::AccountId;
use prometheus::{Encoder, TextEncoder};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
Expand All @@ -47,7 +48,7 @@ pub struct Config {
// TODO: temporary solution
pub account_creator_signer: KeyRotatingSigner,
pub partners: PartnerList,
pub jwt_signature_pk_url: String,
pub jwt_signature_pk_urls: HashMap<String, String>,
}

pub async fn run(config: Config) {
Expand All @@ -59,7 +60,7 @@ pub async fn run(config: Config) {
near_root_account,
account_creator_signer,
partners,
jwt_signature_pk_url,
jwt_signature_pk_urls
} = config;
let _span = tracing::debug_span!("run", env, port);
tracing::debug!(?sign_nodes, "running a leader node");
Expand All @@ -74,7 +75,7 @@ pub async fn run(config: Config) {
near_root_account: near_root_account.parse().unwrap(),
account_creator_signer,
partners,
jwt_signature_pk_url,
jwt_signature_pk_urls
});

// Get keys from all sign nodes, and broadcast them out as a set.
Expand Down Expand Up @@ -198,7 +199,7 @@ struct LeaderState {
// TODO: temporary solution
account_creator_signer: KeyRotatingSigner,
partners: PartnerList,
jwt_signature_pk_url: String,
jwt_signature_pk_urls: HashMap<String, String>,
}

async fn mpc_public_key(
Expand Down Expand Up @@ -302,7 +303,7 @@ async fn process_user_credentials(
&request.oidc_token,
Some(&state.partners.oidc_providers()),
&state.reqwest_client,
&state.jwt_signature_pk_url,
&state.jwt_signature_pk_urls,
)
.await
.map_err(LeaderNodeError::OidcVerificationFailed)?;
Expand Down Expand Up @@ -334,7 +335,7 @@ async fn process_new_account(
&request.oidc_token,
Some(&state.partners.oidc_providers()),
&state.reqwest_client,
&state.jwt_signature_pk_url,
&state.jwt_signature_pk_urls,
)
.await
.map_err(LeaderNodeError::OidcVerificationFailed)?;
Expand Down Expand Up @@ -477,7 +478,7 @@ async fn process_sign(
&request.oidc_token,
Some(&state.partners.oidc_providers()),
&state.reqwest_client,
&state.jwt_signature_pk_url,
&state.jwt_signature_pk_urls,
)
.await
.map_err(LeaderNodeError::OidcVerificationFailed)?;
Expand Down
39 changes: 23 additions & 16 deletions mpc-recovery/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![allow(clippy::result_large_err)]

use std::path::PathBuf;
use std::collections::HashMap;

use aes_gcm::aead::consts::U32;
use aes_gcm::aead::generic_array::GenericArray;
Expand Down Expand Up @@ -113,9 +114,9 @@ pub enum Cli {
/// GCP datastore URL
#[arg(long, env("MPC_RECOVERY_GCP_DATASTORE_URL"))]
gcp_datastore_url: Option<String>,
/// URL to the public key used to sign JWT tokens
#[arg(long, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URL"))]
jwt_signature_pk_url: String,
/// URLs of the public keys used by all issuers
#[arg(long, value_parser = parse_json_str::<HashMap<String, String>>, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URLS"))]
jwt_signature_pk_urls: HashMap<String, String>,
/// Enables export of span data using opentelemetry protocol.
#[clap(flatten)]
logging_options: logging::Options,
Expand All @@ -142,9 +143,9 @@ pub enum Cli {
/// GCP datastore URL
#[arg(long, env("MPC_RECOVERY_GCP_DATASTORE_URL"))]
gcp_datastore_url: Option<String>,
/// URL to the public key used to sign JWT tokens
#[arg(long, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URL"))]
jwt_signature_pk_url: String,
/// URLs of the public keys used by all issuers
#[arg(long, value_parser = parse_json_str::<HashMap<String, String>>, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URLS"))]
jwt_signature_pk_urls: HashMap<String, String>,
/// Enables export of span data using opentelemetry protocol.
#[clap(flatten)]
logging_options: logging::Options,
Expand Down Expand Up @@ -203,7 +204,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> {
fast_auth_partners_filepath: partners_filepath,
gcp_project_id,
gcp_datastore_url,
jwt_signature_pk_url,
jwt_signature_pk_urls,
logging_options,
} => {
let _subscriber_guard = logging::subscribe_global(
Expand Down Expand Up @@ -231,7 +232,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> {
near_root_account,
account_creator_signer,
partners,
jwt_signature_pk_url,
jwt_signature_pk_urls
};

run_leader_node(config).await;
Expand All @@ -244,7 +245,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> {
web_port,
gcp_project_id,
gcp_datastore_url,
jwt_signature_pk_url,
jwt_signature_pk_urls,
logging_options,
} => {
let _subscriber_guard = logging::subscribe_global(
Expand Down Expand Up @@ -272,7 +273,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> {
node_key: sk_share,
cipher,
port: web_port,
jwt_signature_pk_url,
jwt_signature_pk_urls
};
run_sign_node(config).await;
}
Expand Down Expand Up @@ -428,7 +429,7 @@ impl Cli {
fast_auth_partners_filepath,
gcp_project_id,
gcp_datastore_url,
jwt_signature_pk_url,
jwt_signature_pk_urls,
logging_options,
} => {
let mut buf = vec![
Expand All @@ -445,8 +446,6 @@ impl Cli {
account_creator_id.to_string(),
"--gcp-project-id".to_string(),
gcp_project_id,
"--jwt-signature-pk-url".to_string(),
jwt_signature_pk_url,
];

if let Some(partners) = fast_auth_partners {
Expand All @@ -465,6 +464,11 @@ impl Cli {
buf.push("--sign-nodes".to_string());
buf.push(sign_node);
}

let jwt_signature_pk_urls = serde_json::to_string(&jwt_signature_pk_urls).unwrap();
buf.push("--jwt-signature-pk-urls".to_string());
buf.push(jwt_signature_pk_urls);

let account_creator_sk = serde_json::to_string(&account_creator_sk).unwrap();
buf.push("--account-creator-sk".to_string());
buf.push(account_creator_sk);
Expand All @@ -480,7 +484,7 @@ impl Cli {
sk_share,
gcp_project_id,
gcp_datastore_url,
jwt_signature_pk_url,
jwt_signature_pk_urls,
logging_options,
} => {
let mut buf = vec![
Expand All @@ -493,8 +497,6 @@ impl Cli {
web_port.to_string(),
"--gcp-project-id".to_string(),
gcp_project_id,
"--jwt-signature-pk-url".to_string(),
jwt_signature_pk_url,
];
if let Some(key) = cipher_key {
buf.push("--cipher-key".to_string());
Expand All @@ -508,6 +510,11 @@ impl Cli {
buf.push("--gcp-datastore-url".to_string());
buf.push(gcp_datastore_url);
}

let jwt_signature_pk_urls = serde_json::to_string(&jwt_signature_pk_urls).unwrap();
buf.push("--jwt-signature-pk-urls".to_string());
buf.push(jwt_signature_pk_urls);

buf.extend(logging_options.into_str_args());

buf
Expand Down
72 changes: 59 additions & 13 deletions mpc-recovery/src/oauth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use jsonwebtoken::{Algorithm, DecodingKey};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;

use crate::firewall::allowed::OidcProviderList;
Expand All @@ -13,15 +16,21 @@ pub async fn verify_oidc_token(
token: &OidcToken,
oidc_providers: Option<&OidcProviderList>,
client: &reqwest::Client,
jwt_signature_pk_url: &str,
jwt_signature_pk_urls: &HashMap<String, String>,
) -> anyhow::Result<IdTokenClaims> {
let public_keys = get_pagoda_firebase_public_keys(client, jwt_signature_pk_url)
let (_, claims, _) = token.decode_unverified()?;
let issuer = &claims.iss;

let jwks_url = jwt_signature_pk_urls.get(issuer)
.ok_or_else(|| anyhow::anyhow!("No JWKS URL found for issuer: {}", issuer))?;

let public_keys = get_public_keys(client, jwks_url)
.await
.map_err(|e| anyhow::anyhow!("failed to get Firebase public key: {e}"))?;
tracing::info!("verify_oidc_token firebase public keys: {public_keys:?}");
.map_err(|e| anyhow::anyhow!("failed to get public keys: {e}"))?;
tracing::info!("verify_oidc_token public keys: {public_keys:?}");

let mut last_occured_error =
anyhow::anyhow!("Unexpected error. Firebase public keys not found");
anyhow::anyhow!("Unexpected error. Public keys not found");
for public_key in public_keys {
match validate_jwt(token, public_key.as_bytes(), oidc_providers) {
Ok(claims) => {
Expand Down Expand Up @@ -99,13 +108,49 @@ impl IdTokenClaims {
}
}

pub async fn get_pagoda_firebase_public_keys(
client: &reqwest::Client,
jwt_signature_pk_url: &str,
) -> anyhow::Result<Vec<String>> {
let response = client.get(jwt_signature_pk_url).send().await?;
let json: HashMap<String, String> = response.json().await?;
Ok(json.into_values().collect())
pub async fn get_public_keys(client: &reqwest::Client, jwks_url: &str) -> Result<Vec<String>> {
let response = client
.get(jwks_url)
.send()
.await
.context("Failed to send request")?;

let json: Value = response.json().await.context("Failed to parse JSON")?;

match json {
Value::Object(obj) if obj.contains_key("keys") => parse_jwks_format(&obj),
Value::Object(obj) => parse_firebase_format(&obj),
_ => {
tracing::warn!("Unexpected response format from {}", jwks_url);
Ok(vec![])
}
}
}

fn parse_jwks_format(obj: &serde_json::Map<String, Value>) -> Result<Vec<String>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test this

obj["keys"]
.as_array()
.context("'keys' is not an array")?
.iter()
.filter_map(|key| match (key["n"].as_str(), key["e"].as_str()) {
(Some(n), Some(e)) => Some(format_rsa_key(n, e)),
_ => None,
})
.collect::<Result<Vec<_>>>()
}

fn parse_firebase_format(obj: &serde_json::Map<String, Value>) -> Result<Vec<String>> {
Ok(obj
.values()
.filter_map(|value| value.as_str().map(String::from))
.collect())
}

fn format_rsa_key(n: &str, e: &str) -> Result<String> {
Ok(format!(
"-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----",
BASE64.encode(format!("{}:{}", n, e))
))
}

#[cfg(test)]
Expand All @@ -124,7 +169,8 @@ mod tests {
let url =
"https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com";
let client = reqwest::Client::new();
let pk = get_pagoda_firebase_public_keys(&client, url).await.unwrap();
let pk = get_public_keys(&client, url).await.unwrap();

assert!(!pk.is_empty());
}

Expand Down
Loading
Loading