Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add destination cluster info to response cookie #466

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class HaGatewayConfiguration
private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration();
private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration();
private List<String> statementPaths = ImmutableList.of(V1_STATEMENT_PATH);
private boolean includeClusterHostInResponse;

private RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();

Expand Down Expand Up @@ -244,6 +245,16 @@ public void setAdditionalStatementPaths(List<String> statementPaths)
statementPaths.stream().peek(s -> validateStatementPath(s, statementPaths)).map(s -> s.replaceAll("/+$", ""))).toList();
}

public boolean isIncludeClusterHostInResponse()
{
return includeClusterHostInResponse;
}

public void setIncludeClusterHostInResponse(boolean includeClusterHostInResponse)
{
this.includeClusterHostInResponse = includeClusterHostInResponse;
}

private void validateStatementPath(String statementPath, List<String> statementPaths)
{
if (statementPath.startsWith(V1_STATEMENT_PATH) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class ProxyRequestHandler
private final boolean cookiesEnabled;
private final boolean addXForwardedHeaders;
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;

@Inject
public ProxyRequestHandler(
Expand All @@ -100,6 +101,7 @@ public ProxyRequestHandler(
asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout();
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
statementPaths = haGatewayConfiguration.getStatementPaths();
this.includeClusterInfoInResponse = haGatewayConfiguration.isIncludeClusterHostInResponse();
}

@PreDestroy
Expand Down Expand Up @@ -160,7 +162,8 @@ private void performRequest(
addXForwardedHeaders(servletRequest, requestBuilder);
}

ImmutableList<NewCookie> oauth2GatewayCookie = getOAuth2GatewayCookie(remoteUri, servletRequest);
ImmutableList.Builder<NewCookie> cookieBuilder = ImmutableList.builder();
cookieBuilder.addAll(getOAuth2GatewayCookie(remoteUri, servletRequest));

Request request = requestBuilder
.setPreserveAuthorizationOnRedirect(true)
Expand All @@ -171,11 +174,14 @@ private void performRequest(

if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) {
future = future.transform(response -> recordBackendForQueryId(request, response), executor);
if (includeClusterInfoInResponse) {
cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build());
}
}

setupAsyncResponse(
asyncResponse,
future.transform(response -> buildResponse(response, oauth2GatewayCookie), executor)
future.transform(response -> buildResponse(response, cookieBuilder.build()), executor)
.catching(ProxyException.class, e -> handleProxyException(request, e), directExecutor()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static com.google.common.collect.MoreCollectors.onlyElement;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testcontainers.utility.MountableFile.forClasspathResource;

Expand Down Expand Up @@ -158,6 +159,43 @@ public void testQueryDeliveryToMultipleRoutingGroups()
assertThat(response4.body().string()).contains("http://localhost:" + routerPort);
}

@Test
public void testTrinoClusterHostCookie()
throws Exception
{
RequestBody requestBody = RequestBody.create("SELECT 1", MEDIA_TYPE);

// When X-Trino-Routing-Group is set in header, query should be routed to cluster under the routing group
Request requestWithoutCookie =
new Request.Builder()
.url("http://localhost:" + routerPort + "/v1/statement")
.addHeader("X-Trino-User", "test")
.post(requestBody)
.addHeader("X-Trino-Routing-Group", "scheduled")
.build();
Response responseWithoutCookie = httpClient.newCall(requestWithoutCookie).execute();
assertThat(responseWithoutCookie.body().string()).contains("http://localhost:" + routerPort);
List<Cookie> cookies = Cookie.parseAll(responseWithoutCookie.request().url(), responseWithoutCookie.headers());
Cookie clusterHostCookie = cookies.stream().filter(c -> c.name().equals("trinoClusterHost")).collect(onlyElement());
assertThat(clusterHostCookie.value()).isEqualTo("localhost");

// test with sending the request which includes trinoClusterHost in the cookie
// when X-Trino-Routing-Group is set in header, query should be routed to cluster under the routing group
Request requestWithCookie =
new Request.Builder()
.url("http://localhost:" + routerPort + "/v1/statement")
.addHeader("X-Trino-User", "test")
.post(requestBody)
.addHeader("X-Trino-Routing-Group", "scheduled")
.addHeader("Cookie", "trinoClientHost=foo.example.com")
.build();
Response responseWithCookie = httpClient.newCall(requestWithCookie).execute();
assertThat(responseWithCookie.body().string()).contains("http://localhost:" + routerPort);
List<Cookie> overridenCookies = Cookie.parseAll(responseWithCookie.request().url(), responseWithCookie.headers());
Cookie overridenClusterHostCookie = overridenCookies.stream().filter(c -> c.name().equals("trinoClusterHost")).collect(onlyElement());
assertThat(overridenClusterHostCookie.value()).isEqualTo("localhost");
}

@Test
public void testDeleteQueryId()
throws IOException
Expand Down
1 change: 1 addition & 0 deletions gateway-ha/src/test/resources/test-config-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ serverConfig:
node.environment: test
http-server.http.port: REQUEST_ROUTER_PORT

includeClusterHostInResponse: true
dataStore:
jdbcUrl: jdbc:h2:DB_FILE_PATH
user: sa
Expand Down