Skip to content

Commit

Permalink
Move PKCE to OAuth2ClientAuthenticationProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrandja committed Oct 6, 2020
1 parent e5fdee3 commit 5c31fb1
Show file tree
Hide file tree
Showing 21 changed files with 774 additions and 533 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ public List<RequestMatcher> getEndpointMatchers() {
public void init(B builder) {
OAuth2ClientAuthenticationProvider clientAuthenticationProvider =
new OAuth2ClientAuthenticationProvider(
getRegisteredClientRepository(builder));
getRegisteredClientRepository(builder),
getAuthorizationService(builder));
builder.authenticationProvider(postProcess(clientAuthenticationProvider));

NimbusJwsEncoder jwtEncoder = new NimbusJwsEncoder(getKeyManager(builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.jose.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
Expand All @@ -42,12 +41,8 @@
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Collections;

/**
Expand Down Expand Up @@ -92,22 +87,13 @@ public Authentication authenticate(Authentication authentication) throws Authent
(OAuth2AuthorizationCodeAuthenticationToken) authentication;

OAuth2ClientAuthenticationToken clientPrincipal = null;
RegisteredClient registeredClient;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
if (!clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
registeredClient = clientPrincipal.getRegisteredClient();
} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
String clientId = authorizationCodeAuthentication.getClientId();
registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
} else {
}
if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();

OAuth2Authorization authorization = this.authorizationService.findByToken(
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
Expand All @@ -127,26 +113,6 @@ public Authentication authenticate(Authentication authentication) throws Authent
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

// Validate PKCE parameters
String codeChallenge = (String) authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE);
if (StringUtils.hasText(codeChallenge)) {
String codeChallengeMethod = (String) authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);

String codeVerifier = (String) authorizationCodeAuthentication
.getAdditionalParameters()
.get(PkceParameterNames.CODE_VERIFIER);

if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
} else if (registeredClient.getClientSettings().requireProofKey()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();

// TODO Allow configuration for issuer claim
Expand Down Expand Up @@ -179,28 +145,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
.build();
this.authorizationService.save(authorization);

return clientPrincipal != null ?
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) :
new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
}

private static boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
if (!StringUtils.hasText(codeVerifier)) {
return false;
} else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) {
return codeVerifier.equals(codeChallenge);
} else if ("S256".equals(codeChallengeMethod)) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
return encodedVerifier.equals(codeChallenge);
} catch (NoSuchAlgorithmException ex) {
// It is unlikely that SHA-256 is not available on the server. If it is not available,
// there will likely be bigger issues as well. We default to SERVER_ERROR.
}
}
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private final String code;
private final Authentication clientPrincipal;
private final String clientId;
private final String redirectUri;
private final Map<String, Object> additionalParameters;

Expand All @@ -58,32 +57,6 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code, Authentication cl
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
this.code = code;
this.clientPrincipal = clientPrincipal;
this.clientId = OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass()) ?
(String) this.clientPrincipal.getPrincipal() :
null;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(
additionalParameters != null ?
additionalParameters :
Collections.emptyMap());
}

/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
*
* @param code the authorization code
* @param clientId the client identifier
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code, String clientId,
@Nullable String redirectUri, @Nullable Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.hasText(clientId, "clientId cannot be empty");
this.code = code;
this.clientPrincipal = null;
this.clientId = clientId;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(
additionalParameters != null ?
Expand All @@ -93,7 +66,7 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code, String clientId,

@Override
public Object getPrincipal() {
return this.clientPrincipal != null ? this.clientPrincipal : this.clientId;
return this.clientPrincipal;
}

@Override
Expand All @@ -110,15 +83,6 @@ public String getCode() {
return this.code;
}

/**
* Returns the client identifier
*
* @return the client identifier
*/
public @Nullable String getClientId() {
return this.clientId;
}

/**
* Returns the redirect uri.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,146 @@
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Map;

/**
* An {@link AuthenticationProvider} implementation that validates {@link OAuth2ClientAuthenticationToken}'s.
* An {@link AuthenticationProvider} implementation used for authenticating an OAuth 2.0 Client.
*
* @author Joe Grandja
* @author Patryk Kostrzewa
* @author Daniel Garnier-Moiroux
* @since 0.0.1
* @see AuthenticationProvider
* @see OAuth2ClientAuthenticationToken
* @see RegisteredClientRepository
* @see OAuth2AuthorizationService
*/
public class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
private final RegisteredClientRepository registeredClientRepository;
private final OAuth2AuthorizationService authorizationService;

/**
* Constructs an {@code OAuth2ClientAuthenticationProvider} using the provided parameters.
*
* @param registeredClientRepository the repository of registered clients
* @param authorizationService the authorization service
*/
public OAuth2ClientAuthenticationProvider(RegisteredClientRepository registeredClientRepository) {
public OAuth2ClientAuthenticationProvider(RegisteredClientRepository registeredClientRepository,
OAuth2AuthorizationService authorizationService) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
Assert.notNull(authorizationService, "authorizationService cannot be null");
this.registeredClientRepository = registeredClientRepository;
this.authorizationService = authorizationService;
}

@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
String clientId = authentication.getPrincipal().toString();
OAuth2ClientAuthenticationToken clientAuthentication =
(OAuth2ClientAuthenticationToken) authentication;

String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
throwInvalidClient();
}

String clientSecret = authentication.getCredentials().toString();
if (!registeredClient.getClientSecret().equals(clientSecret)) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
if (clientAuthentication.getCredentials() != null) {
String clientSecret = clientAuthentication.getCredentials().toString();
// TODO Use PasswordEncoder.matches()
if (!registeredClient.getClientSecret().equals(clientSecret)) {
throwInvalidClient();
}
}

authenticatePkceIfAvailable(clientAuthentication, registeredClient);

return new OAuth2ClientAuthenticationToken(registeredClient);
}

@Override
public boolean supports(Class<?> authentication) {
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
}

private void authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication,
RegisteredClient registeredClient) {

Map<String, Object> parameters = clientAuthentication.getAdditionalParameters();
if (CollectionUtils.isEmpty(parameters) || !authorizationCodeGrant(parameters)) {
return;
}

OAuth2Authorization authorization = this.authorizationService.findByToken(
(String) parameters.get(OAuth2ParameterNames.CODE),
TokenType.AUTHORIZATION_CODE);
if (authorization == null) {
throwInvalidClient();
}

OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);

