Skip to content
This repository has been archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
skip AAD internal filter when is already authenticated or token not i…
Browse files Browse the repository at this point in the history
…ssued by AAD (#872)

Co-authored-by: Xiaolu Dai <xiada@microsoft.com>
  • Loading branch information
saragluna and saragluna authored May 14, 2020
1 parent 7a2cb96 commit 1041b83
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
boolean cleanupRequired = false;

if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
try {
final String token = authHeader.replace(TOKEN_TYPE, "");
final UserPrincipal principal = principalManager.buildUserPrincipal(token);
final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles"))
.filter(r -> !r.isEmpty())
.orElse(DEFAULT_ROLE_CLAIM);
final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, rolesToGrantedAuthorities(roles));
authentication.setAuthenticated(true);
log.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
cleanupRequired = true;
} catch (BadJWTException ex) {
final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
log.warn(errorMessage);
throw new ServletException(errorMessage, ex);
} catch (ParseException | BadJOSEException | JOSEException ex) {
log.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
}
if (!alreadyAuthenticated() && hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
cleanupRequired = verifyToken(authHeader.replace(TOKEN_TYPE, ""));
}

try {
Expand All @@ -85,6 +66,39 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}
}

private boolean verifyToken(String token) throws ServletException {
if (!principalManager.isTokenIssuedByAAD(token)) {
log.info("Token {} is not issued by AAD", token);
return false;
}

try {
final UserPrincipal principal = principalManager.buildUserPrincipal(token);
final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles"))
.filter(r -> !r.isEmpty())
.orElse(DEFAULT_ROLE_CLAIM);

final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, rolesToGrantedAuthorities(roles));
authentication.setAuthenticated(true);
log.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
return true;
} catch (BadJWTException ex) {
final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
log.warn(errorMessage);
throw new ServletException(errorMessage, ex);
} catch (ParseException | BadJOSEException | JOSEException ex) {
log.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
}
}

private boolean alreadyAuthenticated() {
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return authentication != null && authentication.isAuthenticated();
}

protected Set<SimpleGrantedAuthority> rolesToGrantedAuthorities(JSONArray roles) {
return roles.stream()
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.naming.ServiceUnavailableException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
Expand All @@ -43,80 +45,93 @@ public class AADAuthenticationFilter extends OncePerRequestFilter {
public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false);
this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps,
aadAuthProps,
resourceRetriever,
false));
}

public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever,
JWKSetCache jwkSetCache) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = new UserPrincipalManager(serviceEndpointsProps,
this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps,
aadAuthProps,
resourceRetriever,
false,
jwkSetCache);
jwkSetCache));
}

public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
UserPrincipalManager userPrincipalManager) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = userPrincipalManager;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
final String authHeader = request.getHeader(TOKEN_HEADER);

if (authHeader != null && authHeader.startsWith(TOKEN_TYPE)) {
try {
final String idToken = authHeader.replace(TOKEN_TYPE, "");
UserPrincipal principal = (UserPrincipal) request
.getSession().getAttribute(CURRENT_USER_PRINCIPAL);
String graphApiToken = (String) request
.getSession().getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN);
final String currentToken = (String) request
.getSession().getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN);

final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(),
aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps);

if (principal == null ||
graphApiToken == null ||
graphApiToken.isEmpty() ||
!idToken.equals(currentToken)
) {
principal = principalManager.buildUserPrincipal(idToken);

final String tenantId = principal.getClaim().toString();
graphApiToken = client.acquireTokenForGraphApi(idToken, tenantId).accessToken();

principal.setUserGroups(client.getGroups(graphApiToken));

request.getSession().setAttribute(CURRENT_USER_PRINCIPAL, principal);
request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken);
request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, idToken);
}

final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups()));

authentication.setAuthenticated(true);
log.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
} catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) {
log.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
} catch (ServiceUnavailableException ex) {
log.error("Failed to acquire graph api token.", ex);
throw new ServletException(ex);
} catch (MsalServiceException ex) {
if (ex.claims() != null && !ex.claims().isEmpty()) {
throw new ServletException("Handle conditional access policy", ex);
} else {
throw ex;
}
}
if (!alreadyAuthenticated() && authHeader != null && authHeader.startsWith(TOKEN_TYPE)) {
verifyToken(request.getSession(), authHeader.replace(TOKEN_TYPE, ""));
}

filterChain.doFilter(request, response);
}

private boolean alreadyAuthenticated() {
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return authentication != null && authentication.isAuthenticated();
}

private void verifyToken(HttpSession session, String token) throws IOException, ServletException {
if (!principalManager.isTokenIssuedByAAD(token)) {
log.info("Token {} is not issued by AAD", token);
return;
}

try {
final String currentToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN);
UserPrincipal principal = (UserPrincipal) session.getAttribute(CURRENT_USER_PRINCIPAL);
String graphApiToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN);

