Skip to content

Commit

Permalink
Fix memory leak from JWT cache (and fix the usage of the JWT auth cac…
Browse files Browse the repository at this point in the history
…he) (elastic#101799)

This commit fixes a memory leak and ensures that the JWT authentication cache is actually used to short circuit authenticate. 

When JWT authentication is successful we cache a hash(JWT) -> User. On subsequent authentication attempts we should 
short circuit some of the more expensive validation thanks to this cache. The existing cache key is a `BytesArray` constructed 
from `jwtAuthenticationToken.getUserCredentialsHash()` (the hash value of the JWT). However,  `jwtAuthenticationToken` is 
comes from the security context which is effectively a request scoped object. When the security context goes out of scope the 
value of `jwtAuthenticationToken.getUserCredentialsHash()` is zero'ed out to help keep sensitive data out of the heap. It is 
arguable if zero'ing out that data is  useful especially for a hashed value, but is inline with the internal contract/expectations. 
Since the cache key is derived from `jwtAuthenticationToken.getUserCredentialsHash()` when that get object is zero'ed out, 
so is the cache key. 

This results in a cache key that changes from a valid value to an empty byte array. This results in a junk  cache entry.  Subsequent 
authentication with the same JWT will result in a new cache entry which will then follow the same pattern of getting zero'ed out. 
This results in a useless cache with nothing but zero'ed out cache keys. This negates any benefits of having a cache at all in that 
a full authentication is preformed all the time which can be expensive for JWT (especially since JWT requires role mappings 
and role mappings are not cached). 

Fortunately the default cache size is 100k (by count) so the actual memory leak is technically capped but can vary depending 
on how large the cache values. This is an approximate cap of ~55MB where 3.125 MB (100k * 256 bits) for the 
sha256 cache keys + ~50MB (100k * ~50 bytes) cache values, however, it is possible for the cache to be larger if the values in 
the cache are larger. 

The fix here is to ensure the cache key used is a copy of the value that is zero'ed out (before it is zero'ed out).

fixes: elastic#101752
  • Loading branch information
jakelandis authored Nov 6, 2023
1 parent 63f29d4 commit cef2b80
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/101799.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 101799
summary: Fix memory leak from JWT cache (and fix the usage of the JWT auth cache)
area: Authentication
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
Expand Down Expand Up @@ -283,7 +284,9 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
return; // FAILED (secret is missing or mismatched)
}

