Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrandja authored and mohammedBalhaddad committed Oct 12, 2020
1 parent ccbf157 commit 3c74c7e
Show file tree
Hide file tree
Showing 12 changed files with 430 additions and 635 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
*/
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private RegisteredClient registeredClient;
private Authentication clientPrincipal;
private final RegisteredClient registeredClient;
private final Authentication clientPrincipal;
private final OAuth2AccessToken accessToken;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.jose.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
Expand All @@ -39,12 +39,12 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
Expand All @@ -54,6 +54,7 @@
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant.
*
* @author Joe Grandja
* @author Daniel Garnier-Moiroux
* @since 0.0.1
* @see OAuth2AuthorizationCodeAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken
Expand Down Expand Up @@ -91,12 +92,14 @@ public Authentication authenticate(Authentication authentication) throws Authent
(OAuth2AuthorizationCodeAuthenticationToken) authentication;

OAuth2ClientAuthenticationToken clientPrincipal = null;
RegisteredClient registeredClient = null;
RegisteredClient registeredClient;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
if (!clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
registeredClient = clientPrincipal.getRegisteredClient();
} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
// When the principal is a string, it is the clientId, REQUIRED for public clients
String clientId = authorizationCodeAuthentication.getClientId();
registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
Expand All @@ -106,10 +109,6 @@ public Authentication authenticate(Authentication authentication) throws Authent
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}

if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}

OAuth2Authorization authorization = this.authorizationService.findByToken(
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
if (authorization == null) {
Expand All @@ -118,24 +117,21 @@ public Authentication authenticate(Authentication authentication) throws Authent

OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
if (StringUtils.hasText(authorizationRequest.getRedirectUri()) &&
!authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

if (StringUtils.hasText(authorizationRequest.getRedirectUri()) &&
!authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

String codeChallenge;
Object codeChallengeParameter = authorizationRequest
// Validate PKCE parameters
String codeChallenge = (String) authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE);

if (codeChallengeParameter != null) {
codeChallenge = (String) codeChallengeParameter;

if (StringUtils.hasText(codeChallenge)) {
String codeChallengeMethod = (String) authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
Expand All @@ -147,11 +143,10 @@ public Authentication authenticate(Authentication authentication) throws Authent
if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
} else if (registeredClient.getClientSettings().requireProofKey()){
} else if (registeredClient.getClientSettings().requireProofKey()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}


JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();

// TODO Allow configuration for issuer claim
Expand Down Expand Up @@ -189,24 +184,22 @@ public Authentication authenticate(Authentication authentication) throws Authent
new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
}

private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
if (codeVerifier == null) {
private static boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
if (!StringUtils.hasText(codeVerifier)) {
return false;
} else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) {
} else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) {
return codeVerifier.equals(codeChallenge);
} else if ("S256".equals(codeChallengeMethod)) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
return codeChallenge.equals(encodedVerifier);
} catch (NoSuchAlgorithmException e) {
return encodedVerifier.equals(codeChallenge);
} catch (NoSuchAlgorithmException ex) {
// It is unlikely that SHA-256 is not available on the server. If it is not available,
// there will likely be bigger issues as well. We default to SERVER_ERROR.
}
}

// Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
*
* @author Joe Grandja
* @author Madhu Bhat
* @author Daniel Garnier-Moiroux
* @since 0.0.1
* @see AbstractAuthenticationToken
* @see OAuth2AuthorizationCodeAuthenticationProvider
Expand All @@ -37,7 +38,7 @@
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private final String code;
private Authentication clientPrincipal;
private final Authentication clientPrincipal;
private final String clientId;
private final String redirectUri;
private final Map<String, Object> additionalParameters;
Expand All @@ -50,22 +51,21 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code,
Authentication clientPrincipal, @Nullable String redirectUri,
Map<String, Object> additionalParameters) {
public OAuth2AuthorizationCodeAuthenticationToken(String code, Authentication clientPrincipal,
@Nullable String redirectUri, @Nullable Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
this.code = code;
this.clientPrincipal = clientPrincipal;
this.clientId = OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass()) ?
(String) this.clientPrincipal.getPrincipal() :
null;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());