final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(),
aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps);

if (principal == null || graphApiToken == null || graphApiToken.isEmpty() || !token.equals(currentToken)) {
principal = principalManager.buildUserPrincipal(token);

final String tenantId = principal.getClaim().toString();
graphApiToken = client.acquireTokenForGraphApi(token, tenantId).accessToken();

principal.setUserGroups(client.getGroups(graphApiToken));

session.setAttribute(CURRENT_USER_PRINCIPAL, principal);
session.setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken);
session.setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, token);
}

final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups()));

authentication.setAuthenticated(true);
log.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
} catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) {
log.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
} catch (ServiceUnavailableException ex) {
log.error("Failed to acquire graph api token.", ex);
throw new ServletException(ex);
} catch (MsalServiceException ex) {
if (ex.claims() != null && !ex.claims().isEmpty()) {
throw new ServletException("Handle conditional access policy", ex);
} else {
throw ex;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;

import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
import lombok.extern.slf4j.Slf4j;

import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

@Slf4j
public class UserPrincipalManager {
Expand Down Expand Up @@ -130,6 +131,24 @@ public UserPrincipal buildUserPrincipal(String idToken) throws ParseException, J
return new UserPrincipal(jwsObject, jwtClaimsSet);
}

public boolean isTokenIssuedByAAD(String token) {
try {
final JWT jwt = JWTParser.parse(token);
return isAADIssuer(jwt.getJWTClaimsSet().getIssuer());
} catch (ParseException e) {
log.info("Fail to parse JWT {}, exception {}", token, e);
}
return false;
}

private static boolean isAADIssuer(String issuer) {
if (issuer == null) {
return false;
}
return issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER) || issuer.startsWith(STS_WINDOWS_ISSUER)
|| issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER);
}

private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) {
final ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();

Expand All @@ -143,9 +162,7 @@ private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlg
public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
super.verify(claimsSet, ctx);
final String issuer = claimsSet.getIssuer();
if (issuer == null || !(issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER)
|| issuer.startsWith(STS_WINDOWS_ISSUER)
|| issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER))) {
if (!isAADIssuer(issuer)) {
throw new BadJWTException("Invalid token issuer");
}
if (explicitAudienceCheck) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -47,7 +48,7 @@

public class AADAppRoleAuthenticationFilterTest {

public static final String TOKEN = "dummy-token";
private static final String TOKEN = "dummy-token";

private final UserPrincipalManager userPrincipalManager;
private final HttpServletRequest request;
Expand Down Expand Up @@ -84,12 +85,14 @@ public void testDoFilterGoodCase()

when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal);
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);

// Check in subsequent filter that authentication is available!
final FilterChain filterChain = new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {

final SecurityContext context = SecurityContextHolder.getContext();
assertNotNull(context);
final Authentication authentication = context.getAuthentication();
Expand All @@ -113,6 +116,7 @@ public void testDoFilterShouldRethrowJWTException()

when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(any())).thenThrow(new BadJWTException("bad token"));
when(userPrincipalManager.isTokenIssuedByAAD(any())).thenReturn(true);

filter.doFilterInternal(request, response, mock(FilterChain.class));
}
Expand All @@ -125,6 +129,7 @@ public void testDoFilterAddsDefaultRole()

when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal);
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);

// Check in subsequent filter that authentication is available and default roles are filled.
final FilterChain filterChain = new FilterChain() {
Expand Down Expand Up @@ -158,4 +163,39 @@ public void testRolesToGrantedAuthoritiesShouldConvertRolesAndFilterNulls() {
new SimpleGrantedAuthority("ROLE_ADMIN")));
}

@Test
public void testTokenNotIssuedByAAD() throws ServletException, IOException {
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(false);

final FilterChain filterChain = (request, response) -> {
final SecurityContext context = SecurityContextHolder.getContext();
assertNotNull(context);
final Authentication authentication = context.getAuthentication();
assertNull(authentication);
};

filter.doFilterInternal(request, response, filterChain);
}

@Test
public void testAlreadyAuthenticated() throws ServletException, IOException, ParseException, JOSEException,
BadJOSEException {
final Authentication authentication = mock(Authentication.class);
when(authentication.isAuthenticated()).thenReturn(true);
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);

SecurityContextHolder.getContext().setAuthentication(authentication);

final FilterChain filterChain = (request, response) -> {
final SecurityContext context = SecurityContextHolder.getContext();
assertNotNull(context);
assertNotNull(context.getAuthentication());
SecurityContextHolder.clearContext();
};

filter.doFilterInternal(request, response, filterChain);
verify(userPrincipalManager, times(0)).buildUserPrincipal(TOKEN);

}

}
Loading

0 comments on commit 1041b83

Please sign in to comment.