diff --git a/src/request_identity.rs b/src/request_identity.rs index fa3b4fd..ce8451a 100644 --- a/src/request_identity.rs +++ b/src/request_identity.rs @@ -137,12 +137,25 @@ impl IdentityVerifier { .map_err(|e| VerifyError::ExtractHeader(SIGNATURE_JWT_V1_HEADER, Box::new(e)))? .ok_or(VerifyError::MissingHeader(SIGNATURE_JWT_V1_HEADER))?; - self.check_v1_keys(jwt, path) + self.check_v1_keys(jwt, Self::normalise_path(path)) } SIGNATURE_SCHEME_UNSIGNED => Err(VerifyError::UnsignedRequest), scheme => Err(VerifyError::BadSchemeHeader(scheme.to_owned())), } } + + fn normalise_path<'a>(path: &'a str) -> &'a str { + let slashes: Vec = path.match_indices('/').map(|(index, _)| index).collect(); + if slashes.len() >= 3 + && &path[slashes[slashes.len() - 3]..slashes[slashes.len() - 2]] == "/invoke" + { + &path[slashes[slashes.len() - 3]..] + } else if !slashes.is_empty() && &path[slashes[slashes.len() - 1]..] == "/discover" { + &path[slashes[slashes.len() - 1]..] + } else { + path + } + } } #[cfg(test)] @@ -195,6 +208,31 @@ mod tests { assert!(verifier.verify_identity(&headers, "/invoke/foo").is_err()) } + #[test] + fn normalise_path() { + let paths = vec![ + ("/invoke/a/b", "/invoke/a/b"), + ("/foo/invoke/a/b", "/invoke/a/b"), + ("/foo/bar/invoke/a/b", "/invoke/a/b"), + ("/discover", "/discover"), + ("/foo/discover", "/discover"), + ("/foo/bar/discover", "/discover"), + ("/foo", "/foo"), + ("/invoke", "/invoke"), + ("/foo/invoke", "/foo/invoke"), + ("/invoke/a", "/invoke/a"), + ("/foo/invoke/a", "/foo/invoke/a"), + ("", ""), + ("/", "/"), + ("discover", "discover"), + ]; + + for (path, expected_path) in paths { + let actual_path = IdentityVerifier::normalise_path(path); + assert_eq!(expected_path, actual_path) + } + } + fn mock_token_and_key() -> (String, String) { let serialized_keypair = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new()).unwrap(); let keypair = Ed25519KeyPair::from_pkcs8(serialized_keypair.as_ref()).unwrap();