if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) {
this.clientId = (String) this.clientPrincipal.getPrincipal();
} else {
this.clientId = null;
}
this.additionalParameters = Collections.unmodifiableMap(
additionalParameters != null ?
additionalParameters :
Collections.emptyMap());
}

/**
Expand All @@ -76,16 +76,19 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code,
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code,
String clientId, @Nullable String redirectUri,
Map<String, Object> additionalParameters) {
public OAuth2AuthorizationCodeAuthenticationToken(String code, String clientId,
@Nullable String redirectUri, @Nullable Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.hasText(clientId, "clientId cannot be empty");
this.code = code;
this.clientPrincipal = null;
this.clientId = clientId;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
this.additionalParameters = Collections.unmodifiableMap(
additionalParameters != null ?
additionalParameters :
Collections.emptyMap());
}

@Override
Expand All @@ -107,6 +110,15 @@ public String getCode() {
return this.code;
}

/**
* Returns the client identifier
*
* @return the client identifier
*/
public @Nullable String getClientId() {
return this.clientId;
}

/**
* Returns the redirect uri.
*
Expand All @@ -124,13 +136,4 @@ public String getCode() {
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}

/**
* Returns the client id
*
* @return the client id
*/
public @Nullable String getClientId() {
return this.clientId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
*
* @author Joe Grandja
* @author Paurav Munshi
* @author Daniel Garnier-Moiroux
* @since 0.0.1
* @see RegisteredClientRepository
* @see OAuth2AuthorizationService
Expand All @@ -74,12 +75,13 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
*/
public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";

private static final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";

private final RegisteredClientRepository registeredClientRepository;
private final OAuth2AuthorizationService authorizationService;
private final RequestMatcher authorizationEndpointMatcher;
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";

/**
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
Expand Down Expand Up @@ -185,15 +187,16 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
return;
}

if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null &&
parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) {
String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (StringUtils.hasText(codeChallengeMethod) &&
parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}

String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) {
if (StringUtils.hasText(codeChallengeMethod) &&
(!"S256".equals(codeChallengeMethod) && !"plain".equals(codeChallengeMethod))) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
*
* @author Joe Grandja
* @author Madhu Bhat
* @author Daniel Garnier-Moiroux
* @since 0.0.1
* @see AuthenticationManager
* @see OAuth2AuthorizationService
Expand Down Expand Up @@ -188,6 +189,12 @@ private static void throwError(String errorCode, String parameterName) {
throw new OAuth2AuthenticationException(error);
}

private static boolean isClientAuthenticated(Authentication clientPrincipal) {
return clientPrincipal != null &&
OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass()) &&
clientPrincipal.isAuthenticated();
}

private static class AuthorizationCodeAuthenticationConverter implements Converter<HttpServletRequest, Authentication> {

@Override
Expand All @@ -200,11 +207,26 @@ public Authentication convert(HttpServletRequest request) {

MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);

// code (REQUIRED)
String code = parameters.getFirst(OAuth2ParameterNames.CODE);
if (!StringUtils.hasText(code) ||
parameters.get(OAuth2ParameterNames.CODE).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CODE);
}

// redirect_uri (REQUIRED)
// Required only if the "redirect_uri" parameter was included in the authorization request
String redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
if (StringUtils.hasText(redirectUri) &&
parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
}

// client_id (REQUIRED)
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
// Required only if the client did not authenticate
String clientId = null;
if (clientPrincipal == null ||
!OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
if (!isClientAuthenticated(clientPrincipal)) {
clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId) ||
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
Expand All @@ -219,21 +241,6 @@ public Authentication convert(HttpServletRequest request) {
}
}

// code (REQUIRED)
String code = parameters.getFirst(OAuth2ParameterNames.CODE);
if (!StringUtils.hasText(code) ||
parameters.get(OAuth2ParameterNames.CODE).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CODE);
}

// redirect_uri (REQUIRED)
// Required only if the "redirect_uri" parameter was included in the authorization request
String redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
if (StringUtils.hasText(redirectUri) &&
parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
}

Map<String, Object> additionalParameters = parameters
.entrySet()
.stream()
Expand Down
Loading

0 comments on commit 3c74c7e

Please sign in to comment.