Skip to content

Commit

Permalink
improve xsuaa jwk fetch error handling (#1560)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: liga-oz <liga.ozolina@sap.com>
Co-authored-by: Manuel Fink <123368068+finkmanAtSap@users.noreply.github.com>
  • Loading branch information
liga-oz and finkmanAtSap authored Jun 12, 2024
1 parent d65e418 commit 580213f
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.token.validation.validators.JwtValidatorBuilder;
import com.sap.cloud.security.xsuaa.client.OAuth2ServiceException;
import com.sap.cloud.security.xsuaa.client.OAuth2TokenKeyService;
import org.apache.commons.io.IOUtils;
import org.junit.ClassRule;
Expand All @@ -23,12 +24,15 @@

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static com.sap.cloud.security.config.Service.XSUAA;
import static com.sap.cloud.security.config.ServiceConstants.XSUAA.VERIFICATION_KEY;
import static com.sap.cloud.security.test.SecurityTestRule.DEFAULT_CLIENT_ID;
import static com.sap.cloud.security.test.SecurityTestRule.DEFAULT_UAA_DOMAIN;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

/**
* Xsuaa integration test with single binding scenario.
Expand Down Expand Up @@ -110,9 +114,12 @@ public void createToken_withCorrectVerificationKey_tokenIsValid() throws IOExcep
.withProperty(VERIFICATION_KEY, publicKey)
.build();

OAuth2TokenKeyService tokenKeyServiceMock = Mockito.mock(OAuth2TokenKeyService.class);
when(tokenKeyServiceMock.retrieveTokenKeys(any(), (Map<String, String>) any())).thenThrow(
OAuth2ServiceException.class);
CombiningValidator<Token> tokenValidator = JwtValidatorBuilder.getInstance(configuration)
// mocked because we use the key from the verificationkey property here
.withOAuth2TokenKeyService(Mockito.mock(OAuth2TokenKeyService.class))
.withOAuth2TokenKeyService(tokenKeyServiceMock)
.build();

Token token = rule.getPreconfiguredJwtGenerator().createToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.token.validation.validators.JwtValidatorBuilder;
import com.sap.cloud.security.xsuaa.client.OAuth2ServiceException;
import com.sap.cloud.security.xsuaa.client.OAuth2TokenKeyService;
import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static com.sap.cloud.security.config.Service.XSUAA;
import static com.sap.cloud.security.config.ServiceConstants.XSUAA.VERIFICATION_KEY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

/**
* Performance test for java-security jwt token validation.
Expand Down Expand Up @@ -78,9 +84,12 @@ private CombiningValidator<Token> createOfflineTokenValidator() throws IOExcepti
OAuth2ServiceConfiguration configuration = createConfigurationBuilder()
.withProperty(VERIFICATION_KEY, publicKey)
.build();
OAuth2TokenKeyService tokenKeyServiceMock = Mockito.mock(OAuth2TokenKeyService.class);
when(tokenKeyServiceMock.retrieveTokenKeys(any(), (Map<String, String>) any())).thenThrow(
OAuth2ServiceException.class);
return JwtValidatorBuilder.getInstance(configuration)
// oAuth2TokenKeyService mocked because verificationkey property is used for offline token validation
.withOAuth2TokenKeyService((uri, zoneId) -> "{\"keys\": []}")
.withOAuth2TokenKeyService(tokenKeyServiceMock)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,27 @@ protected PublicKey getPublicKey(Token token, JwtSignatureAlgorithm algorithm)
key = fetchPublicKey(token, algorithm);
} catch (OAuth2ServiceException | InvalidKeySpecException | NoSuchAlgorithmException
| IllegalArgumentException e) {
LOGGER.error("Error fetching public key from XSUAA service: {}", e.getMessage());
if (!configuration.hasProperty(ServiceConstants.XSUAA.VERIFICATION_KEY)) {
throw e;
}
}

if (key == null && configuration.hasProperty(ServiceConstants.XSUAA.VERIFICATION_KEY)) {
String fallbackKey = configuration.getProperty(ServiceConstants.XSUAA.VERIFICATION_KEY);
try {
key = JsonWebKeyImpl.createPublicKeyFromPemEncodedPublicKey(JwtSignatureAlgorithm.RS256, fallbackKey);
} catch (NoSuchAlgorithmException | InvalidKeySpecException ex) {
throw new IllegalArgumentException(
"Fallback validation key supplied via " + ServiceConstants.XSUAA.VERIFICATION_KEY
+ " property in service credentials could not be used: {}",
ex);
if (configuration.hasProperty(ServiceConstants.XSUAA.VERIFICATION_KEY)) {
String fallbackKey = configuration.getProperty(ServiceConstants.XSUAA.VERIFICATION_KEY);
try {
key = JsonWebKeyImpl.createPublicKeyFromPemEncodedPublicKey(JwtSignatureAlgorithm.RS256,
fallbackKey);
} catch (NoSuchAlgorithmException | InvalidKeySpecException ex) {
IllegalArgumentException illegalArgEx = new IllegalArgumentException(
"Fallback validation key supplied via " + ServiceConstants.XSUAA.VERIFICATION_KEY
+ " property in service credentials could not be used: " + ex.getMessage());
if (e instanceof OAuth2ServiceException) {
e.addSuppressed(illegalArgEx);
throw e;
}
throw illegalArgEx;

}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.sap.cloud.security.token.Token;
import com.sap.cloud.security.token.XsuaaToken;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.xsuaa.client.OAuth2ServiceException;
import com.sap.cloud.security.xsuaa.client.OAuth2TokenKeyService;
import com.sap.cloud.security.xsuaa.client.OidcConfigurationService;
import com.sap.cloud.security.xsuaa.http.HttpHeaders;
Expand Down Expand Up @@ -76,6 +77,38 @@ public void xsuaa_RSASignatureMatchesJWKS() {
assertThat(cut.validate(xsuaaToken).isValid(), is(true));
}

@Test
public void onlineVerificationFails_noVerificationKey() throws IOException {
when(tokenKeyServiceMock
.retrieveTokenKeys(
URI.create("https://authentication.stagingaws.hanavlab.ondemand.com/token_keys?zid=uaa"),
Map.of(HttpHeaders.X_ZID, "uaa")))
.thenThrow(new OAuth2ServiceException("Error retrieving token keys"));

ValidationResult result = cut.validate(xsuaaToken);
assertThat(result.isErroneous(), is(true));
assertThat(result.getErrorDescription(), containsString("JWKS could not be fetched"));
}

@Test
public void onlineVerificationFails_withNotWorkingVerificationKey() throws IOException {
when(tokenKeyServiceMock
.retrieveTokenKeys(
URI.create("https://authentication.stagingaws.hanavlab.ondemand.com/token_keys?zid=uaa"),
Map.of(HttpHeaders.X_ZID, "uaa")))
.thenThrow(new OAuth2ServiceException("Error retrieving token keys"));
when(mockConfiguration.hasProperty("verificationkey")).thenReturn(true);
when(mockConfiguration.getProperty("verificationkey")).thenReturn(
"""
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAm1QaZzMjtEfHdimrHP3/
oQIDAQAB
-----END PUBLIC KEY-----""");
ValidationResult result = cut.validate(xsuaaToken);
assertThat(result.isErroneous(), is(true));
assertThat(result.getErrorDescription(), containsString("JWKS could not be fetched"));
}

@Test
public void generatedToken_SignatureMatchesVerificationkey() {
when(mockConfiguration.hasProperty("verificationkey")).thenReturn(true);
Expand All @@ -94,7 +127,7 @@ public void generatedToken_SignatureMatchesVerificationkey() {
}

@Test
public void validationFails_whenVerificationkeyIsInvalid() {
public void validationFails_whenVerificationKeyIsInvalid() {
when(mockConfiguration.hasProperty("verificationkey")).thenReturn(true);
when(mockConfiguration.getProperty("verificationkey")).thenReturn("INVALIDKEY");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
Expand Down Expand Up @@ -82,7 +83,11 @@ public Jwt decode(String encodedToken) {
default -> throw new BadJwtException("Tokens issued by " + token.getService() + " service aren't supported.");
}
if (validationResult.isErroneous()) {
throw new BadJwtException("The token is invalid: " + validationResult.getErrorDescription());
if (validationResult.getErrorDescription().contains("JWKS could not be fetched")) {
throw new JwtException(validationResult.getErrorDescription());
} else {
throw new BadJwtException("The token is invalid: " + validationResult.getErrorDescription());
}
}
logger.debug("Token issued by {} service was successfully validated.", token.getService());
return jwt;
Expand All @@ -99,5 +104,4 @@ public static Jwt parseJwt(Token token) {
return new Jwt(token.getTokenValue(), token.getNotBefore(), token.getExpiration(),
token.getHeaders(), token.getClaims());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.sap.cloud.security.config.OAuth2ServiceConfiguration;
import com.sap.cloud.security.test.JwtGenerator;
import com.sap.cloud.security.token.SecurityContext;
import com.sap.cloud.security.token.Token;
import com.sap.cloud.security.token.TokenClaims;
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResults;
import com.sap.cloud.security.token.validation.validators.JwtValidatorBuilder;
import com.sap.cloud.security.x509.X509Certificate;
import org.apache.commons.io.IOUtils;
import org.assertj.core.api.Assertions;
Expand All @@ -24,6 +26,7 @@
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

Expand Down Expand Up @@ -127,6 +130,17 @@ void decodeWithCorruptToken_throwsBadJwtException() {
assertThrows(BadJwtException.class, () -> cut.decode("Bearerabc"));
}

@Test
void decode_cantRetrieveJWK() {
OAuth2ServiceConfiguration configuration = Mockito.mock(OAuth2ServiceConfiguration.class);
when(configuration.getService()).thenReturn(XSUAA);
when(configuration.getClientId()).thenReturn("theClientId");
CombiningValidator<Token> xsuaaValidators = JwtValidatorBuilder.getInstance(configuration).build();
HybridJwtDecoder cut = new HybridJwtDecoder(xsuaaValidators, null);
String encodedToken = JwtGenerator.getInstance(XSUAA, "theClientId").createToken().getTokenValue();
assertThrows(JwtException.class, () -> cut.decode(encodedToken));
}

@Test
void instantiateForXsuaaOnly() {
cut = new HybridJwtDecoder(combiningValidator, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ private Jwt verifyToken(JWT jwt) {
validateJwksParameters(kid, uaaDomain);

return verifyToken(jwt.getParsedString(), kid, uaaDomain, getZid(jwt));
} catch (BadJwtException e) {
} catch (JwtException e) {
if (e.getMessage().contains("Couldn't retrieve remote JWK set")
|| e.getMessage().contains("Cannot verify with online token key, uaadomain is")) {
logger.debug(e.getMessage());
logger.error(e.getMessage());
return tryToVerifyWithVerificationKey(jwt.getParsedString(), e);
} else {
throw e;
Expand Down Expand Up @@ -169,13 +169,8 @@ private Jwt verifyToken(String token, String kid, String uaaDomain, String zid)
}
}

try {
return verifyWithKey(token, jku, kid);
} catch (JwtValidationException ex) {
throw ex;
} catch (JwtException ex) {
throw new BadJwtException("JWT verification failed: " + ex.getMessage());
}

}

private void validateJwksParameters(String kid, String uaadomain) {
Expand Down Expand Up @@ -232,18 +227,19 @@ private Jwt tryToVerifyWithVerificationKey(String token, JwtException verificati
if (!hasText(verificationKey)) {
throw verificationException;
}
return verifyWithVerificationKey(token, verificationKey);
return verifyWithVerificationKey(token, verificationKey, verificationException);
}

private Jwt verifyWithVerificationKey(String token, String verificationKey) {
private Jwt verifyWithVerificationKey(String token, String verificationKey,
JwtException onlineVerificationException) {
try {
RSAPublicKey rsaPublicKey = createPublicKey(verificationKey);
NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey(rsaPublicKey).build();
decoder.setJwtValidator(tokenValidators);
return decoder.decode(token);
} catch (NoSuchAlgorithmException | IllegalArgumentException | InvalidKeySpecException | BadJwtException e) {
logger.debug("Jwt signature validation with fallback verificationkey failed: {}", e.getMessage());
throw new BadJwtException("Jwt validation with fallback verificationkey failed");
} catch (NoSuchAlgorithmException | IllegalArgumentException | InvalidKeySpecException e) {
logger.error("Jwt signature validation with fallback verificationkey failed: {}", e.getMessage());
throw new JwtException("Jwt validation with fallback verificationkey failed", onlineVerificationException);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
Expand Down Expand Up @@ -101,7 +100,7 @@ public void decode_withInvalidFallbackVerificationKey_withoutUaaDomain() {

final JwtDecoder cut = new XsuaaJwtDecoderBuilder(config).build();

assertThatThrownBy(() -> cut.decode(rsaToken)).isInstanceOf(BadJwtException.class)
assertThatThrownBy(() -> cut.decode(rsaToken)).isInstanceOf(JwtException.class)
.hasMessageContaining("Jwt validation with fallback verificationkey failed");
}

Expand Down

0 comments on commit 580213f

Please sign in to comment.