Skip to content

Commit

Permalink
Merge pull request #24 from near/serhii/OAuth
Browse files Browse the repository at this point in the history
OAuth mock
  • Loading branch information
volovyks authored Apr 7, 2023
2 parents 1acafab + 04be724 commit b83f102
Show file tree
Hide file tree
Showing 8 changed files with 534 additions and 93 deletions.
414 changes: 352 additions & 62 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion mpc-recovery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ serde = "1"
serde_json = "1"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
oauth2 = "4.3.0"
51 changes: 34 additions & 17 deletions mpc-recovery/src/leader_node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::msg::{LeaderRequest, LeaderResponse, SigShareRequest, SigShareResponse};
use crate::oauth::{OAuthTokenVerifier, UniversalTokenVerifier};
use crate::NodeId;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use futures::stream::FuturesUnordered;
Expand Down Expand Up @@ -32,7 +33,7 @@ pub async fn run(
};

let app = Router::new()
.route("/submit", post(submit))
.route("/submit", post(submit::<UniversalTokenVerifier>))
.with_state(state);

let addr = SocketAddr::from(([0, 0, 0, 0], port));
Expand All @@ -58,14 +59,24 @@ async fn parse(response_future: ResponseFuture) -> anyhow::Result<SigShareRespon
}

#[tracing::instrument(level = "debug", skip_all, fields(id = state.id))]
async fn submit(
async fn submit<T: OAuthTokenVerifier>(
State(state): State<LeaderState>,
Json(request): Json<LeaderRequest>,
) -> (StatusCode, Json<LeaderResponse>) {
tracing::info!(payload = request.payload, "submit request");

// TODO: run some check that the payload makes sense, fail if not
tracing::debug!("approved");
// TODO: extract access token from payload
let access_token = "validToken";
match T::verify_token(access_token).await {
Some(_) => {
tracing::info!("access token is valid");
// continue execution
}
None => {
tracing::error!("access token verification failed");
return (StatusCode::UNAUTHORIZED, Json(LeaderResponse::Err));
}
}

let sig_share_request = SigShareRequest {
payload: request.payload.clone(),
Expand Down Expand Up @@ -100,42 +111,48 @@ async fn submit(
let mut sig_shares = BTreeMap::new();
sig_shares.insert(state.id, state.sk_share.sign(&request.payload));
for response_future in response_futures {
let response = match parse(response_future).await {
Ok(response) => response,
let (node_id, sig_share) = match parse(response_future).await {
Ok(response) => match response {
SigShareResponse::Ok { node_id, sig_share } => (node_id, sig_share),
SigShareResponse::Err => {
tracing::error!("Received an error response");
continue;
}
},
Err(err) => {
tracing::error!(%err, "failed to get response");
tracing::error!(%err, "Failed to get response");
continue;
}
};

if state
.pk_set
.public_key_share(response.node_id)
.verify(&response.sig_share, &request.payload)
.public_key_share(node_id)
.verify(&sig_share, &request.payload)
{
match sig_shares.entry(response.node_id) {
match sig_shares.entry(node_id) {
Entry::Vacant(e) => {
tracing::debug!(?response, "received valid signature share");
e.insert(response.sig_share);
tracing::debug!(?sig_share, "received valid signature share");
e.insert(sig_share);
}
Entry::Occupied(e) if e.get() == &response.sig_share => {
Entry::Occupied(e) if e.get() == &sig_share => {
tracing::error!(
node_id = response.node_id,
node_id,
sig_share = ?e.get(),
"received a duplicate share"
);
}
Entry::Occupied(e) => {
tracing::error!(
node_id = response.node_id,
node_id = node_id,
sig_share_1 = ?e.get(),
sig_share_2 = ?response.sig_share,
sig_share_2 = ?sig_share,
"received two different valid shares for the same node (should be impossible)"
);
}
}
} else {
tracing::error!(?response, "received invalid signature",);
tracing::error!("received invalid signature",);
}

if sig_shares.len() > state.pk_set.threshold() {
Expand Down
1 change: 1 addition & 0 deletions mpc-recovery/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use threshold_crypto::{PublicKeySet, SecretKeySet, SecretKeyShare};

mod leader_node;
pub mod msg;
mod oauth;
mod sign_node;

type NodeId = u64;
Expand Down
2 changes: 2 additions & 0 deletions mpc-recovery/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod oauth;

use clap::Parser;
use threshold_crypto::{serde_impl::SerdeSecret, PublicKeySet, SecretKeyShare};

Expand Down
10 changes: 7 additions & 3 deletions mpc-recovery/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ pub struct SigShareRequest {
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SigShareResponse {
pub node_id: NodeId,
pub sig_share: SignatureShare,
#[allow(clippy::large_enum_variant)]
pub enum SigShareResponse {
Ok {
node_id: NodeId,
sig_share: SignatureShare,
},
Err,
}

mod hex_sig_share {
Expand Down
115 changes: 115 additions & 0 deletions mpc-recovery/src/oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#[async_trait::async_trait]
pub trait OAuthTokenVerifier {
async fn verify_token(token: &str) -> Option<&str>;
}

pub enum SupportedTokenVerifiers {
GoogleTokenVerifier,
TestTokenVerifier,
}

/* Universal token verifier */
pub struct UniversalTokenVerifier {}

#[async_trait::async_trait]
impl OAuthTokenVerifier for UniversalTokenVerifier {
async fn verify_token(token: &str) -> Option<&str> {
// TODO: here we assume that verifier type can be determined from the token
match get_token_verifier_type(token) {
SupportedTokenVerifiers::GoogleTokenVerifier => {
return GoogleTokenVerifier::verify_token(token).await;
}
SupportedTokenVerifiers::TestTokenVerifier => {
return TestTokenVerifier::verify_token(token).await;
}
}
}
}

fn get_token_verifier_type(token: &str) -> SupportedTokenVerifiers {
match token.len() {
// TODO: add real token type detection
0 => {
tracing::info!("Using GoogleTokenVerifier");
SupportedTokenVerifiers::GoogleTokenVerifier
}
_ => {
tracing::info!("Using TestTokenVerifier");
SupportedTokenVerifiers::TestTokenVerifier
}
}
}

/* Google verifier */
pub struct GoogleTokenVerifier {}

#[async_trait::async_trait]
impl OAuthTokenVerifier for GoogleTokenVerifier {
// TODO: replace with real implementation
async fn verify_token(token: &str) -> Option<&str> {
match token {
"validToken" => {
tracing::info!("GoogleTokenVerifier: access token is valid");
Some("testAccountId")
}
_ => {
tracing::info!("GoogleTokenVerifier: access token verification failed");
None
}
}
}
}

/* Test verifier */
pub struct TestTokenVerifier {}

#[async_trait::async_trait]
impl OAuthTokenVerifier for TestTokenVerifier {
async fn verify_token(token: &str) -> Option<&str> {
match token {
"validToken" => {
tracing::info!("TestTokenVerifier: access token is valid");
Some("testAccountId")
}
_ => {
tracing::info!("TestTokenVerifier: access token verification failed");
None
}
}
}
}

#[tokio::test]
async fn test_verify_token_valid() {
let token = "validToken";
let account_id = TestTokenVerifier::verify_token(token).await.unwrap();
assert_eq!(account_id, "testAccountId");
}

#[tokio::test]
async fn test_verify_token_invalid_with_test_verifier() {
let token = "invalid";
let account_id = TestTokenVerifier::verify_token(token).await;
assert_eq!(account_id, None);
}

#[tokio::test]
async fn test_verify_token_valid_with_test_verifier() {
let token = "validToken";
let account_id = TestTokenVerifier::verify_token(token).await.unwrap();
assert_eq!(account_id, "testAccountId");
}

#[tokio::test]
async fn test_verify_token_invalid_with_universal_verifier() {
let token = "invalid";
let account_id = UniversalTokenVerifier::verify_token(token).await;
assert_eq!(account_id, None);
}

#[tokio::test]
async fn test_verify_token_valid_with_universal_verifier() {
let token = "validToken";
let account_id = UniversalTokenVerifier::verify_token(token).await.unwrap();
assert_eq!(account_id, "testAccountId");
}
31 changes: 21 additions & 10 deletions mpc-recovery/src/sign_node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::msg::{SigShareRequest, SigShareResponse};
use crate::oauth::{OAuthTokenVerifier, UniversalTokenVerifier};
use crate::NodeId;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use std::net::SocketAddr;
Expand All @@ -15,7 +16,9 @@ pub async fn run(id: NodeId, pk_set: PublicKeySet, sk_share: SecretKeyShare, por

let state = SignNodeState { id, sk_share };

let app = Router::new().route("/sign", post(sign)).with_state(state);
let app = Router::new()
.route("/sign", post(sign::<UniversalTokenVerifier>))
.with_state(state);

let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::debug!(?addr, "starting http server");
Expand All @@ -32,18 +35,26 @@ struct SignNodeState {
}

#[tracing::instrument(level = "debug", skip_all, fields(id = state.id))]
async fn sign(
async fn sign<T: OAuthTokenVerifier>(
State(state): State<SignNodeState>,
Json(request): Json<SigShareRequest>,
) -> (StatusCode, Json<SigShareResponse>) {
tracing::info!(payload = request.payload, "sign request");

// TODO: run some check that the payload makes sense, fail if not
tracing::debug!("approved");

let response = SigShareResponse {
node_id: state.id,
sig_share: state.sk_share.sign(request.payload),
};
(StatusCode::OK, Json(response))
// TODO: extract access token from payload
let access_token = "validToken";
match T::verify_token(access_token).await {
Some(_) => {
tracing::debug!("access token is valid");
let response = SigShareResponse::Ok {
node_id: state.id,
sig_share: state.sk_share.sign(request.payload),
};
(StatusCode::OK, Json(response))
}
None => {
tracing::debug!("access token verification failed");
(StatusCode::UNAUTHORIZED, Json(SigShareResponse::Err))
}
}
}

0 comments on commit b83f102

Please sign in to comment.