Skip to content

Commit

Permalink
Collect warnings in compute service (#103031) (#103079)
Browse files Browse the repository at this point in the history
We have implemented #99927 in DriverRunner. However, we also need to
implement this in ComputeService, where we spawn multiple requests to
avoid losing response headers.

Relates #99927

Closes #100163
Closes #102982
Closes #102871
Closes #103028
  • Loading branch information
dnhatn authored Dec 6, 2023
1 parent 268565e commit bdd2b42
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 33 deletions.
9 changes: 9 additions & 0 deletions docs/changelog/103031.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pr: 103031
summary: Collect warnings in compute service
area: ES|QL
type: bug
issues:
- 100163
- 103028
- 102871
- 102982
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,11 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.tasks.TaskCancelledException;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

/**
Expand All @@ -41,11 +36,10 @@ public DriverRunner(ThreadContext threadContext) {
*/
public void runToCompletion(List<Driver> drivers, ActionListener<Void> listener) {
AtomicReference<Exception> failure = new AtomicReference<>();
AtomicArray<Map<String, List<String>>> responseHeaders = new AtomicArray<>(drivers.size());
var responseHeadersCollector = new ResponseHeadersCollector(threadContext);
CountDown counter = new CountDown(drivers.size());
for (int i = 0; i < drivers.size(); i++) {
Driver driver = drivers.get(i);
int driverIndex = i;
ActionListener<Void> driverListener = new ActionListener<>() {
@Override
public void onResponse(Void unused) {
Expand Down Expand Up @@ -80,9 +74,9 @@ public void onFailure(Exception e) {
}

private void done() {
responseHeaders.setOnce(driverIndex, threadContext.getResponseHeaders());
responseHeadersCollector.collect();
if (counter.countDown()) {
mergeResponseHeaders(responseHeaders);
responseHeadersCollector.finish();
Exception error = failure.get();
if (error != null) {
listener.onFailure(error);
Expand All @@ -96,23 +90,4 @@ private void done() {
start(driver, driverListener);
}
}

private void mergeResponseHeaders(AtomicArray<Map<String, List<String>>> responseHeaders) {
final Map<String, Set<String>> merged = new HashMap<>();
for (int i = 0; i < responseHeaders.length(); i++) {
final Map<String, List<String>> resp = responseHeaders.get(i);
if (resp == null || resp.isEmpty()) {
continue;
}
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
// Use LinkedHashSet to retain the order of the values
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
}
}
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
for (String v : e.getValue()) {
threadContext.addResponseHeader(e.getKey(), v);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ThreadContext;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

/**
* A helper class that can be used to collect and merge response headers from multiple child requests.
*/
public final class ResponseHeadersCollector {
private final ThreadContext threadContext;
private final Queue<Map<String, List<String>>> collected = ConcurrentCollections.newQueue();

public ResponseHeadersCollector(ThreadContext threadContext) {
this.threadContext = threadContext;
}

/**
* Called when a child request is completed to collect the response headers of the responding thread
*/
public void collect() {
Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders();
if (responseHeaders.isEmpty() == false) {
collected.add(responseHeaders);
}
}

/**
* Called when all child requests are completed. This will merge all collected response headers
* from the child requests and restore to the current thread.
*/
public void finish() {
final Map<String, Set<String>> merged = new HashMap<>();
Map<String, List<String>> resp;
while ((resp = collected.poll()) != null) {
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
// Use LinkedHashSet to retain the order of the values
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
}
}
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
for (String v : e.getValue()) {
threadContext.addResponseHeader(e.getKey(), v);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.equalTo;

public class ResponseHeadersCollectorTests extends ESTestCase {

public void testCollect() {
int numThreads = randomIntBetween(1, 10);
TestThreadPool threadPool = new TestThreadPool(
getTestClass().getSimpleName(),
new FixedExecutorBuilder(Settings.EMPTY, "test", numThreads, 1024, "test", EsExecutors.TaskTrackingConfig.DEFAULT)
);
Set<String> expectedWarnings = new HashSet<>();
try {
ThreadContext threadContext = threadPool.getThreadContext();
var collector = new ResponseHeadersCollector(threadContext);
PlainActionFuture<Void> future = new PlainActionFuture<>();
Runnable mergeAndVerify = () -> {
collector.finish();
List<String> actualWarnings = threadContext.getResponseHeaders().getOrDefault("Warnings", List.of());
assertThat(Sets.newHashSet(actualWarnings), equalTo(expectedWarnings));
};
try (RefCountingListener refs = new RefCountingListener(ActionListener.runAfter(future, mergeAndVerify))) {
CyclicBarrier barrier = new CyclicBarrier(numThreads);
for (int i = 0; i < numThreads; i++) {
String warning = "warning-" + i;
expectedWarnings.add(warning);
ActionListener<Void> listener = ActionListener.runBefore(refs.acquire(), collector::collect);
threadPool.schedule(new ActionRunnable<>(listener) {
@Override
protected void doRun() throws Exception {
barrier.await(30, TimeUnit.SECONDS);
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.addResponseHeader("Warnings", warning);
listener.onResponse(null);
}
}
}, TimeValue.timeValueNanos(between(0, 1000_000)), threadPool.executor("test"));
}
}
future.actionGet(TimeValue.timeValueSeconds(30));
} finally {
terminate(threadPool);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.transport.TransportService;

import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug")
public class WarningsIT extends AbstractEsqlIntegTestCase {

public void testCollectWarnings() {
final String node1, node2;
if (randomBoolean()) {
internalCluster().ensureAtLeastNumDataNodes(2);
node1 = randomDataNode().getName();
node2 = randomValueOtherThan(node1, () -> randomDataNode().getName());
} else {
node1 = randomDataNode().getName();
node2 = randomDataNode().getName();
}

int numDocs1 = randomIntBetween(1, 15);
assertAcked(
client().admin()
.indices()
.prepareCreate("index-1")
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node1))
.setMapping("host", "type=keyword")
);
for (int i = 0; i < numDocs1; i++) {
client().prepareIndex("index-1").setSource("host", "192." + i).get();
}
int numDocs2 = randomIntBetween(1, 15);
assertAcked(
client().admin()
.indices()
.prepareCreate("index-2")
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node2))
.setMapping("host", "type=keyword")
);
for (int i = 0; i < numDocs2; i++) {
client().prepareIndex("index-2").setSource("host", "10." + i).get();
}

DiscoveryNode coordinator = randomFrom(clusterService().state().nodes().stream().toList());
client().admin().indices().prepareRefresh("index-1", "index-2").get();

EsqlQueryRequest request = new EsqlQueryRequest();
request.query("FROM index-* | EVAL ip = to_ip(host) | STATS s = COUNT(*) by ip | KEEP ip | LIMIT 100");
request.pragmas(randomPragmas());
PlainActionFuture<EsqlQueryResponse> future = new PlainActionFuture<>();
client(coordinator.getName()).execute(EsqlQueryAction.INSTANCE, request, ActionListener.runBefore(future, () -> {
var threadpool = internalCluster().getInstance(TransportService.class, coordinator.getName()).getThreadPool();
Map<String, List<String>> responseHeaders = threadpool.getThreadContext().getResponseHeaders();
List<String> warnings = responseHeaders.getOrDefault("Warning", List.of())
.stream()
.filter(w -> w.contains("is not an IP string literal"))
.toList();
int expectedWarnings = Math.min(20, numDocs1 + numDocs2);
// we cap the number of warnings per node
assertThat(warnings.size(), greaterThanOrEqualTo(expectedWarnings));
}));
future.actionGet(30, TimeUnit.SECONDS).close();
}

private DiscoveryNode randomDataNode() {
return randomFrom(clusterService().state().nodes().getDataNodes().values());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverTaskRunner;
import org.elasticsearch.compute.operator.ResponseHeadersCollector;
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
Expand Down Expand Up @@ -150,6 +151,8 @@ public void execute(
LOGGER.debug("Sending data node plan\n{}\n with filter [{}]", dataNodePlan, requestFilter);

String[] originalIndices = PlannerUtils.planOriginalIndices(physicalPlan);
var responseHeadersCollector = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext());
listener = ActionListener.runBefore(listener, responseHeadersCollector::finish);
computeTargetNodes(
rootTask,
requestFilter,
Expand All @@ -170,7 +173,16 @@ public void execute(
exchangeSource.addCompletionListener(requestRefs.acquire());
// run compute on the coordinator
var computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null);
runCompute(rootTask, computeContext, coordinatorPlan, cancelOnFailure(rootTask, cancelled, requestRefs.acquire()));
runCompute(
rootTask,
computeContext,
coordinatorPlan,
cancelOnFailure(
rootTask,
cancelled,
ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect)
)
);
// run compute on remote nodes
// TODO: This is wrong, we need to be able to cancel
runComputeOnRemoteNodes(
Expand All @@ -180,7 +192,11 @@ public void execute(
dataNodePlan,
exchangeSource,
targetNodes,
() -> cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(unused -> null)
() -> cancelOnFailure(
rootTask,
cancelled,
ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect)
)
);
}
})
Expand All @@ -194,7 +210,7 @@ private void runComputeOnRemoteNodes(
PhysicalPlan dataNodePlan,
ExchangeSourceHandler exchangeSource,
List<TargetNode> targetNodes,
Supplier<ActionListener<DataNodeResponse>> listener
Supplier<ActionListener<Void>> listener
) {
// Do not complete the exchange sources until we have linked all remote sinks
final SubscribableListener<Void> blockingSinkFuture = new SubscribableListener<>();
Expand Down Expand Up @@ -223,7 +239,7 @@ private void runComputeOnRemoteNodes(
new DataNodeRequest(sessionId, configuration, targetNode.shardIds, targetNode.aliasFilters, dataNodePlan),
rootTask,
TransportRequestOptions.EMPTY,
new ActionListenerResponseHandler<>(delegate, DataNodeResponse::new, esqlExecutor)
new ActionListenerResponseHandler<>(delegate.map(ignored -> null), DataNodeResponse::new, esqlExecutor)
);
})
);
Expand Down Expand Up @@ -442,7 +458,10 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T
runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(unused -> {
// don't return until all pages are fetched
exchangeSink.addCompletionListener(
ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null))
ContextPreservingActionListener.wrapPreservingContext(
ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null)),
transportService.getThreadPool().getThreadContext()
)
);
}, e -> {
exchangeService.finishSinkHandler(sessionId, e);
Expand Down

0 comments on commit bdd2b42

Please sign in to comment.