Skip to content

Commit

Permalink
Defer ExchangeFilterFunction execution in WebClient
Browse files Browse the repository at this point in the history
Prior to this commit, the `DefaultWebClient` would execute the configured
`ExchangeFilterFunction` as the reactive pipeline is assembled during
subscription. This means that if imperative code is executed in a filter
function, it won't be aware of the current observation through the local
scope.

For example, when automatic context propagation is enabled for Reactor
operators, the logger MDC will not know about the current
traceId/spanId.

This commit ensures that client filter functions execution is deferred
during the actual client exchange.

Fixes gh-33559
  • Loading branch information
bclozel committed Sep 18, 2024
1 parent de4ff4b commit 776811b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
1 change: 1 addition & 0 deletions spring-webflux/spring-webflux.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies {
testImplementation(testFixtures(project(":spring-web")))
testImplementation("com.fasterxml:aalto-xml")
testImplementation("com.squareup.okhttp3:mockwebserver")
testImplementation("io.micrometer:context-propagation")
testImplementation("io.micrometer:micrometer-observation-test")
testImplementation("io.projectreactor:reactor-test")
testImplementation("io.reactivex.rxjava3:rxjava")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,9 @@ public Mono<ClientResponse> exchange() {
ClientRequest request = requestBuilder.build();
observationContext.setUriTemplate((String) request.attribute(URI_TEMPLATE_ATTRIBUTE).orElse(null));
observationContext.setRequest(request);
Mono<ClientResponse> responseMono = filterFunction.apply(exchangeFunction)
.exchange(request)
final ExchangeFilterFunction finalFilterFunction = filterFunction;
Mono<ClientResponse> responseMono = Mono.defer(
() -> finalFilterFunction.apply(exchangeFunction).exchange(request))
.checkpoint("Request to " +
WebClientUtils.getRequestDescription(request.method(), request.url()) +
" [DefaultWebClient]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

Expand Down Expand Up @@ -63,6 +65,7 @@ class WebClientObservationTests {

@BeforeEach
void setup() {
Hooks.enableAutomaticContextPropagation();
ClientResponse mockResponse = mock();
when(mockResponse.statusCode()).thenReturn(HttpStatus.OK);
when(mockResponse.headers()).thenReturn(new MockClientHeaders());
Expand All @@ -74,6 +77,11 @@ void setup() {
this.observationRegistry.observationConfig().observationHandler(new HeaderInjectingHandler());
}

@AfterEach
void cleanUp() {
Hooks.disableAutomaticContextPropagation();
}

@Test
void recordsObservationForSuccessfulExchange() {
this.builder.build().get().uri("/resource/{id}", 42)
Expand Down Expand Up @@ -148,6 +156,19 @@ void setsCurrentObservationInReactorContext() {
verifyAndGetRequest();
}

@Test
void setsCurrentObservationInScope() {
ExchangeFilterFunction assertionFilter = (request, chain) -> {
Observation currentObservation = observationRegistry.getCurrentObservation();
assertThat(currentObservation).isNotNull();
assertThat(currentObservation.getContext()).isInstanceOf(ClientRequestObservationContext.class);
return chain.exchange(request);
};
this.builder.filter(assertionFilter).build().get().uri("/resource/{id}", 42)
.retrieve().bodyToMono(Void.class).block(Duration.ofSeconds(5));
verifyAndGetRequest();
}

@Test
void recordsObservationWithResponseDetailsWhenFilterFunctionErrors() {
ExchangeFilterFunction errorFunction = (req, next) -> next.exchange(req).then(Mono.error(new IllegalStateException()));
Expand Down

0 comments on commit 776811b

Please sign in to comment.