Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com>
  • Loading branch information
kaushalmahi12 committed Jul 22, 2024
1 parent 48fea42 commit be7e98d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ private void executeRequest(

// At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter
// or HTTP header (HTTP header will be deprecated once ActionFilter is implemented)
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());

PipelinedRequest searchRequest;
ActionListener<SearchResponse> listener;
Expand Down
12 changes: 6 additions & 6 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ public void executeDfsPhase(
ActionListener<SearchPhaseResult> listener
) {
final IndexShard shard = getShard(request);
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());
rewriteAndFetchShardRequest(shard, request, new ActionListener<ShardSearchRequest>() {
@Override
public void onResponse(ShardSearchRequest rewritten) {
Expand Down Expand Up @@ -611,7 +611,7 @@ public void executeQueryPhase(
) {
assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1
: "empty responses require more than one shard";
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());
final IndexShard shard = getShard(request);
rewriteAndFetchShardRequest(shard, request, new ActionListener<ShardSearchRequest>() {
@Override
Expand Down Expand Up @@ -721,7 +721,7 @@ public void executeQueryPhase(
freeReaderContext(readerContext.id());
throw e;
}
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (
Expand All @@ -748,7 +748,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest());
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
readerContext.setAggregatedDfs(request.dfs());
try (
Expand Down Expand Up @@ -799,7 +799,7 @@ public void executeFetchPhase(
) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed;
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());
try {
markAsUsed = readerContext.markAsUsed(getScrollKeepAlive(request.scroll()));
} catch (Exception e) {
Expand Down Expand Up @@ -835,7 +835,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
if (request.lastEmittedDoc() != null) {
Expand Down
14 changes: 11 additions & 3 deletions server/src/main/java/org/opensearch/tasks/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ public TaskId getParentTaskId() {
return parentTask;
}


/**
* Build a status for this task or null if this task doesn't have status.
* Since most tasks don't have status this defaults to returning null. While
Expand Down Expand Up @@ -525,12 +524,21 @@ public String getHeader(String header) {
return headers.get(header);
}

public void addQueryGroupHeadersTo(final ThreadContext threadContext) {
/**
* This method adds the queryGroupHeader in the task headers, We need this method since the query group is not determined at the task creation time
* hence it is not possible to copy this header from request headers. This header is required to group the tasks into queryGroups to account for the QueryGroup level resource footprint
* @param threadContext
*/
public void addQueryGroupHeaders(final ThreadContext threadContext) {
// For now this header will be coming from HTTP headers but in second phase this header

// We will use this constant from QueryGroup Service once the framework changes are done
final String QUERY_GROUP_ID_HEADER = "queryGroupId";
final String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);
String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);

if (requestQueryGroupId == null) {
requestQueryGroupId = "DEFAULT_QUERY_GROUP_ID"; // TODO: move this constant either to QueryGroupService or Tracking equivalent
}

final Map<String, String> newHeaders = new HashMap<>(headers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public void testTaskResourceStats() {
}
}

public void testAddQueryGroupHeadersTo() {
public void testAddQueryGroupHeaders() {
ThreadPool threadPool = new TestThreadPool(getClass().getName());
try {
Task task = new Task(
Expand All @@ -253,7 +253,7 @@ public void testAddQueryGroupHeadersTo() {

threadPool.getThreadContext().putHeader("queryGroupId", "afakgkagj09532059");

task.addQueryGroupHeadersTo(threadPool.getThreadContext());
task.addQueryGroupHeaders(threadPool.getThreadContext());

String queryGroupId = task.getHeader("queryGroupId");

Expand All @@ -262,4 +262,26 @@ public void testAddQueryGroupHeadersTo() {
threadPool.shutdown();
}
}

public void testAddQueryGroupHeadersWhenHeaderIsNotPresentInThreadContext() {
ThreadPool threadPool = new TestThreadPool(getClass().getName());
try {
Task task = new Task(
randomLong(),
"transport",
SearchAction.NAME,
"description",
new TaskId(randomLong() + ":" + randomLong()),
Collections.emptyMap()
);

task.addQueryGroupHeaders(threadPool.getThreadContext());

String queryGroupId = task.getHeader("queryGroupId");

assertEquals("DEFAULT_QUERY_GROUP_ID", queryGroupId);
} finally {
threadPool.shutdown();
}
}
}

0 comments on commit be7e98d

Please sign in to comment.