Skip to content

Commit

Permalink
Request Identity (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Jul 23, 2024
1 parent 7ae644f commit 5e7825f
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ jobs:
with:
target: ${{ matrix.target }}
rustflags: ""
- name: Install musl
if: ${{ endsWith( matrix.target, '-musl' ) }}
run: sudo apt install musl-tools

- name: Install protoc
uses: ./.github/actions/install-protoc
Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ tracing = "0.1.40"
paste = "1.0.15"
strum = { version = "0.26", features = ["derive"] }
base64 = "0.22"
bs58 = { version = "0.5.0" }
sha2 = "0.11.0-pre.3"
ring = { version = "0.17.8" }
jsonwebtoken = { version = "9.3.0" }
serde = { version = "1.0.204", features = ["derive"] }

[dev-dependencies]
googletest = "0.11.0"
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod request_identity;
mod service_protocol;
mod vm;

use std::borrow::Cow;
use std::time::Duration;

pub use request_identity::*;
pub use vm::CoreVM;

#[derive(Debug, Eq, PartialEq)]
Expand Down
242 changes: 242 additions & 0 deletions src/request_identity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH.
// All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use jsonwebtoken::{DecodingKey, Validation};
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;

const SIGNATURE_SCHEME_HEADER: &str = "x-restate-signature-scheme";
const SIGNATURE_SCHEME_V1: &str = "v1";
const SIGNATURE_SCHEME_UNSIGNED: &str = "unsigned";
const SIGNATURE_JWT_V1_HEADER: &str = "x-restate-jwt-v1";
const IDENTITY_V1_PREFIX: &str = "publickeyv1_";

pub trait IdentityHeaderMap {
type Error;

fn extract(&self, name: &str) -> Result<Option<&str>, Self::Error>;
}

impl IdentityHeaderMap for HashMap<String, String> {
type Error = Infallible;

fn extract(&self, name: &str) -> Result<Option<&str>, Self::Error> {
Ok(self.get(name).map(|x| x.as_str()))
}
}

#[derive(Debug, thiserror::Error)]
pub enum KeyError {
#[error("identity v1 jwt public keys are expected to start with {IDENTITY_V1_PREFIX}")]
MissingPrefix,
#[error("cannot decode the public key with base58: {0}")]
Base58(#[from] bs58::decode::Error),
#[error("decoded key should have length of 32, was {0}")]
BadLength(usize),
}

#[derive(Debug, thiserror::Error)]
pub enum VerifyError {
#[error("cannot read header {0}: {1}")]
ExtractHeader(
&'static str,
#[source] Box<dyn std::error::Error + Sync + Send + 'static>,
),
#[error("missing header: {0}")]
MissingHeader(&'static str),
#[error("bad {SIGNATURE_SCHEME_HEADER} header, unexpected value {0}")]
BadSchemeHeader(String),
#[error("got unsigned request, expecting only signed requests matching the configured keys")]
UnsignedRequest,
#[error("invalid JWT: {0}")]
InvalidJWT(#[from] jsonwebtoken::errors::Error),
}

pub struct IdentityVerifier {
validation: Validation,
keys: Vec<DecodingKey>,
}

impl Default for IdentityVerifier {
fn default() -> Self {
let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA);
validation.required_spec_claims =
HashSet::from(["aud".into(), "exp".into(), "iat".into(), "nbf".into()]);
validation.leeway = 0;
validation.reject_tokens_expiring_in_less_than = 0;
validation.validate_exp = true;
validation.validate_nbf = true;
validation.validate_aud = true;

Self {
validation,
keys: vec![],
}
}
}

#[derive(Deserialize)]
struct Claims {}

impl IdentityVerifier {
pub fn new(keys: &[&str]) -> Result<Self, KeyError> {
let mut iv = IdentityVerifier::default();
for k in keys {
iv = iv.with_key(k)?;
}
Ok(iv)
}

pub fn with_key(mut self, key: &str) -> Result<Self, KeyError> {
self.keys.push(Self::parse_key(key)?);
Ok(Self {
validation: self.validation,
keys: self.keys,
})
}

fn parse_key(key: &str) -> Result<DecodingKey, KeyError> {
if !key.starts_with(IDENTITY_V1_PREFIX) {
return Err(KeyError::MissingPrefix);
}

let decoded_key = bs58::decode(key.split_at(IDENTITY_V1_PREFIX.len()).1).into_vec()?;
if decoded_key.len() != 32 {
return Err(KeyError::BadLength(decoded_key.len()));
}

Ok(DecodingKey::from_ed_der(&decoded_key))
}

fn check_v1_keys(&self, jwt_token: &str, path: &str) -> Result<(), VerifyError> {
let mut validation = self.validation.clone();
validation.set_audience(&[path]);
let mut res = Ok(());
for k in &self.keys {
if let Err(e) = jsonwebtoken::decode::<Claims>(jwt_token, k, &validation) {
res = Err(e);
} else {
return Ok(());
}
}
res.map_err(VerifyError::InvalidJWT)
}

pub fn verify_identity<I>(&self, hm: &I, path: &str) -> Result<(), VerifyError>
where
I: IdentityHeaderMap,
<I as IdentityHeaderMap>::Error: std::error::Error + Send + Sync + 'static,
{
if self.keys.is_empty() {
return Ok(());
}

let scheme_header = hm
.extract(SIGNATURE_SCHEME_HEADER)
.map_err(|e| VerifyError::ExtractHeader(SIGNATURE_SCHEME_HEADER, Box::new(e)))?
.ok_or(VerifyError::MissingHeader(SIGNATURE_SCHEME_HEADER))?;

match scheme_header {
SIGNATURE_SCHEME_V1 => {
let jwt = hm
.extract(SIGNATURE_JWT_V1_HEADER)
.map_err(|e| VerifyError::ExtractHeader(SIGNATURE_JWT_V1_HEADER, Box::new(e)))?
.ok_or(VerifyError::MissingHeader(SIGNATURE_JWT_V1_HEADER))?;

self.check_v1_keys(jwt, path)
}
SIGNATURE_SCHEME_UNSIGNED => Err(VerifyError::UnsignedRequest),
scheme => Err(VerifyError::BadSchemeHeader(scheme.to_owned())),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
use serde::Serialize;
use std::time::SystemTime;

#[derive(Serialize)]
pub(crate) struct Claims<'aud> {
aud: &'aud str,
exp: u64,
iat: u64,
nbf: u64,
}

#[test]
fn verify() {
let (jwt, identity_key) = mock_token_and_key();

let verifier = IdentityVerifier::new(&[&identity_key]).unwrap();

let headers: HashMap<String, String> = [
(
SIGNATURE_SCHEME_HEADER.to_owned(),
SIGNATURE_SCHEME_V1.to_owned(),
),
(SIGNATURE_JWT_V1_HEADER.to_owned(), jwt),
]
.into_iter()
.collect();

verifier.verify_identity(&headers, "/invoke/foo").unwrap();
}

#[test]
fn bad_key() {
let verifier =
IdentityVerifier::new(&["publickeyv1_ChjENKeMvCtRnqG2mrBK1HmPKufgFUc98K8B3ononQvp"])
.unwrap();

let headers: HashMap<String, String> = [
(SIGNATURE_SCHEME_HEADER.to_owned(), SIGNATURE_SCHEME_V1.to_owned()),
(SIGNATURE_JWT_V1_HEADER.to_owned(), "eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSIsImtpZCI6InB1YmxpY2tleXYxX0FmUXdtd2ZnRVpocldwdnY4TjUyU0hwUnRacUdHYUZyNEFaTjZxdFlXU2lZIn0.eyJhdWQiOiIvaW52b2tlL2ZvbyIsImV4cCI6MTcyMTY2MjcwOSwiaWF0IjoxNzIxNjYyNjQ5LCJuYmYiOjE3MjE2NjI1ODl9.UBReG_9cdFQ5VcaJxAV0rM8U_zaNw9kMXJZt691SiI0SWw7Ucmz5Zz3wtmVUc1jrkNsnTDhNEvOFGEZoKXTMCQ".to_owned())
].into_iter().collect();

assert!(verifier.verify_identity(&headers, "/invoke/foo").is_err())
}

fn mock_token_and_key() -> (String, String) {
let serialized_keypair = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new()).unwrap();
let keypair = Ed25519KeyPair::from_pkcs8(serialized_keypair.as_ref()).unwrap();

let kid = format!(
"{IDENTITY_V1_PREFIX}{}",
bs58::encode(keypair.public_key()).into_string()
);
let signing_key = jsonwebtoken::EncodingKey::from_ed_der(serialized_keypair.as_ref());

let header = jsonwebtoken::Header {
typ: Some("JWT".into()),
kid: Some(kid.clone()),
alg: jsonwebtoken::Algorithm::EdDSA,
..Default::default()
};
let unix_seconds = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("duration since Unix epoch should be well-defined")
.as_secs();
let claims = Claims {
aud: "/invoke/foo",
nbf: unix_seconds.saturating_sub(60),
iat: unix_seconds,
exp: unix_seconds.saturating_add(60),
};
let jwt = jsonwebtoken::encode(&header, &claims, &signing_key).unwrap();

(jwt, kid)
}
}

0 comments on commit 5e7825f

Please sign in to comment.