Skip to content

Commit

Permalink
fix: enable spring subscription callback module using RC version (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
dariuszkuc authored May 17, 2024
1 parent 3efd208 commit 1cd6e03
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 55 deletions.
4 changes: 2 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ graphQLJavaVersion = 22.0
mockWebServerVersion = 4.12.0
protobufVersion = 4.26.1
slf4jVersion = 2.0.13
springBootVersion = 3.2.5
springGraphQLVersion = 1.2.6
springBootVersion = 3.3.0-RC1
springGraphQLVersion = 1.3.0-RC1
reactorVersion = 3.6.6

# test dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
* A GraphQL Java Instrumentation that computes a max age for an operation based on @cacheControl
* directives.
*
* By default, this instrumentation will only set the `Cache-Control` `max-age` value IF positive int
* value is provided. If you would rather want to return explicit `max-age=0` values, you need to
* explicitly opt-in to this behavior by specifying `allowZeroMaxAge=true` constructor value.
* <p>By default, this instrumentation will only set the `Cache-Control` `max-age` value IF positive
* int value is provided. If you would rather want to return explicit `max-age=0` values, you need
* to explicitly opt-in to this behavior by specifying `allowZeroMaxAge=true` constructor value.
*
* <p>You can retrieve the "max-age=..." header value with a {@link GraphQLContext}: <code>
* String cacheControlHeader = CacheControlInstrumentation.cacheControlContext(context);
Expand Down
6 changes: 2 additions & 4 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
rootProject.name = "federation-jvm"

include(":federation-graphql-java-support")
// TODO: disabling spring-subscription-callback module until Spring Boot 3.3 is released
//include(":federation-spring-subscription-callback")
include(":federation-spring-subscription-callback")

project(":federation-graphql-java-support").projectDir = file("graphql-java-support")
// TODO: disabling spring-subscription-callback module until Spring Boot 3.3 is released
//project(":federation-spring-subscription-callback").projectDir = file("spring-subscription-callback")
project(":federation-spring-subscription-callback").projectDir = file("spring-subscription-callback")
8 changes: 7 additions & 1 deletion spring-subscription-callback/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ plugins {
id("com.apollographql.federation.java-conventions")
}

repositories {
mavenCentral()
maven {
url = uri("https://repo.spring.io/milestone")
}
}

val annotationsVersion: String by project
val graphQLJavaVersion: String by project
val mockWebServerVersion: String by project
Expand All @@ -27,7 +34,6 @@ dependencies {
testImplementation("org.springframework.boot", "spring-boot-starter-websocket", springBootVersion)
testImplementation("org.springframework.graphql", "spring-graphql-test", springGraphQLVersion)
testImplementation("io.projectreactor", "reactor-test", reactorVersion)

}

java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public Mono<ServerResponse> handleRequest(@NotNull ServerRequest serverRequest)
serverRequest.uri(),
serverRequest.headers().asHttpHeaders(),
serverRequest.cookies(),
serverRequest.remoteAddress().orElse(null),
serverRequest.attributes(),
body,
serverRequest.exchange().getRequest().getId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import com.apollographql.subscription.callback.SubscriptionCallbackHandler;
import graphql.ExecutionResult;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -20,13 +18,9 @@
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlRequest;
import org.springframework.graphql.server.webmvc.GraphQlHttpHandler;
import org.springframework.http.HttpCookie;
import org.springframework.http.MediaType;
import org.springframework.util.AlternativeJdkIdGenerator;
import org.springframework.util.IdGenerator;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebInputException;
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -79,6 +73,7 @@ public CallbackGraphQlHttpHandler(
serverRequest.uri(),
serverRequest.headers().asHttpHeaders(),
initCookies(serverRequest),
serverRequest.remoteAddress().orElse(null),
serverRequest.attributes(),
readBody(serverRequest),
this.idGenerator.generateId().toString(),
Expand Down Expand Up @@ -145,29 +140,6 @@ public CallbackGraphQlHttpHandler(
}
}

private static MultiValueMap<String, HttpCookie> initCookies(ServerRequest serverRequest) {
MultiValueMap<String, Cookie> source = serverRequest.cookies();
MultiValueMap<String, HttpCookie> target = new LinkedMultiValueMap<>(source.size());
source
.values()
.forEach(
cookieList ->
cookieList.forEach(
cookie -> {
HttpCookie httpCookie = new HttpCookie(cookie.getName(), cookie.getValue());
target.add(cookie.getName(), httpCookie);
}));
return target;
}

private static Map<String, Object> readBody(ServerRequest request) throws ServletException {
try {
return request.body(MAP_PARAMETERIZED_TYPE_REF);
} catch (IOException ex) {
throw new ServerWebInputException("I/O error while reading request body", null, ex);
}
}

private static MediaType selectResponseMediaType(ServerRequest serverRequest) {
for (MediaType accepted : serverRequest.headers().accept()) {
if (SUPPORTED_MEDIA_TYPES.contains(accepted)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.graphql.ExecutionGraphQlRequest;
import org.springframework.graphql.ExecutionGraphQlResponse;
import org.springframework.graphql.ExecutionGraphQlService;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlRequest;
import org.springframework.graphql.server.WebGraphQlResponse;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.support.DefaultExecutionGraphQlResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
Expand All @@ -39,21 +41,16 @@

public class SubscriptionCallbackHandlerTest {

static class MockWebHandler implements WebGraphQlHandler {
static class MockExecutionEngine implements ExecutionGraphQlService {

private final Flux subscriptionFlux;

public MockWebHandler(Flux subscriptionFlux) {
public MockExecutionEngine(Flux subscriptionFlux) {
this.subscriptionFlux = subscriptionFlux;
}

@Override
public WebSocketGraphQlInterceptor getWebSocketInterceptor() {
return null;
}

@Override
public @NotNull Mono<WebGraphQlResponse> handleRequest(@NotNull WebGraphQlRequest request) {
public Mono<ExecutionGraphQlResponse> execute(ExecutionGraphQlRequest request) {
var executionResult = ExecutionResult.newExecutionResult().data(subscriptionFlux).build();
var executionResponse =
new DefaultExecutionGraphQlResponse(request.toExecutionInput(), executionResult);
Expand All @@ -62,6 +59,10 @@ public WebSocketGraphQlInterceptor getWebSocketInterceptor() {
}
}

static WebGraphQlHandler mockHandler(Flux subscriptionFlux) {
return WebGraphQlHandler.builder(new MockExecutionEngine(subscriptionFlux)).build();
}

@Test
public void init_successful() {
var capturedRequests = new ArrayList<String>();
Expand All @@ -81,7 +82,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
var data =
Flux.just(1, 2)
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
Expand Down Expand Up @@ -132,7 +133,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
var data =
Flux.just(1, 2)
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
Expand Down Expand Up @@ -184,7 +185,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
Flux.just(1, 2)
.delayElements(Duration.ofMillis(3000))
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

// note: heartbeat goes into infinite recursion and does not emit value
// TODO update to use virtual timer
Expand Down Expand Up @@ -249,7 +250,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
Flux.just(1, 2)
.delayElements(Duration.ofMillis(3000))
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

// note: heartbeat goes into infinite recursion and does not emit value
// TODO update to use virtual timer
Expand Down Expand Up @@ -280,7 +281,7 @@ public void subscription_success() {
Flux.just(1, 2)
.delayElements(Duration.ofMillis(50))
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
Expand Down Expand Up @@ -309,7 +310,7 @@ public void subscription_success_without_heartbeats() {
Flux.just(1, 2)
.delayElements(Duration.ofMillis(50))
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
Expand Down Expand Up @@ -339,7 +340,7 @@ public void subscription_exception() {
.delayElements(Duration.ofMillis(50))
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build())
.concatWith(Mono.error(new RuntimeException("JUNIT_FAILURE")));
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));
var handler = new SubscriptionCallbackHandler(mockHandler(data));

var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
Expand Down Expand Up @@ -385,6 +386,7 @@ private WebGraphQlRequest stubWebGraphQLRequest(String subscriptionId, String ca
URI.create(callbackUrl),
HttpHeaders.EMPTY,
null,
null,
Collections.emptyMap(),
createMockGraphQLRequest(subscriptionId, callbackUrl),
UUID.randomUUID().toString(),
Expand Down

0 comments on commit 1cd6e03

Please sign in to comment.