Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth mock #24

Merged
merged 13 commits into from
Apr 7, 2023
414 changes: 352 additions & 62 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion mpc-recovery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ serde = "1"
serde_json = "1"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
oauth2 = "4.3.0"
51 changes: 34 additions & 17 deletions mpc-recovery/src/leader_node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::msg::{LeaderRequest, LeaderResponse, SigShareRequest, SigShareResponse};
use crate::oauth::{OAuthTokenVerifier, UniversalTokenVerifier};
use crate::NodeId;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use futures::stream::FuturesUnordered;
Expand Down Expand Up @@ -32,7 +33,7 @@ pub async fn run(
};

let app = Router::new()
.route("/submit", post(submit))
.route("/submit", post(submit::<UniversalTokenVerifier>))
.with_state(state);

let addr = SocketAddr::from(([0, 0, 0, 0], port));
Expand All @@ -58,14 +59,24 @@ async fn parse(response_future: ResponseFuture) -> anyhow::Result<SigShareRespon
}

#[tracing::instrument(level = "debug", skip_all, fields(id = state.id))]
async fn submit(
async fn submit<T: OAuthTokenVerifier>(
State(state): State<LeaderState>,
Json(request): Json<LeaderRequest>,
) -> (StatusCode, Json<LeaderResponse>) {
tracing::info!(payload = request.payload, "submit request");

// TODO: run some check that the payload makes sense, fail if not
tracing::debug!("approved");
// TODO: extract access token from payload
let access_token = "validToken";
match T::verify_token(access_token).await {
Some(_) => {
tracing::info!("access token is valid");
// continue execution
}
None => {
tracing::error!("access token verification failed");
return (StatusCode::UNAUTHORIZED, Json(LeaderResponse::Err));
}
}

let sig_share_request = SigShareRequest {
payload: request.payload.clone(),
Expand Down Expand Up @@ -100,42 +111,48 @@ async fn submit(
let mut sig_shares = BTreeMap::new();
sig_shares.insert(state.id, state.sk_share.sign(&request.payload));
for response_future in response_futures {
let response = match parse(response_future).await {
Ok(response) => response,
let (node_id, sig_share) = match parse(response_future).await {
Ok(response) => match response {
SigShareResponse::Ok { node_id, sig_share } => (node_id, sig_share),
SigShareResponse::Err => {
tracing::error!("Received an error response");
continue;
}
},
Err(err) => {
tracing::error!(%err, "failed to get response");
tracing::error!(%err, "Failed to get response");
continue;
}
};

if state
.pk_set
.public_key_share(response.node_id)
.verify(&response.sig_share, &request.payload)
.public_key_share(node_id)
.verify(&sig_share, &request.payload)
{
match sig_shares.entry(response.node_id) {
match sig_shares.entry(node_id) {
Entry::Vacant(e) => {
tracing::debug!(?response, "received valid signature share");
e.insert(response.sig_share);
tracing::debug!(?sig_share, "received valid signature share");
e.insert(sig_share);
}
Entry::Occupied(e) if e.get() == &response.sig_share => {
Entry::Occupied(e) if e.get() == &sig_share => {
tracing::error!(
node_id = response.node_id,
node_id,
sig_share = ?e.get(),
"received a duplicate share"
);
}
Entry::Occupied(e) => {
tracing::error!(
node_id = response.node_id,
node_id = node_id,
sig_share_1 = ?e.get(),
sig_share_2 = ?response.sig_share,
sig_share_2 = ?sig_share,
"received two different valid shares for the same node (should be impossible)"
);
}
}
} else {
tracing::error!(?response, "received invalid signature",);
tracing::error!("received invalid signature",);
}

if sig_shares.len() > state.pk_set.threshold() {
Expand Down
1 change: 1 addition & 0 deletions mpc-recovery/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use threshold_crypto::{PublicKeySet, SecretKeySet, SecretKeyShare};

mod leader_node;
pub mod msg;
mod oauth;
mod sign_node;

type NodeId = u64;
Expand Down
2 changes: 2 additions & 0 deletions mpc-recovery/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod oauth;

use clap::Parser;
use threshold_crypto::{serde_impl::SerdeSecret, PublicKeySet, SecretKeyShare};

Expand Down
10 changes: 7 additions & 3 deletions mpc-recovery/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ pub struct SigShareRequest {
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SigShareResponse {
pub node_id: NodeId,
pub sig_share: SignatureShare,
#[allow(clippy::large_enum_variant)]
pub enum SigShareResponse {
Ok {
node_id: NodeId,
sig_share: SignatureShare,
},
Err,
}

mod hex_sig_share {
Expand Down
115 changes: 115 additions & 0 deletions mpc-recovery/src/oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#[async_trait::async_trait]
pub trait OAuthTokenVerifier {
async fn verify_token(token: &str) -> Option<&str>;
}

pub enum SupportedTokenVerifiers {
GoogleTokenVerifier,
TestTokenVerifier,
}

/* Universal token verifier */
pub struct UniversalTokenVerifier {}

#[async_trait::async_trait]
impl OAuthTokenVerifier for UniversalTokenVerifier {
async fn verify_token(token: &str) -> Option<&str> {
// TODO: here we assume that verifier type can be determined from the token
match get_token_verifier_type(token) {
SupportedTokenVerifiers::GoogleTokenVerifier => {
return GoogleTokenVerifier::verify_token(token).await;
}
SupportedTokenVerifiers::TestTokenVerifier => {
return TestTokenVerifier::verify_token(token).await;
}
}
}
}

fn get_token_verifier_type(token: &str) -> SupportedTokenVerifiers {
match token.len() {
// TODO: add real token type detection
0 => {
tracing::info!("Using GoogleTokenVerifier");
SupportedTokenVerifiers::GoogleTokenVerifier
}
_ => {
tracing::info!("Using TestTokenVerifier");
SupportedTokenVerifiers::TestTokenVerifier
}
}
}

/* Google verifier */
pub struct GoogleTokenVerifier {}

#[async_trait::async_trait]
impl OAuthTokenVerifier for GoogleTokenVerifier {
// TODO: replace with real implementation
async fn verify_token(token: &str) -> Option<&str> {
match token {
"validToken" => {
tracing::info!("GoogleTokenVerifier: access token is valid");
Some("testAccountId")
}
_ => {
tracing::info!("GoogleTokenVerifier: access token verification failed");
None
}
}
}
}

/* Test verifier */
pub struct TestTokenVerifier {}

#[async_trait::async_trait]
impl OAuthTokenVerifier for TestTokenVerifier {
async fn verify_token(token: &str) -> Option<&str> {
match token {
"validToken" => {
tracing::info!("TestTokenVerifier: access token is valid");
Some("testAccountId")
}
_ => {
tracing::info!("TestTokenVerifier: access token verification failed");
None
}
}
}
}

#[tokio::test]
async fn test_verify_token_valid() {
let token = "validToken";
let account_id = TestTokenVerifier::verify_token(token).await.unwrap();
assert_eq!(account_id, "testAccountId");
}

#[tokio::test]
async fn test_verify_token_invalid_with_test_verifier() {
let token = "invalid";
let account_id = TestTokenVerifier::verify_token(token).await;
assert_eq!(account_id, None);
}

#[tokio::test]
async fn test_verify_token_valid_with_test_verifier() {
let token = "validToken";
let account_id = TestTokenVerifier::verify_token(token).await.unwrap();
assert_eq!(account_id, "testAccountId");
}

#[tokio::test]
async fn test_verify_token_invalid_with_universal_verifier() {
let token = "invalid";
let account_id = UniversalTokenVerifier::verify_token(token).await;
assert_eq!(account_id, None);
}

#[tokio::test]
async fn test_verify_token_valid_with_universal_verifier() {
let token = "validToken";
let account_id = UniversalTokenVerifier::verify_token(token).await.unwrap();
assert_eq!(account_id, "testAccountId");
}
31 changes: 21 additions & 10 deletions mpc-recovery/src/sign_node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::msg::{SigShareRequest, SigShareResponse};
use crate::oauth::{OAuthTokenVerifier, UniversalTokenVerifier};
use crate::NodeId;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use std::net::SocketAddr;
Expand All @@ -15,7 +16,9 @@ pub async fn run(id: NodeId, pk_set: PublicKeySet, sk_share: SecretKeyShare, por

let state = SignNodeState { id, sk_share };

let app = Router::new().route("/sign", post(sign)).with_state(state);
let app = Router::new()
.route("/sign", post(sign::<UniversalTokenVerifier>))
.with_state(state);

let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::debug!(?addr, "starting http server");
Expand All @@ -32,18 +35,26 @@ struct SignNodeState {
}

#[tracing::instrument(level = "debug", skip_all, fields(id = state.id))]
async fn sign(
async fn sign<T: OAuthTokenVerifier>(
State(state): State<SignNodeState>,
Json(request): Json<SigShareRequest>,
) -> (StatusCode, Json<SigShareResponse>) {
tracing::info!(payload = request.payload, "sign request");

// TODO: run some check that the payload makes sense, fail if not
tracing::debug!("approved");

let response = SigShareResponse {
node_id: state.id,
sig_share: state.sk_share.sign(request.payload),
};
(StatusCode::OK, Json(response))
// TODO: extract access token from payload
let access_token = "validToken";
match T::verify_token(access_token).await {
Some(_) => {
tracing::debug!("access token is valid");
let response = SigShareResponse::Ok {
node_id: state.id,
sig_share: state.sk_share.sign(request.payload),
};
(StatusCode::OK, Json(response))
}
None => {
tracing::debug!("access token verification failed");
(StatusCode::UNAUTHORIZED, Json(SigShareResponse::Err))
}
}
}