Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve xsuaa jwk fetch error handling #1560

Merged
merged 7 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.token.validation.validators.JwtValidatorBuilder;
import com.sap.cloud.security.xsuaa.client.OAuth2ServiceException;
import com.sap.cloud.security.xsuaa.client.OAuth2TokenKeyService;
import org.apache.commons.io.IOUtils;
import org.junit.ClassRule;
Expand All @@ -23,12 +24,15 @@

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static com.sap.cloud.security.config.Service.XSUAA;
import static com.sap.cloud.security.config.ServiceConstants.XSUAA.VERIFICATION_KEY;
import static com.sap.cloud.security.test.SecurityTestRule.DEFAULT_CLIENT_ID;
import static com.sap.cloud.security.test.SecurityTestRule.DEFAULT_UAA_DOMAIN;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

/**
* Xsuaa integration test with single binding scenario.
Expand Down Expand Up @@ -110,9 +114,12 @@ public void createToken_withCorrectVerificationKey_tokenIsValid() throws IOExcep
.withProperty(VERIFICATION_KEY, publicKey)
.build();

OAuth2TokenKeyService tokenKeyServiceMock = Mockito.mock(OAuth2TokenKeyService.class);
when(tokenKeyServiceMock.retrieveTokenKeys(any(), (Map<String, String>) any())).thenThrow(
OAuth2ServiceException.class);
CombiningValidator<Token> tokenValidator = JwtValidatorBuilder.getInstance(configuration)
// mocked because we use the key from the verificationkey property here
.withOAuth2TokenKeyService(Mockito.mock(OAuth2TokenKeyService.class))
.withOAuth2TokenKeyService(tokenKeyServiceMock)
.build();

Token token = rule.getPreconfiguredJwtGenerator().createToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.token.validation.validators.JwtValidatorBuilder;
import com.sap.cloud.security.xsuaa.client.OAuth2ServiceException;
import com.sap.cloud.security.xsuaa.client.OAuth2TokenKeyService;
import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static com.sap.cloud.security.config.Service.XSUAA;
import static com.sap.cloud.security.config.ServiceConstants.XSUAA.VERIFICATION_KEY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

/**
* Performance test for java-security jwt token validation.
Expand Down Expand Up @@ -78,9 +84,12 @@ private CombiningValidator<Token> createOfflineTokenValidator() throws IOExcepti
OAuth2ServiceConfiguration configuration = createConfigurationBuilder()
.withProperty(VERIFICATION_KEY, publicKey)
.build();
OAuth2TokenKeyService tokenKeyServiceMock = Mockito.mock(OAuth2TokenKeyService.class);
when(tokenKeyServiceMock.retrieveTokenKeys(any(), (Map<String, String>) any())).thenThrow(
OAuth2ServiceException.class);
return JwtValidatorBuilder.getInstance(configuration)
// oAuth2TokenKeyService mocked because verificationkey property is used for offline token validation
.withOAuth2TokenKeyService((uri, zoneId) -> "{\"keys\": []}")
.withOAuth2TokenKeyService(tokenKeyServiceMock)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,27 @@ protected PublicKey getPublicKey(Token token, JwtSignatureAlgorithm algorithm)
key = fetchPublicKey(token, algorithm);
} catch (OAuth2ServiceException | InvalidKeySpecException | NoSuchAlgorithmException
| IllegalArgumentException e) {
LOGGER.error("Error fetching public key from XSUAA service: {}", e.getMessage());
if (!configuration.hasProperty(ServiceConstants.XSUAA.VERIFICATION_KEY)) {
throw e;
}
}

if (key == null && configuration.hasProperty(ServiceConstants.XSUAA.VERIFICATION_KEY)) {
String fallbackKey = configuration.getProperty(ServiceConstants.XSUAA.VERIFICATION_KEY);
try {
key = JsonWebKeyImpl.createPublicKeyFromPemEncodedPublicKey(JwtSignatureAlgorithm.RS256, fallbackKey);
} catch (NoSuchAlgorithmException | InvalidKeySpecException ex) {
throw new IllegalArgumentException(
"Fallback validation key supplied via " + ServiceConstants.XSUAA.VERIFICATION_KEY
+ " property in service credentials could not be used: {}",
ex);
if (configuration.hasProperty(ServiceConstants.XSUAA.VERIFICATION_KEY)) {
String fallbackKey = configuration.getProperty(ServiceConstants.XSUAA.VERIFICATION_KEY);
try {
key = JsonWebKeyImpl.createPublicKeyFromPemEncodedPublicKey(JwtSignatureAlgorithm.RS256,
fallbackKey);
} catch (NoSuchAlgorithmException | InvalidKeySpecException ex) {
IllegalArgumentException illegalArgEx = new IllegalArgumentException(
"Fallback validation key supplied via " + ServiceConstants.XSUAA.VERIFICATION_KEY
+ " property in service credentials could not be used: " + ex.getMessage());
if (e.getClass().equals(OAuth2ServiceException.class)) {
finkmanAtSap marked this conversation as resolved.
Show resolved Hide resolved
e.addSuppressed(illegalArgEx);
throw e;
}
throw illegalArgEx;

}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.sap.cloud.security.token.Token;
import com.sap.cloud.security.token.XsuaaToken;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.xsuaa.client.OAuth2ServiceException;
import com.sap.cloud.security.xsuaa.client.OAuth2TokenKeyService;
import com.sap.cloud.security.xsuaa.client.OidcConfigurationService;
import com.sap.cloud.security.xsuaa.http.HttpHeaders;
Expand Down Expand Up @@ -76,6 +77,38 @@ public void xsuaa_RSASignatureMatchesJWKS() {
assertThat(cut.validate(xsuaaToken).isValid(), is(true));
}

@Test
public void onlineVerificationFails_noVerificationKey() throws IOException {
when(tokenKeyServiceMock
.retrieveTokenKeys(
URI.create("https://authentication.stagingaws.hanavlab.ondemand.com/token_keys?zid=uaa"),
Map.of(HttpHeaders.X_ZID, "uaa")))
.thenThrow(new OAuth2ServiceException("Error retrieving token keys"));

ValidationResult result = cut.validate(xsuaaToken);
assertThat(result.isErroneous(), is(true));
assertThat(result.getErrorDescription(), containsString("JWKS could not be fetched"));
}

@Test
public void onlineVerificationFails_withNotWorkingVerificationKey() throws IOException {
when(tokenKeyServiceMock
.retrieveTokenKeys(
URI.create("https://authentication.stagingaws.hanavlab.ondemand.com/token_keys?zid=uaa"),
Map.of(HttpHeaders.X_ZID, "uaa")))
.thenThrow(new OAuth2ServiceException("Error retrieving token keys"));
when(mockConfiguration.hasProperty("verificationkey")).thenReturn(true);
when(mockConfiguration.getProperty("verificationkey")).thenReturn(
"""
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAm1QaZzMjtEfHdimrHP3/
oQIDAQAB
-----END PUBLIC KEY-----""");
ValidationResult result = cut.validate(xsuaaToken);
assertThat(result.isErroneous(), is(true));
assertThat(result.getErrorDescription(), containsString("JWKS could not be fetched"));
}

@Test
public void generatedToken_SignatureMatchesVerificationkey() {
when(mockConfiguration.hasProperty("verificationkey")).thenReturn(true);
Expand All @@ -94,7 +127,7 @@ public void generatedToken_SignatureMatchesVerificationkey() {
}

@Test
public void validationFails_whenVerificationkeyIsInvalid() {
public void validationFails_whenVerificationKeyIsInvalid() {
when(mockConfiguration.hasProperty("verificationkey")).thenReturn(true);
when(mockConfiguration.getProperty("verificationkey")).thenReturn("INVALIDKEY");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
Expand Down Expand Up @@ -82,7 +83,11 @@ public Jwt decode(String encodedToken) {
default -> throw new BadJwtException("Tokens issued by " + token.getService() + " service aren't supported.");
}
if (validationResult.isErroneous()) {
throw new BadJwtException("The token is invalid: " + validationResult.getErrorDescription());
if (validationResult.getErrorDescription().contains("Error retrieving token keys")) {
finkmanAtSap marked this conversation as resolved.
Show resolved Hide resolved
throw new JwtException("Error retrieving token keys: " + validationResult.getErrorDescription());
liga-oz marked this conversation as resolved.
Show resolved Hide resolved
} else {
throw new BadJwtException("The token is invalid: " + validationResult.getErrorDescription());
}
}
logger.debug("Token issued by {} service was successfully validated.", token.getService());
return jwt;
Expand All @@ -99,5 +104,4 @@ public static Jwt parseJwt(Token token) {
return new Jwt(token.getTokenValue(), token.getNotBefore(), token.getExpiration(),
token.getHeaders(), token.getClaims());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.sap.cloud.security.config.OAuth2ServiceConfiguration;
import com.sap.cloud.security.test.JwtGenerator;
import com.sap.cloud.security.token.SecurityContext;
import com.sap.cloud.security.token.Token;
import com.sap.cloud.security.token.TokenClaims;
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResults;
import com.sap.cloud.security.token.validation.validators.JwtValidatorBuilder;
import com.sap.cloud.security.x509.X509Certificate;
import org.apache.commons.io.IOUtils;
import org.assertj.core.api.Assertions;
Expand All @@ -24,6 +26,7 @@
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

Expand Down Expand Up @@ -127,6 +130,17 @@ void decodeWithCorruptToken_throwsBadJwtException() {
assertThrows(BadJwtException.class, () -> cut.decode("Bearerabc"));
}

@Test
void decode_cantRetrieveJWK() {
OAuth2ServiceConfiguration configuration = Mockito.mock(OAuth2ServiceConfiguration.class);
when(configuration.getService()).thenReturn(XSUAA);
when(configuration.getClientId()).thenReturn("theClientId");
CombiningValidator<Token> xsuaaValidators = JwtValidatorBuilder.getInstance(configuration).build();
HybridJwtDecoder cut = new HybridJwtDecoder(xsuaaValidators, null);
String encodedToken = JwtGenerator.getInstance(XSUAA, "theClientId").createToken().getTokenValue();
assertThrows(JwtException.class, () -> cut.decode(encodedToken));
}

@Test
void instantiateForXsuaaOnly() {
cut = new HybridJwtDecoder(combiningValidator, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ private Jwt verifyToken(JWT jwt) {
validateJwksParameters(kid, uaaDomain);

return verifyToken(jwt.getParsedString(), kid, uaaDomain, getZid(jwt));
} catch (BadJwtException e) {
} catch (JwtException e) {
if (e.getMessage().contains("Couldn't retrieve remote JWK set")
|| e.getMessage().contains("Cannot verify with online token key, uaadomain is")) {
logger.debug(e.getMessage());
logger.error(e.getMessage());
return tryToVerifyWithVerificationKey(jwt.getParsedString(), e);
} else {
throw e;
Expand Down Expand Up @@ -169,13 +169,8 @@ private Jwt verifyToken(String token, String kid, String uaaDomain, String zid)
}
}

try {
return verifyWithKey(token, jku, kid);
} catch (JwtValidationException ex) {
throw ex;
} catch (JwtException ex) {
throw new BadJwtException("JWT verification failed: " + ex.getMessage());
}

}

private void validateJwksParameters(String kid, String uaadomain) {
Expand Down Expand Up @@ -232,18 +227,19 @@ private Jwt tryToVerifyWithVerificationKey(String token, JwtException verificati
if (!hasText(verificationKey)) {
throw verificationException;
}
return verifyWithVerificationKey(token, verificationKey);
return verifyWithVerificationKey(token, verificationKey, verificationException);
}

private Jwt verifyWithVerificationKey(String token, String verificationKey) {
private Jwt verifyWithVerificationKey(String token, String verificationKey,
JwtException onlineVerificationException) {
try {
RSAPublicKey rsaPublicKey = createPublicKey(verificationKey);
NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey(rsaPublicKey).build();
decoder.setJwtValidator(tokenValidators);
return decoder.decode(token);
} catch (NoSuchAlgorithmException | IllegalArgumentException | InvalidKeySpecException | BadJwtException e) {
logger.debug("Jwt signature validation with fallback verificationkey failed: {}", e.getMessage());
throw new BadJwtException("Jwt validation with fallback verificationkey failed");
} catch (NoSuchAlgorithmException | IllegalArgumentException | InvalidKeySpecException e) {
logger.error("Jwt signature validation with fallback verificationkey failed: {}", e.getMessage());
throw new JwtException("Jwt validation with fallback verificationkey failed", onlineVerificationException);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
Expand Down Expand Up @@ -101,7 +100,7 @@ public void decode_withInvalidFallbackVerificationKey_withoutUaaDomain() {

final JwtDecoder cut = new XsuaaJwtDecoderBuilder(config).build();

assertThatThrownBy(() -> cut.decode(rsaToken)).isInstanceOf(BadJwtException.class)
assertThatThrownBy(() -> cut.decode(rsaToken)).isInstanceOf(JwtException.class)
.hasMessageContaining("Jwt validation with fallback verificationkey failed");
}

Expand Down
Loading