final BytesArray jwtCacheKey = isCacheEnabled() ? new BytesArray(jwtAuthenticationToken.getUserCredentialsHash()) : null;
final BytesArray jwtCacheKey = isCacheEnabled()
? new BytesArray(new BytesRef(jwtAuthenticationToken.getUserCredentialsHash()), true)
: null;
if (jwtCacheKey != null) {
final User cachedUser = tryAuthenticateWithCache(tokenPrincipal, jwtCacheKey);
if (cachedUser != null) {
Expand Down Expand Up @@ -483,6 +486,11 @@ private boolean isCacheEnabled() {
return jwtCache != null && jwtCacheHelper != null;
}

// package private for testing
Cache<BytesArray, ExpiringUser> getJwtCache() {
return jwtCache;
}

/**
* Format and filter JWT contents as user metadata.
* @param claimsSet Claims are supported. Claim keys are prefixed by "jwt_claim_".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ public void testJwtAuthcRealmAuthcAuthzWithEmptyRoles() throws Exception {
doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
}

public void testJwtCache() throws Exception {
jwtIssuerAndRealms = generateJwtIssuerRealmPairs(1, 1, 1, 1, 1, 1, 99, false);
JwtRealm realm = jwtIssuerAndRealms.get(0).realm();
realm.expireAll();
assertThat(realm.getJwtCache().count(), is(0));
final JwtIssuerAndRealm jwtIssuerAndRealm = randomJwtIssuerRealmPair();
final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
for (int i = 1; i <= randomIntBetween(2, 10); i++) {
User user = randomUser(jwtIssuerAndRealm.issuer());
doMultipleAuthcAuthzAndVerifySuccess(
jwtIssuerAndRealm.realm(),
user,
randomJwt(jwtIssuerAndRealm, user),
clientSecret,
randomIntBetween(2, 10)
);
assertThat(realm.getJwtCache().count(), is(i));
}
}

/**
* Test with no authz realms.
* @throws Exception Unexpected test failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import com.nimbusds.openid.connect.sdk.Nonce;

import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.settings.MockSecureSettings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -46,6 +48,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
Expand Down Expand Up @@ -290,7 +293,7 @@ protected JwtRealmSettingsBuilder createJwtRealmSettingsBuilder(final JwtIssuer
if (randomBoolean()) {
authcSettings.put(
RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_TTL),
randomIntBetween(10, 120) + randomFrom("s", "m", "h")
randomIntBetween(10, 120) + randomFrom("m", "h")
);
}
authcSettings.put(RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_SIZE), jwtCacheSize);
Expand Down Expand Up @@ -378,11 +381,12 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
final int jwtAuthcRepeats
) {
final List<JwtRealm> jwtRealmsList = jwtIssuerAndRealms.stream().map(p -> p.realm).toList();

BytesArray firstCacheKeyFound = null;
// Select different test JWKs from the JWT realm, and generate test JWTs for the test user. Run the JWT through the chain.
for (int authcRun = 1; authcRun <= jwtAuthcRepeats; authcRun++) {

final ThreadContext requestThreadContext = createThreadContext(jwt, sharedSecret);
logger.info("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());
logger.debug("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());

// Any JWT realm can recognize and extract the request headers.
final var jwtAuthenticationToken = (JwtAuthenticationToken) randomFrom(jwtRealmsList).token(requestThreadContext);
Expand All @@ -393,11 +397,11 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
// Loop through all authc/authz realms. Confirm user is returned with expected principal and roles.
User authenticatedUser = null;
realmLoop: for (final JwtRealm candidateJwtRealm : jwtRealmsList) {
logger.info("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
logger.debug("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
final PlainActionFuture<AuthenticationResult<User>> authenticateFuture = PlainActionFuture.newFuture();
candidateJwtRealm.authenticate(jwtAuthenticationToken, authenticateFuture);
final AuthenticationResult<User> authenticationResult = authenticateFuture.actionGet();
logger.info("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult);
logger.debug("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult);
switch (authenticationResult.getStatus()) {
case SUCCESS:
assertThat("Unexpected realm SUCCESS status", candidateJwtRealm.name(), equalTo(jwtRealm.name()));
Expand Down Expand Up @@ -430,20 +434,41 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
equalTo(Map.of("jwt_token_type", JwtRealmInspector.getTokenType(jwtRealm).value()))
);
}
// if the cache is enabled ensure the cache is used and does not change for the provided jwt
if (jwtRealm.getJwtCache() != null) {
Cache<BytesArray, JwtRealm.ExpiringUser> cache = jwtRealm.getJwtCache();
if (firstCacheKeyFound == null) {
assertNotNull("could not find cache keys", cache.keys());
firstCacheKeyFound = cache.keys().iterator().next();
}
jwtAuthenticationToken.clearCredentials(); // simulates the realm's context closing which clears the credential
boolean foundInCache = false;
for (BytesArray key : cache.keys()) {
logger.trace("cache key: " + HexFormat.of().formatHex(key.array()));
if (key.equals(firstCacheKeyFound)) {
foundInCache = true;
}
assertFalse(
"cache key should not be nulled out",
IntStream.range(0, key.array().length).map(idx -> key.array()[idx]).allMatch(b -> b == 0)
);
}
assertTrue("cache key was not found in cache", foundInCache);
}
}
logger.info("Test succeeded");
logger.debug("Test succeeded");
}

protected User randomUser(final JwtIssuer jwtIssuer) {
final User user = randomFrom(jwtIssuer.principals.values());
logger.info("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "].");
logger.debug("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "].");
return user;
}

protected SecureString randomJwt(final JwtIssuerAndRealm jwtIssuerAndRealm, User user) throws Exception {
final JwtIssuer.AlgJwkPair algJwkPair = randomFrom(jwtIssuerAndRealm.issuer.algAndJwksAll);
final JWK jwk = algJwkPair.jwk();
logger.info(
logger.debug(
"ALG["
+ algJwkPair.alg()
+ "]. JWK: kty=["
Expand Down Expand Up @@ -491,7 +516,7 @@ protected void printJwtRealmAndIssuer(JwtIssuerAndRealm jwtIssuerAndRealm) throw
}

protected void printJwtRealm(final JwtRealm jwtRealm) {
logger.info(
logger.debug(
"REALM["
+ jwtRealm.name()
+ ","
Expand Down Expand Up @@ -527,15 +552,15 @@ protected void printJwtRealm(final JwtRealm jwtRealm) {
+ "]."
);
for (final JWK jwk : JwtRealmInspector.getJwksAlgsHmac(jwtRealm).jwks()) {
logger.info("REALM HMAC: jwk=[{}]", jwk);
logger.debug("REALM HMAC: jwk=[{}]", jwk);
}
for (final JWK jwk : JwtRealmInspector.getJwksAlgsPkc(jwtRealm).jwks()) {
logger.info("REALM PKC: jwk=[{}]", jwk);
logger.debug("REALM PKC: jwk=[{}]", jwk);
}
}

protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
logger.info(
logger.debug(
"ISSUER: iss=["
+ jwtIssuer.issuerClaimValue
+ "], aud=["
Expand All @@ -549,13 +574,13 @@ protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
+ "]."
);
if (jwtIssuer.algAndJwkHmacOidc != null) {
logger.info("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
logger.debug("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
}
for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksHmac) {
logger.info("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
logger.debug("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
}
for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksPkc) {
logger.info("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
logger.debug("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
}
}
}

0 comments on commit cef2b80

Please sign in to comment.