Skip to content

Commit

Permalink
Implement Proof Key for Code Exchange (PKCE) RFC 7636
Browse files Browse the repository at this point in the history
  • Loading branch information
Kehrlann authored and jgrandja committed Sep 30, 2020
1 parent 8541f6b commit ab09044
Show file tree
Hide file tree
Showing 13 changed files with 1,008 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
*/
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private final RegisteredClient registeredClient;
private final Authentication clientPrincipal;
private RegisteredClient registeredClient;
private Authentication clientPrincipal;
private final OAuth2AccessToken accessToken;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
Expand All @@ -33,15 +34,20 @@
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Collections;

/**
Expand Down Expand Up @@ -85,29 +91,30 @@ public Authentication authenticate(Authentication authentication) throws Authent
(OAuth2AuthorizationCodeAuthenticationToken) authentication;

OAuth2ClientAuthenticationToken clientPrincipal = null;
RegisteredClient registeredClient = null;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
}
if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
registeredClient = clientPrincipal.getRegisteredClient();
} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
// When the principal is a string, it is the clientId, REQUIRED for public clients
String clientId = authorizationCodeAuthentication.getClientId();
registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
} else {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}

// TODO Authenticate public client
// A client MAY use the "client_id" request parameter to identify itself
// when sending requests to the token endpoint.
// In the "authorization_code" "grant_type" request to the token endpoint,
// an unauthenticated client MUST send its "client_id" to prevent itself
// from inadvertently accepting a code intended for a client with a different "client_id".
// This protects the client from substitution of the authentication code.
if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}

OAuth2Authorization authorization = this.authorizationService.findByToken(
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
if (authorization == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
Expand All @@ -116,6 +123,35 @@ public Authentication authenticate(Authentication authentication) throws Authent
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}


String codeChallenge;
Object codeChallengeParameter = authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE);

if (codeChallengeParameter != null) {
codeChallenge = (String) codeChallengeParameter;

String codeChallengeMethod = (String) authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);

String codeVerifier = (String) authorizationCodeAuthentication
.getAdditionalParameters()
.get(PkceParameterNames.CODE_VERIFIER);

if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
} else if (registeredClient.getClientSettings().requireProofKey()){
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}


JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();

// TODO Allow configuration for issuer claim
Expand All @@ -130,7 +166,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
.issuer(issuer)
.subject(authorization.getPrincipalName())
.audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId()))
.audience(Collections.singletonList(registeredClient.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.notBefore(issuedAt)
Expand All @@ -148,8 +184,30 @@ public Authentication authenticate(Authentication authentication) throws Authent
.build();
this.authorizationService.save(authorization);

return new OAuth2AccessTokenAuthenticationToken(
clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
return clientPrincipal != null ?
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) :
new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
}

private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
if (codeVerifier == null) {
return false;
} else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) {
return codeVerifier.equals(codeChallenge);
} else if ("S256".equals(codeChallengeMethod)) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
return codeChallenge.equals(encodedVerifier);
} catch (NoSuchAlgorithmException e) {
// It is unlikely that SHA-256 is not available on the server. If it is not available,
// there will likely be bigger issues as well. We default to SERVER_ERROR.
}
}

// Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.util.Assert;

import java.util.Collections;
import java.util.Map;

/**
* An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
Expand All @@ -35,26 +36,36 @@
*/
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private String code;
private final String code;
private Authentication clientPrincipal;
private String clientId;
private String redirectUri;
private final String clientId;
private final String redirectUri;
private final Map<String, Object> additionalParameters;

/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
*
* @param code the authorization code
* @param clientPrincipal the authenticated client principal
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code,
Authentication clientPrincipal, @Nullable String redirectUri) {
Authentication clientPrincipal, @Nullable String redirectUri,
Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
this.code = code;
this.clientPrincipal = clientPrincipal;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());

if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) {
this.clientId = (String) this.clientPrincipal.getPrincipal();
} else {
this.clientId = null;
}
}

/**
Expand All @@ -63,15 +74,18 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code,
* @param code the authorization code
* @param clientId the client identifier
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code,
String clientId, @Nullable String redirectUri) {
String clientId, @Nullable String redirectUri,
Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.hasText(clientId, "clientId cannot be empty");
this.code = code;
this.clientId = clientId;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
}

@Override
Expand Down Expand Up @@ -101,4 +115,22 @@ public String getCode() {
public @Nullable String getRedirectUri() {
return this.redirectUri;
}

/**
* Returns the additional parameters
*
* @return the additional parameters
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}

/**
* Returns the client id
*
* @return the client id
*/
public @Nullable String getClientId() {
return this.clientId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ public RegisteredClient build() {
Assert.hasText(this.clientId, "clientId cannot be empty");
Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty");
if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty");
}
if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
Expand Down Expand Up @@ -78,6 +79,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
private final RequestMatcher authorizationEndpointMatcher;
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";

/**
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
Expand Down Expand Up @@ -174,6 +176,34 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
return;
}

// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
if (StringUtils.hasText(codeChallenge)) {
if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}

if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null &&
parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}

String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}
} else if (registeredClient.getClientSettings().requireProofKey()) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}

// ---------------
// The request is valid - ensure the resource owner is authenticated
// ---------------
Expand Down Expand Up @@ -245,8 +275,11 @@ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse r
}

private static OAuth2Error createError(String errorCode, String parameterName) {
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
}

private static OAuth2Error createError(String errorCode, String parameterName, String errorUri) {
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
}

private static boolean isPrincipalAuthenticated(Authentication principal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
Expand All @@ -54,6 +56,7 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
Expand Down Expand Up @@ -198,14 +201,22 @@ public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);

// client_id (REQUIRED)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
Authentication clientPrincipal = null;
if (StringUtils.hasText(clientId)) {
if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
String clientId = null;
if (clientPrincipal == null ||
!OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) {
clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId) ||
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
}
} else {
clientPrincipal = SecurityContextHolder.getContext().getAuthentication();

// code_verifier (REQUIRED for public clients)
String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER);
if (!StringUtils.hasText(codeVerifier) ||
parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER);
}
}

// code (REQUIRED)
Expand All @@ -223,9 +234,19 @@ public Authentication convert(HttpServletRequest request) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
}

return clientPrincipal != null ?
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) :
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri);
Map<String, Object> additionalParameters = parameters
.entrySet()
.stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
!e.getKey().equals(OAuth2ParameterNames.CODE) &&
!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI))
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));


return clientId != null ?
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri, additionalParameters) :
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters);
}
}

Expand Down
Loading

0 comments on commit ab09044

Please sign in to comment.