Skip to content

Commit

Permalink
Use AssertingPartyMetadata
Browse files Browse the repository at this point in the history
Issue gh-15394
  • Loading branch information
jzheaux committed Jul 20, 2024
1 parent f9d5dda commit 366ab7e
Show file tree
Hide file tree
Showing 25 changed files with 300 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.core.io.ResourceLoader;
import org.springframework.security.converter.RsaKeyConverters;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations;
Expand Down Expand Up @@ -153,7 +154,7 @@ private static Map<String, Map<String, Object>> getAssertingParties(Element elem
}

private static void addVerificationCredentials(Map<String, Object> assertingParty,
RelyingPartyRegistration.AssertingPartyDetails.Builder builder) {
AssertingPartyMetadata.Builder<?> builder) {
List<String> verificationCertificateLocations = (List<String>) assertingParty.get(ELT_VERIFICATION_CREDENTIAL);
List<Saml2X509Credential> verificationCredentials = new ArrayList<>();
for (String certificateLocation : verificationCertificateLocations) {
Expand All @@ -163,7 +164,7 @@ private static void addVerificationCredentials(Map<String, Object> assertingPart
}

private static void addEncryptionCredentials(Map<String, Object> assertingParty,
RelyingPartyRegistration.AssertingPartyDetails.Builder builder) {
AssertingPartyMetadata.Builder<?> builder) {
List<String> encryptionCertificateLocations = (List<String>) assertingParty.get(ELT_ENCRYPTION_CREDENTIAL);
List<Saml2X509Credential> encryptionCredentials = new ArrayList<>();
for (String certificateLocation : encryptionCertificateLocations) {
Expand Down Expand Up @@ -220,8 +221,8 @@ private static RelyingPartyRegistration.Builder getBuilderFromMetadataLocationIf
}
else {
builder = RelyingPartyRegistration.withRegistrationId(registrationId)
.assertingPartyDetails((apBuilder) -> buildAssertingParty(relyingPartyRegistrationElt, assertingParties,
apBuilder, parserContext));
.assertingPartyMetadata((apBuilder) -> buildAssertingParty(relyingPartyRegistrationElt,
assertingParties, apBuilder, parserContext));
}
addRemainingProperties(relyingPartyRegistrationElt, builder);
return builder;
Expand Down Expand Up @@ -260,7 +261,7 @@ private static void addRemainingProperties(Element relyingPartyRegistrationElt,
}

private static void buildAssertingParty(Element relyingPartyElt, Map<String, Map<String, Object>> assertingParties,
RelyingPartyRegistration.AssertingPartyDetails.Builder builder, ParserContext parserContext) {
AssertingPartyMetadata.Builder<?> builder, ParserContext parserContext) {
String assertingPartyId = relyingPartyElt.getAttribute(ATT_ASSERTING_PARTY_ID);
if (!assertingParties.containsKey(assertingPartyId)) {
Object source = parserContext.extractSource(relyingPartyElt);
Expand Down Expand Up @@ -293,7 +294,7 @@ private static void buildAssertingParty(Element relyingPartyElt, Map<String, Map
}

private static void addSigningAlgorithms(Map<String, Object> assertingParty,
RelyingPartyRegistration.AssertingPartyDetails.Builder builder) {
AssertingPartyMetadata.Builder<?> builder) {
String signingAlgorithmsAttr = getAsString(assertingParty, ATT_SIGNING_ALGORITHMS);
if (StringUtils.hasText(signingAlgorithmsAttr)) {
List<String> signingAlgorithms = Arrays.asList(signingAlgorithmsAttr.split(","));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefau
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
}
String assertingPartyEntityId = token.getRelyingPartyRegistration()
.getAssertingPartyDetails()
.getAssertingPartyMetadata()
.getEntityId();
if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
Expand Down Expand Up @@ -775,7 +775,7 @@ private static ValidationContext createValidationContext(AssertionToken assertio
RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration();
String audience = relyingPartyRegistration.getEntityId();
String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation();
String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyMetadata().getEntityId();
Map<String, Object> params = new HashMap<>();
Assertion assertion = assertionToken.getAssertion();
if (assertionContainsInResponseTo(assertion)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ static QueryParametersPartial sign(RelyingPartyRegistration registration) {
private static SignatureSigningParameters resolveSigningParameters(
RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
List<String> algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms();
List<String> algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms();
List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ static VerifierPartial verifySignature(RequestAbstractType object, RelyingPartyR

static SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) {
Set<Credential> credentials = new HashSet<>();
Collection<Saml2X509Credential> keys = registration.getAssertingPartyDetails().getVerificationX509Credentials();
Collection<Saml2X509Credential> keys = registration.getAssertingPartyMetadata()
.getVerificationX509Credentials();
for (Saml2X509Credential key : keys) {
BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(registration.getAssertingPartyDetails().getEntityId());
cred.setEntityId(registration.getAssertingPartyMetadata().getEntityId());
credentials.add(cred);
}
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public Saml2MessageBinding getBinding() {
* @since 5.7
*/
public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
String location = registration.getAssertingPartyMetadata().getSingleSignOnServiceLocation();
return new Builder(registration).authenticationRequestUri(location);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public Saml2MessageBinding getBinding() {
* @since 5.7
*/
public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
String location = registration.getAssertingPartyMetadata().getSingleSignOnServiceLocation();
return new Builder(registration).authenticationRequestUri(location);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ private Consumer<Collection<Saml2Error>> validateIssuer(LogoutRequest request,
return;
}
String issuer = request.getIssuer().getValue();
if (!issuer.equals(registration.getAssertingPartyDetails().getEntityId())) {
if (!issuer.equals(registration.getAssertingPartyMetadata().getEntityId())) {
errors
.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, "Failed to match issuer to configured issuer"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ private Consumer<Collection<Saml2Error>> validateIssuer(LogoutResponse response,
return;
}
String issuer = response.getIssuer().getValue();
if (!issuer.equals(registration.getAssertingPartyDetails().getEntityId())) {
if (!issuer.equals(registration.getAssertingPartyMetadata().getEntityId())) {
errors
.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, "Failed to match issuer to configured issuer"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ private CriteriaSet verificationCriteria(Issuer issuer) {

private SignatureTrustEngine trustEngine(RelyingPartyRegistration registration) {
Set<Credential> credentials = new HashSet<>();
Collection<Saml2X509Credential> keys = registration.getAssertingPartyDetails()
Collection<Saml2X509Credential> keys = registration.getAssertingPartyMetadata()
.getVerificationX509Credentials();
for (Saml2X509Credential key : keys) {
BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
cred.setUsageType(UsageType.SIGNING);
cred.setEntityId(registration.getAssertingPartyDetails().getEntityId());
cred.setEntityId(registration.getAssertingPartyMetadata().getEntityId());
credentials.add(cred);
}
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ public static final class Builder {

private Builder(RelyingPartyRegistration registration) {
this.registration = registration;
this.location = registration.getAssertingPartyDetails().getSingleLogoutServiceLocation();
this.binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding();
this.location = registration.getAssertingPartyMetadata().getSingleLogoutServiceLocation();
this.binding = registration.getAssertingPartyMetadata().getSingleLogoutServiceBinding();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ public static final class Builder {
private Function<Map<String, String>, String> encoder = DEFAULT_ENCODER;

private Builder(RelyingPartyRegistration registration) {
this.location = registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation();
this.binding = registration.getAssertingPartyDetails().getSingleLogoutServiceBinding();
this.location = registration.getAssertingPartyMetadata().getSingleLogoutServiceResponseLocation();
this.binding = registration.getAssertingPartyMetadata().getSingleLogoutServiceBinding();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ static String serialize(XMLObject object) {
}

static <O extends SignableXMLObject> O sign(O object, RelyingPartyRegistration relyingPartyRegistration) {
List<String> algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms();
List<String> algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms();
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
return sign(object, algorithms, credentials);
}
Expand All @@ -103,7 +103,7 @@ static QueryParametersPartial sign(RelyingPartyRegistration registration) {
private static SignatureSigningParameters resolveSigningParameters(
RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
List<String> algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms();
List<String> algorithms = relyingPartyRegistration.getAssertingPartyMetadata().getSigningAlgorithms();
return resolveSigningParameters(algorithms, credentials);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private static Map<String, List<RelyingPartyRegistration>> createMappingByAssert
Collection<RelyingPartyRegistration> rps) {
MultiValueMap<String, RelyingPartyRegistration> result = new LinkedMultiValueMap<>();
for (RelyingPartyRegistration rp : rps) {
result.add(rp.getAssertingPartyDetails().getEntityId(), rp);
result.add(rp.getAssertingPartyMetadata().getEntityId(), rp);
}
return Collections.unmodifiableMap(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
* EntityDescriptor descriptor = openSamlRegistration.getAssertingPartyDetails.getEntityDescriptor();
* }
* </pre> do instead: <pre>
* if (registration.getAssertingPartyDetails() instanceof openSamlAssertingPartyDetails) {
* if (registration.getAssertingPartyMetadata() instanceof openSamlAssertingPartyDetails) {
* EntityDescriptor descriptor = openSamlAssertingPartyDetails.getEntityDescriptor();
* }
* </pre>
Expand Down Expand Up @@ -170,6 +170,11 @@ public Builder assertingPartyDetails(Consumer<AssertingPartyDetails.Builder> ass
return (Builder) super.assertingPartyDetails(assertingPartyDetails);
}

@Override
public Builder assertingPartyMetadata(Consumer<AssertingPartyMetadata.Builder<?>> assertingPartyMetadata) {
return (Builder) super.assertingPartyMetadata(assertingPartyMetadata);
}

/**
* Build an {@link OpenSamlRelyingPartyRegistration}
* {@link org.springframework.security.saml2.provider.service.registration.OpenSamlRelyingPartyRegistration}
Expand Down
Loading

0 comments on commit 366ab7e

Please sign in to comment.