String codeChallenge = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE);
if (StringUtils.hasText(codeChallenge)) {
String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throwInvalidClient();
}
} else if (registeredClient.getClientSettings().requireProofKey()) {
throwInvalidClient();
}
}

private static boolean authorizationCodeGrant(Map<String, Object> parameters) {
return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(
parameters.get(OAuth2ParameterNames.GRANT_TYPE)) &&
parameters.get(OAuth2ParameterNames.CODE) != null;
}

private static boolean codeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
if (!StringUtils.hasText(codeVerifier)) {
return false;
} else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) {
return codeVerifier.equals(codeChallenge);
} else if ("S256".equals(codeChallengeMethod)) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
return encodedVerifier.equals(codeChallenge);
} catch (NoSuchAlgorithmException ex) {
// It is unlikely that SHA-256 is not available on the server. If it is not available,
// there will likely be bigger issues as well. We default to SERVER_ERROR.
}
}
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
}

private static void throwInvalidClient() {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.springframework.util.Assert;

import java.util.Collections;
import java.util.Map;

/**
* An {@link Authentication} implementation used for OAuth 2.0 Client Authentication.
Expand All @@ -38,20 +39,36 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private String clientId;
private String clientSecret;
private Map<String, Object> additionalParameters;
private RegisteredClient registeredClient;

/**
* Constructs an {@code OAuth2ClientAuthenticationToken} using the provided parameters.
*
* @param clientId the client identifier
* @param clientSecret the client secret
* @param additionalParameters the additional parameters
*/
public OAuth2ClientAuthenticationToken(String clientId, String clientSecret) {
public OAuth2ClientAuthenticationToken(String clientId, String clientSecret,
@Nullable Map<String, Object> additionalParameters) {
this(clientId, additionalParameters);
Assert.hasText(clientSecret, "clientSecret cannot be empty");
this.clientSecret = clientSecret;
}

/**
* Constructs an {@code OAuth2ClientAuthenticationToken} using the provided parameters.
*
* @param clientId the client identifier
* @param additionalParameters the additional parameters
*/
public OAuth2ClientAuthenticationToken(String clientId,
@Nullable Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(clientId, "clientId cannot be empty");
Assert.hasText(clientSecret, "clientSecret cannot be empty");
this.clientId = clientId;
this.clientSecret = clientSecret;
this.additionalParameters = additionalParameters != null ?
Collections.unmodifiableMap(additionalParameters) : null;
}

/**
Expand All @@ -78,6 +95,15 @@ public Object getCredentials() {
return this.clientSecret;
}

/**
* Returns the additional parameters
*
* @return the additional parameters
*/
public @Nullable Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}

/**
* Returns the {@link RegisteredClient registered client}.
*
Expand Down
Loading

0 comments on commit 5c31fb1

Please sign in to comment.