Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KNOX-3019 - Allow token renewal without upper bound for non-expired tokens #880

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,13 @@ public boolean isExpired(final JWTToken token) throws UnknownTokenException {
return getTokenExpiration(token) <= System.currentTimeMillis();
}

protected void setMaxLifetime(final String token, long parsedMaxLifeTime) {
maxTokenLifetimes.put(token, parsedMaxLifeTime);
protected void setMaxLifetime(final String token, long maxLifeTime) {
maxTokenLifetimes.put(token, maxLifeTime);
}

protected void setMaxLifetime(final String token, long issueTime, long maxLifetimeDuration) {
maxTokenLifetimes.put(token, issueTime + maxLifetimeDuration);
final long maxLifetime = maxLifetimeDuration < 0 ? maxLifetimeDuration : issueTime + maxLifetimeDuration;
setMaxLifetime(token, maxLifetime);
}

/**
Expand Down Expand Up @@ -313,8 +314,9 @@ private void removeTokenState(final Set<String> tokenIds) {
}

protected boolean hasRemainingRenewals(final String tokenId, long renewInterval) {
final long maximumTokenLifetime = getMaxLifetime(tokenId);
// If the current time + buffer + the renewal interval is less than the max lifetime for the token?
return ((System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(30) + renewInterval) < getMaxLifetime(tokenId));
return maximumTokenLifetime < 0 ? true : ((System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(30) + renewInterval) < maximumTokenLifetime);
}

protected long getMaxLifetime(final String tokenId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ protected void updateExpiration(String tokenId, long expiration) {

@Override
protected long getMaxLifetime(String tokenId) {
long maxLifetime = super.getMaxLifetime(tokenId);
long maxLifetime = super.getMaxLifetime(tokenId); // returns 0, if not found in memory

// If there is no result from the in-memory collection, proceed to check the Database
if (maxLifetime < 1L) {
if (maxLifetime == 0L) {
try {
maxLifetime = tokenDatabase.getMaxLifetime(tokenId);
log.fetchedMaxLifetimeFromDatabase(Tokens.getTokenIDDisplayText(tokenId), maxLifetime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ boolean addToken(String tokenId, long issueTime, long expiration, long maxLifeti
addTokenStatement.setString(1, tokenId);
addTokenStatement.setLong(2, issueTime);
addTokenStatement.setLong(3, expiration);
addTokenStatement.setLong(4, issueTime + maxLifetimeDuration);
addTokenStatement.setLong(4, maxLifetimeDuration < 0 ? maxLifetimeDuration : issueTime + maxLifetimeDuration);
return addTokenStatement.executeUpdate() == 1;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ public void testAddToken() throws Exception {
assertEquals(issueTime + maxLifetimeDuration, getLongTokenAttributeFromDatabase(tokenId, TokenStateDatabase.GET_MAX_LIFETIME_SQL));
}

@Test
public void testAddTokenThatCanAlwaysBeRenewed() throws Exception {
final String tokenId = UUID.randomUUID().toString();
jdbcTokenStateService.addToken(tokenId, System.currentTimeMillis(), System.currentTimeMillis() + 30, -1);

assertEquals(-1L, jdbcTokenStateService.getMaxLifetime(tokenId));
assertEquals(-1L, getLongTokenAttributeFromDatabase(tokenId, TokenStateDatabase.GET_MAX_LIFETIME_SQL));
}

@Test
public void testAddTokensForMultipleUsers() throws Exception {
String user1 = "user1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,12 @@
import static javax.ws.rs.core.MediaType.APPLICATION_XML;

/**
* @deprecated The public REST API endpoints in this class (bound to
* '/knoxtoken/v1/api/token/...') are no longer acceptable for
* token-related operations. Please use the
* '/knoxtoken/v2/api/token/...' path instead.
* Some of the public REST API endpoints in this class (bound to
* '/knoxtoken/v1/api/token/...') are no longer acceptable for token-related
* operations. Please use the '/knoxtoken/v2/api/token/...' path instead.
*
* @see TokenResourceV2
*/
@Deprecated
@Singleton
@Path(TokenResource.RESOURCE_PATH)
public class TokenResource {
Expand Down Expand Up @@ -206,7 +204,8 @@ public enum ErrorCode {
UNKNOWN_TOKEN(50),
ALREADY_DISABLED(60),
ALREADY_ENABLED(70),
DISABLED_KNOXSSO_COOKIE(80);
DISABLED_KNOXSSO_COOKIE(80),
TOKEN_EXPIRED(90);

private final int code;

Expand Down Expand Up @@ -549,13 +548,14 @@ public Response renew(String token) {
if (allowedRenewers.contains(renewer)) {
try {
JWTToken jwt = new JWTToken(token);
// If renewal fails, it should be an exception
expiration = tokenStateService.renewToken(jwt,
renewInterval.orElse(tokenStateService.getDefaultRenewInterval()));
log.renewedToken(getTopologyName(),
Tokens.getTokenDisplayText(token),
Tokens.getTokenIDDisplayText(TokenUtils.getTokenId(jwt)),
renewer);
if (tokenStateService.isExpired(jwt)) {
errorCode = ErrorCode.TOKEN_EXPIRED;
error = "Expired tokens must not be renewed.";
} else {
// If renewal fails, it should be an exception
expiration = tokenStateService.renewToken(jwt, renewInterval.orElse(tokenStateService.getDefaultRenewInterval()));
log.renewedToken(getTopologyName(), Tokens.getTokenDisplayText(token), Tokens.getTokenIDDisplayText(TokenUtils.getTokenId(jwt)), renewer);
}
} catch (ParseException e) {
log.invalidToken(getTopologyName(), Tokens.getTokenDisplayText(token), e);
errorCode = ErrorCode.INVALID_TOKEN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import org.apache.knox.gateway.services.ServiceLifecycleException;
import org.apache.knox.gateway.services.ServiceType;
import org.apache.knox.gateway.services.security.AliasService;
import org.apache.knox.gateway.services.security.AliasServiceException;
import org.apache.knox.gateway.services.security.token.JWTokenAttributes;
import org.apache.knox.gateway.services.security.token.JWTokenAuthority;
import org.apache.knox.gateway.services.security.token.KnoxToken;
Expand Down Expand Up @@ -784,6 +785,24 @@ public void testTokenRenewal_Enabled_WithConfigurableMaxTokenLifetime() throws E
assertEquals(10L, tss.getMaxLifetime(token) - tss.getIssueTime(token));
}

@Test
public void testTokenRenewalShouldFailOnExpiredTokens() throws Exception {
final long tokenTTL = 1;
final String renewer = "yarn";
final Map<String, String> contextExpectations = new HashMap<>();
contextExpectations.put(TokenStateService.CONFIG_SERVER_MANAGED, "true");
contextExpectations.put("knox.token.ttl", String.valueOf(tokenTTL));
contextExpectations.put("knox.token.renewer.whitelist", renewer);
Thread.sleep(tokenTTL + 10); // so that the token is expired
configureCommonExpectations(contextExpectations);
final TokenResource tokenResource = new TokenResource();
final String accessToken = getAccessToken(tokenResource);

final Response renewalResponse = requestTokenRenewal(tokenResource, accessToken, createTestSubject(renewer));
assertEquals(Response.Status.BAD_REQUEST, renewalResponse.getStatusInfo());
assertTrue(renewalResponse.getEntity().toString().contains("Expired tokens must not be renewed."));
assertTrue(renewalResponse.getEntity().toString().contains("\"code\": " + TokenResource.ErrorCode.TOKEN_EXPIRED.toInt()));
}

@Test
public void testTokenRevocation_ServerManagedStateNotConfigured() throws Exception {
Expand Down Expand Up @@ -1497,19 +1516,8 @@ private Map.Entry<TestTokenStateService, Response> doTestTokenLifecyle(final Tok

configureCommonExpectations(contextExpectations, gatewayLevelConfig);

TokenResource tr = new TokenResource();
tr.request = request;
tr.context = context;
tr.init();

// Request a token
Response retResponse = tr.doGet();
assertEquals(200, retResponse.getStatus());

// Parse the response
String retString = retResponse.getEntity().toString();
String accessToken = getTagValue(retString, "access_token");
assertNotNull(accessToken);
final TokenResource tr = new TokenResource();
final String accessToken = getAccessToken(tr);

Response response;
switch (operation) {
Expand All @@ -1526,6 +1534,22 @@ private Map.Entry<TestTokenStateService, Response> doTestTokenLifecyle(final Tok
return new AbstractMap.SimpleEntry<>(tss, response);
}

private String getAccessToken(TokenResource tokenResource) throws KeyLengthException, AliasServiceException, ServiceLifecycleException {
tokenResource.request = request;
tokenResource.context = context;
tokenResource.init();

// Request a token
final Response retResponse = tokenResource.doGet();
assertEquals(200, retResponse.getStatus());

// Parse the response
final String retString = retResponse.getEntity().toString();
final String accessToken = getTagValue(retString, "access_token");
assertNotNull(accessToken);
return accessToken;
}

private static Response requestTokenRenewal(final TokenResource tr, final String tokenData, final Subject caller) {
Response response;
if (caller != null) {
Expand Down Expand Up @@ -1681,6 +1705,10 @@ public void addToken(String tokenId, long issueTime, long expiration, long maxLi

@Override
public boolean isExpired(JWTToken token) {
try {
return getTokenExpiration(token) <= System.currentTimeMillis();
} catch (UnknownTokenException e) {
}
return false;
}

Expand Down Expand Up @@ -1719,7 +1747,7 @@ public long renewToken(String tokenId, long renewInterval) {

@Override
public long getTokenExpiration(JWT token) throws UnknownTokenException {
return 0;
return getTokenExpiration(TokenUtils.getTokenId(token));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public long getExpirationLong() {
}

public String getMaxLifetime() {
return KNOX_TOKEN_TS_FORMAT.get().format(new Date(maxLifetime));
return maxLifetime < 0 ? "Unbounded" : KNOX_TOKEN_TS_FORMAT.get().format(new Date(maxLifetime));
}

public long getMaxLifetimeLong() {
Expand Down
Loading