Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.11] Fix memory leak from JWT cache (and fix the usage of the JWT auth cache) (#101799) #101841

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/101799.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 101799
summary: Fix memory leak from JWT cache (and fix the usage of the JWT auth cache)
area: Authentication
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
Expand Down Expand Up @@ -276,7 +277,9 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
return; // FAILED (secret is missing or mismatched)
}

final BytesArray jwtCacheKey = isCacheEnabled() ? new BytesArray(jwtAuthenticationToken.getUserCredentialsHash()) : null;
final BytesArray jwtCacheKey = isCacheEnabled()
? new BytesArray(new BytesRef(jwtAuthenticationToken.getUserCredentialsHash()), true)
: null;
if (jwtCacheKey != null) {
final User cachedUser = tryAuthenticateWithCache(tokenPrincipal, jwtCacheKey);
if (cachedUser != null) {
Expand Down Expand Up @@ -476,6 +479,11 @@ private boolean isCacheEnabled() {
return jwtCache != null && jwtCacheHelper != null;
}

// package private for testing
Cache<BytesArray, ExpiringUser> getJwtCache() {
return jwtCache;
}

/**
* Format and filter JWT contents as user metadata.
* @param claimsSet Claims are supported. Claim keys are prefixed by "jwt_claim_".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ public void testJwtAuthcRealmAuthcAuthzWithEmptyRoles() throws Exception {
doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
}

public void testJwtCache() throws Exception {
jwtIssuerAndRealms = generateJwtIssuerRealmPairs(1, 1, 1, 1, 1, 1, 99, false);
JwtRealm realm = jwtIssuerAndRealms.get(0).realm();
realm.expireAll();
assertThat(realm.getJwtCache().count(), is(0));
final JwtIssuerAndRealm jwtIssuerAndRealm = randomJwtIssuerRealmPair();
final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
for (int i = 1; i <= randomIntBetween(2, 10); i++) {
User user = randomUser(jwtIssuerAndRealm.issuer());
doMultipleAuthcAuthzAndVerifySuccess(
jwtIssuerAndRealm.realm(),
user,
randomJwt(jwtIssuerAndRealm, user),
clientSecret,
randomIntBetween(2, 10)
);
assertThat(realm.getJwtCache().count(), is(i));
}
}

/**
* Test with no authz realms.
* @throws Exception Unexpected test failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import com.nimbusds.openid.connect.sdk.Nonce;

import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.settings.MockSecureSettings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -46,6 +48,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
Expand Down Expand Up @@ -290,7 +293,7 @@ protected JwtRealmSettingsBuilder createJwtRealmSettingsBuilder(final JwtIssuer
if (randomBoolean()) {
authcSettings.put(
RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_TTL),
randomIntBetween(10, 120) + randomFrom("s", "m", "h")
randomIntBetween(10, 120) + randomFrom("m", "h")
);
}
authcSettings.put(RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_SIZE), jwtCacheSize);
Expand Down Expand Up @@ -378,11 +381,12 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
final int jwtAuthcRepeats
) {
final List<JwtRealm> jwtRealmsList = jwtIssuerAndRealms.stream().map(p -> p.realm).toList();

BytesArray firstCacheKeyFound = null;
// Select different test JWKs from the JWT realm, and generate test JWTs for the test user. Run the JWT through the chain.
for (int authcRun = 1; authcRun <= jwtAuthcRepeats; authcRun++) {

final ThreadContext requestThreadContext = createThreadContext(jwt, sharedSecret);
logger.info("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());
logger.debug("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());

// Any JWT realm can recognize and extract the request headers.
final var jwtAuthenticationToken = (JwtAuthenticationToken) randomFrom(jwtRealmsList).token(requestThreadContext);
Expand All @@ -393,11 +397,11 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
// Loop through all authc/authz realms. Confirm user is returned with expected principal and roles.
User authenticatedUser = null;
realmLoop: for (final JwtRealm candidateJwtRealm : jwtRealmsList) {
logger.info("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
logger.debug("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
final PlainActionFuture<AuthenticationResult<User>> authenticateFuture = PlainActionFuture.newFuture();
candidateJwtRealm.authenticate(jwtAuthenticationToken, authenticateFuture);
final AuthenticationResult<User> authenticationResult = authenticateFuture.actionGet();
logger.info("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult);
logger.debug("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult);
switch (authenticationResult.getStatus()) {
case SUCCESS:
assertThat("Unexpected realm SUCCESS status", candidateJwtRealm.name(), equalTo(jwtRealm.name()));
Expand Down Expand Up @@ -430,20 +434,41 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
equalTo(Map.of("jwt_token_type", JwtRealmInspector.getTokenType(jwtRealm).value()))
);
}
// if the cache is enabled ensure the cache is used and does not change for the provided jwt
if (jwtRealm.getJwtCache() != null) {
Cache<BytesArray, JwtRealm.ExpiringUser> cache = jwtRealm.getJwtCache();
if (firstCacheKeyFound == null) {
assertNotNull("could not find cache keys", cache.keys());
firstCacheKeyFound = cache.keys().iterator().next();
}
jwtAuthenticationToken.clearCredentials(); // simulates the realm's context closing which clears the credential
boolean foundInCache = false;
for (BytesArray key : cache.keys()) {
logger.trace("cache key: " + HexFormat.of().formatHex(key.array()));
if (key.equals(firstCacheKeyFound)) {
foundInCache = true;
}
assertFalse(
"cache key should not be nulled out",
IntStream.range(0, key.array().length).map(idx -> key.array()[idx]).allMatch(b -> b == 0)
);
}
assertTrue("cache key was not found in cache", foundInCache);
}
}
logger.info("Test succeeded");
logger.debug("Test succeeded");
}

protected User randomUser(final JwtIssuer jwtIssuer) {
final User user = randomFrom(jwtIssuer.principals.values());
logger.info("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "].");
logger.debug("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "].");
return user;
}

protected SecureString randomJwt(final JwtIssuerAndRealm jwtIssuerAndRealm, User user) throws Exception {
final JwtIssuer.AlgJwkPair algJwkPair = randomFrom(jwtIssuerAndRealm.issuer.algAndJwksAll);
final JWK jwk = algJwkPair.jwk();
logger.info(
logger.debug(
"ALG["
+ algJwkPair.alg()
+ "]. JWK: kty=["
Expand Down Expand Up @@ -491,7 +516,7 @@ protected void printJwtRealmAndIssuer(JwtIssuerAndRealm jwtIssuerAndRealm) throw
}

protected void printJwtRealm(final JwtRealm jwtRealm) {
logger.info(
logger.debug(
"REALM["
+ jwtRealm.name()
+ ","
Expand Down Expand Up @@ -527,15 +552,15 @@ protected void printJwtRealm(final JwtRealm jwtRealm) {
+ "]."
);
for (final JWK jwk : JwtRealmInspector.getJwksAlgsHmac(jwtRealm).jwks()) {
logger.info("REALM HMAC: jwk=[{}]", jwk);
logger.debug("REALM HMAC: jwk=[{}]", jwk);
}
for (final JWK jwk : JwtRealmInspector.getJwksAlgsPkc(jwtRealm).jwks()) {
logger.info("REALM PKC: jwk=[{}]", jwk);
logger.debug("REALM PKC: jwk=[{}]", jwk);
}
}

protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
logger.info(
logger.debug(
"ISSUER: iss=["
+ jwtIssuer.issuerClaimValue
+ "], aud=["
Expand All @@ -549,13 +574,13 @@ protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
+ "]."
);
if (jwtIssuer.algAndJwkHmacOidc != null) {
logger.info("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
logger.debug("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
}
for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksHmac) {
logger.info("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
logger.debug("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
}
for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksPkc) {
logger.info("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
logger.debug("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
}
}
}