Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve @CurrentSecurityContext meta-annotations #15553

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentRes
CurrentSecurityContextArgumentResolver currentSecurityContextArgumentResolver = new CurrentSecurityContextArgumentResolver();
currentSecurityContextArgumentResolver.setBeanResolver(this.beanResolver);
currentSecurityContextArgumentResolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
currentSecurityContextArgumentResolver.setTemplateDefaults(this.templateDefaults);
argumentResolvers.add(currentSecurityContextArgumentResolver);
argumentResolvers.add(new CsrfTokenArgumentResolver());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ void setCompromisedPasswordChecker(ReactiveCompromisedPasswordChecker compromise

@Bean
static WebFluxConfigurer authenticationPrincipalArgumentResolverConfigurer(
ObjectProvider<AuthenticationPrincipalArgumentResolver> authenticationPrincipalArgumentResolver) {
ObjectProvider<AuthenticationPrincipalArgumentResolver> authenticationPrincipalArgumentResolver,
ObjectProvider<CurrentSecurityContextArgumentResolver> currentSecurityContextArgumentResolvers) {
return new WebFluxConfigurer() {

@Override
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
configurer.addCustomResolver(authenticationPrincipalArgumentResolver.getObject());
configurer.addCustomResolver(authenticationPrincipalArgumentResolver.getObject(),
currentSecurityContextArgumentResolvers.getObject());
}

};
Expand All @@ -133,12 +135,14 @@ AuthenticationPrincipalArgumentResolver authenticationPrincipalArgumentResolver(
}

@Bean
CurrentSecurityContextArgumentResolver reactiveCurrentSecurityContextArgumentResolver() {
CurrentSecurityContextArgumentResolver reactiveCurrentSecurityContextArgumentResolver(
ObjectProvider<AnnotationTemplateExpressionDefaults> templateDefaults) {
CurrentSecurityContextArgumentResolver resolver = new CurrentSecurityContextArgumentResolver(
this.adapterRegistry);
if (this.beanFactory != null) {
resolver.setBeanResolver(new BeanFactoryResolver(this.beanFactory));
}
templateDefaults.ifAvailable(resolver::setTemplateDefaults);
return resolver;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.core.annotation.CurrentSecurityContext;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.csrf.CsrfToken;
Expand Down Expand Up @@ -115,6 +116,15 @@ public void metaAnnotationWhenTemplateDefaultsBeanThenResolvesExpression() throw
this.mockMvc.perform(get("/hi")).andExpect(content().string("Hi, Harold!"));
}

@Test
public void resolveMetaAnnotationWhenTemplateDefaultsBeanThenResolvesExpression() throws Exception {
this.mockMvc.perform(get("/hello")).andExpect(content().string("user"));
Authentication harold = new TestingAuthenticationToken("harold", "password",
AuthorityUtils.createAuthorityList("ROLE_USER"));
SecurityContextHolder.getContext().setAuthentication(harold);
this.mockMvc.perform(get("/hello")).andExpect(content().string("harold"));
}

private ResultMatcher assertResult(Object expected) {
return model().attribute("result", expected);
}
Expand All @@ -128,6 +138,15 @@ private ResultMatcher assertResult(Object expected) {

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@CurrentSecurityContext(expression = "authentication.{property}")
@interface CurrentAuthenticationProperty {

String property();

}

@Controller
static class TestController {

Expand Down Expand Up @@ -158,6 +177,13 @@ String ifUser(@IsUser("harold") boolean isHarold) {
}
}

@GetMapping("/hello")
@ResponseBody
String getCurrentAuthenticationProperty(
@CurrentAuthenticationProperty(property = "principal") String principal) {
return principal;
}

}

@Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.core.annotation.CurrentSecurityContext;
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
Expand Down Expand Up @@ -183,6 +184,27 @@ public void metaAnnotationWhenTemplateDefaultsBeanThenResolvesExpression() throw
.isEqualTo("Hi, Harold!");
}

@Test
public void resoleMetaAnnotationWhenTemplateDefaultsBeanThenResolvesExpression() throws Exception {
this.spring.register(MetaAnnotationPlaceholderConfig.class).autowire();
Authentication user = new TestingAuthenticationToken("user", "password", "ROLE_USER");
this.webClient.mutateWith(mockAuthentication(user))
.get()
.uri("/hello")
.exchange()
.expectStatus()
.isOk()
.expectBody(String.class)
.isEqualTo("user");
Authentication harold = new TestingAuthenticationToken("harold", "password", "ROLE_USER");
this.webClient.mutateWith(mockAuthentication(harold))
.get()
.uri("/hello")
.exchange()
.expectBody(String.class)
.isEqualTo("harold");
}

@Configuration
static class SubclassConfig extends ServerHttpSecurityConfiguration {

Expand Down Expand Up @@ -283,6 +305,15 @@ public Mono<CompromisedPasswordDecision> check(String password) {

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@CurrentSecurityContext(expression = "authentication.{property}")
@interface CurrentAuthenticationProperty {

String property();

}

@RestController
static class TestController {

Expand All @@ -296,6 +327,12 @@ String ifUser(@IsUser("harold") boolean isHarold) {
}
}

@GetMapping("/hello")
String getCurrentAuthenticationProperty(
@CurrentAuthenticationProperty(property = "principal") String principal) {
return principal;
}

}

@Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.security.messaging.handler.invocation.reactive;

import java.lang.annotation.Annotation;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
Expand All @@ -25,7 +27,6 @@
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.expression.BeanResolver;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
Expand All @@ -34,6 +35,9 @@
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AnnotationSynthesizer;
import org.springframework.security.core.annotation.AnnotationSynthesizers;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
import org.springframework.security.core.annotation.CurrentSecurityContext;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
Expand Down Expand Up @@ -88,12 +92,18 @@
* </pre>
*
* @author Rob Winch
* @author DingHao
* @since 5.2
*/
public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgumentResolver {

private final Map<MethodParameter, Annotation> cachedAttributes = new ConcurrentHashMap<>();

private ExpressionParser parser = new SpelExpressionParser();

private AnnotationSynthesizer<CurrentSecurityContext> synthesizer = AnnotationSynthesizers
.requireUnique(CurrentSecurityContext.class);

private BeanResolver beanResolver;

private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance();
Expand All @@ -118,8 +128,7 @@ public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) {

@Override
public boolean supportsParameter(MethodParameter parameter) {
return isMonoSecurityContext(parameter)
|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
return isMonoSecurityContext(parameter) || findMethodAnnotation(parameter) != null;
}

private boolean isMonoSecurityContext(MethodParameter parameter) {
Expand Down Expand Up @@ -149,7 +158,7 @@ public Mono<Object> resolveArgument(MethodParameter parameter, Message<?> messag
}

private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
CurrentSecurityContext contextAnno = findMethodAnnotation(parameter);
if (contextAnno != null) {
return resolveSecurityContextFromAnnotation(contextAnno, parameter, securityContext);
}
Expand Down Expand Up @@ -193,26 +202,28 @@ private boolean isInvalidType(MethodParameter parameter, Object value) {
return !typeToCheck.isAssignableFrom(value.getClass());
}

/**
* Configure CurrentSecurityContext template resolution
* <p>
* By default, this value is <code>null</code>, which indicates that templates should
* not be resolved.
* @param templateDefaults - whether to resolve CurrentSecurityContext templates
* parameters
* @since 6.4
*/
public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDefaults) {
this.synthesizer = AnnotationSynthesizers.requireUnique(CurrentSecurityContext.class, templateDefaults);
}

/**
* Obtains the specified {@link Annotation} on the specified {@link MethodParameter}.
* @param annotationClass the class of the {@link Annotation} to find on the
* {@link MethodParameter}
* @param parameter the {@link MethodParameter} to search for an {@link Annotation}
* @return the {@link Annotation} that was found or null.
*/
private <T extends Annotation> T findMethodAnnotation(Class<T> annotationClass, MethodParameter parameter) {
T annotation = parameter.getParameterAnnotation(annotationClass);
if (annotation != null) {
return annotation;
}
Annotation[] annotationsToSearch = parameter.getParameterAnnotations();
for (Annotation toSearch : annotationsToSearch) {
annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), annotationClass);
if (annotation != null) {
return annotation;
}
}
return null;
@SuppressWarnings("unchecked")
private <T extends Annotation> T findMethodAnnotation(MethodParameter parameter) {
return (T) this.cachedAttributes.computeIfAbsent(parameter,
(methodParameter) -> this.synthesizer.synthesize(methodParameter.getParameter()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@

package org.springframework.security.messaging.handler.invocation.reactive;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.SynthesizingMethodParameter;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
import org.springframework.security.core.annotation.CurrentSecurityContext;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
Expand Down Expand Up @@ -171,6 +175,39 @@ public void resolveArgumentWhenMonoCustomSecurityContextNoAnnotationThenFound()
assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal());
}

@Test
public void resolveArgumentCustomMetaAnnotation() {
Authentication authentication = TestAuthentication.authenticatedUser();
CustomSecurityContext securityContext = new CustomSecurityContext();
securityContext.setAuthentication(authentication);
Mono<UserDetails> result = (Mono<UserDetails>) this.resolver
.resolveArgument(arg0("showUserCustomMetaAnnotation"), null)
.contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
.block();
assertThat(result.block()).isEqualTo(authentication.getPrincipal());
}

@Test
public void resolveArgumentCustomMetaAnnotationTpl() {
this.resolver.setTemplateDefaults(new AnnotationTemplateExpressionDefaults());
Authentication authentication = TestAuthentication.authenticatedUser();
CustomSecurityContext securityContext = new CustomSecurityContext();
securityContext.setAuthentication(authentication);
Mono<UserDetails> result = (Mono<UserDetails>) this.resolver
.resolveArgument(arg0("showUserCustomMetaAnnotationTpl"), null)
.contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
.block();
assertThat(result.block()).isEqualTo(authentication.getPrincipal());
}

private void showUserCustomMetaAnnotation(
@AliasedCurrentSecurityContext(expression = "authentication.principal") Mono<UserDetails> user) {
}

private void showUserCustomMetaAnnotationTpl(
@CurrentAuthenticationProperty(property = "principal") Mono<UserDetails> user) {
}

@SuppressWarnings("unused")
private void monoCustomSecurityContext(Mono<CustomSecurityContext> securityContext) {
}
Expand All @@ -186,6 +223,25 @@ private MethodParameter arg0(String methodName) {

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@CurrentSecurityContext
@interface AliasedCurrentSecurityContext {

@AliasFor(annotation = CurrentSecurityContext.class)
String expression() default "";

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@CurrentSecurityContext(expression = "authentication.{property}")
@interface CurrentAuthenticationProperty {

String property() default "";

}

static class CustomSecurityContext implements SecurityContext {

private Authentication authentication;
Expand Down
Loading