Skip to content

Commit

Permalink
remoe user info and read remote address from RestRequest
Browse files Browse the repository at this point in the history
Signed-off-by: Chenyang Ji <cyji@amazon.com>
  • Loading branch information
ansjcy committed Mar 7, 2024
1 parent 21dd954 commit fe9c669
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 139 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support for returning scores in matched queries ([#11626](https://github.com/opensearch-project/OpenSearch/pull/11626))
- Add shard id property to SearchLookup for use in field types provided by plugins ([#1063](https://github.com/opensearch-project/OpenSearch/pull/1063))
- Add kuromoji_completion analyzer and filter ([#4835](https://github.com/opensearch-project/OpenSearch/issues/4835))
- [Query insights] Add user info in top queries ([#12529](https://github.com/opensearch-project/OpenSearch/pull/12529))
- [Query insights] Add remote address info in top queries ([#12529](https://github.com/opensearch-project/OpenSearch/pull/12529))

### Dependencies
- Bump `peter-evans/find-comment` from 2 to 3 ([#12288](https://github.com/opensearch-project/OpenSearch/pull/12288))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public Collection<Object> createComponents(
) {
// create top n queries service
final QueryInsightsService queryInsightsService = new QueryInsightsService(threadPool);
return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService, threadPool));
return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import org.opensearch.plugin.insights.rules.model.Attribute;
import org.opensearch.plugin.insights.rules.model.MetricType;
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
import org.opensearch.plugin.insights.utils.ThreadContextParser;
import org.opensearch.threadpool.ThreadPool;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -47,23 +45,16 @@ public final class QueryInsightsListener extends SearchRequestOperationsListener
private static final Logger log = LogManager.getLogger(QueryInsightsListener.class);

private final QueryInsightsService queryInsightsService;
private final ThreadPool threadPool;

/**
* Constructor for QueryInsightsListener
*
* @param clusterService The Node's cluster service.
* @param queryInsightsService The topQueriesByLatencyService associated with this listener
* @param threadPool The OpenSearch thread pool to run async tasks
*/
@Inject
public QueryInsightsListener(
final ClusterService clusterService,
final QueryInsightsService queryInsightsService,
final ThreadPool threadPool
) {
public QueryInsightsListener(final ClusterService clusterService, final QueryInsightsService queryInsightsService) {
this.queryInsightsService = queryInsightsService;
this.threadPool = threadPool;
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(TOP_N_LATENCY_QUERIES_ENABLED, v -> this.setEnableTopQueries(MetricType.LATENCY, v));
clusterService.getClusterSettings()
Expand Down Expand Up @@ -147,8 +138,7 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo
attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards());
attributes.put(Attribute.INDICES, request.indices());
attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap());
// add user related information
attributes.putAll(ThreadContextParser.getUserInfoFromThreadContext(threadPool.getThreadContext()));
attributes.put(Attribute.REMOTE_ADDRESS, searchRequestContext.getRequestRemoteAddress());
SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes);
queryInsightsService.addRecord(record);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,7 @@ public enum Attribute {
/**
* The remote address of this request
*/
REMOTE_ADDRESS,
/**
* Username of the user who sent this request
*/
USER_NAME,
/**
* Backend roles of the user who sent this request
*/
USER_BACKEND_ROLES,
/**
* Roles of the user who sent this request
*/
USER_ROLES,
/**
* Tenant info of the user who sent this request
*/
USER_TENANT;
REMOTE_ADDRESS;

/**
* Read an Attribute from a StreamInput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@
* @opensearch.experimental
*/
public class QueryInsightsSettings {
/**
* Constant setting for user info header key that are injected during authentication
*/
public static final String REQUEST_HEADER_USER_INFO = "_opendistro_security_user_info";
/**
* Constant setting for remote address info header key that are injected during authentication
*/
public static final String REQUEST_HEADER_REMOTE_ADDRESS = "_opendistro_security_remote_address";

/**
* Executors settings
*/
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
import org.opensearch.plugin.insights.core.service.TopQueriesService;
import org.opensearch.plugin.insights.rules.model.Attribute;
Expand All @@ -26,7 +25,6 @@
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.junit.Before;

import java.util.ArrayList;
Expand Down Expand Up @@ -54,14 +52,9 @@ public class QueryInsightsListenerTests extends OpenSearchTestCase {
private final SearchRequest searchRequest = mock(SearchRequest.class);
private final QueryInsightsService queryInsightsService = mock(QueryInsightsService.class);
private final TopQueriesService topQueriesService = mock(TopQueriesService.class);
private final ThreadPool threadPool = mock(ThreadPool.class);
private final Settings.Builder settingsBuilder = Settings.builder();
private final Settings settings = settingsBuilder.build();
private final String remoteAddress = "1.2.3.4";
private final String userName = "user1";
private final List<String> userBackendRoles = List.of("bk-role1", "bk-role2");
private final List<String> userRoles = List.of("role1", "role2");
private final String userTenant = "tenant1";
private ClusterService clusterService;

@Before
Expand All @@ -73,15 +66,7 @@ public void setup() {
clusterService = new ClusterService(settings, clusterSettings, null);
when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true);
when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService);

// inject user info
ThreadContext threadContext = new ThreadContext(settings);
threadContext.putTransient(
QueryInsightsSettings.REQUEST_HEADER_USER_INFO,
userName + '|' + String.join(",", userBackendRoles) + "|" + String.join(",", userRoles) + "|" + userTenant
);
threadContext.putTransient(QueryInsightsSettings.REQUEST_HEADER_REMOTE_ADDRESS, remoteAddress);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(searchRequestContext.getRequestRemoteAddress()).thenReturn(remoteAddress);
}

public void testOnRequestEnd() {
Expand All @@ -101,7 +86,7 @@ public void testOnRequestEnd() {

int numberOfShards = 10;

QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService, threadPool);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService);

when(searchRequest.getOrCreateAbsoluteStartMillis()).thenReturn(timestamp);
when(searchRequest.searchType()).thenReturn(searchType);
Expand All @@ -122,10 +107,6 @@ public void testOnRequestEnd() {
assertEquals(numberOfShards, attrs.get(Attribute.TOTAL_SHARDS));
assertEquals(indices, attrs.get(Attribute.INDICES));
assertEquals(phaseLatencyMap, attrs.get(Attribute.PHASE_LATENCY_MAP));
assertEquals(userName, attrs.get(Attribute.USER_NAME));
assertEquals(userBackendRoles, attrs.get(Attribute.USER_BACKEND_ROLES));
assertEquals(userRoles, attrs.get(Attribute.USER_ROLES));
assertEquals(userTenant, attrs.get(Attribute.USER_TENANT));
assertEquals(remoteAddress, attrs.get(Attribute.REMOTE_ADDRESS));
}

Expand Down Expand Up @@ -162,7 +143,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(numRequests);

for (int i = 0; i < numRequests; i++) {
searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService, threadPool));
searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService));
}

for (int i = 0; i < numRequests; i++) {
Expand All @@ -183,7 +164,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {

public void testSetEnabled() {
when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService, threadPool);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService);
queryInsightsListener.setEnableTopQueries(MetricType.LATENCY, true);
assertTrue(queryInsightsListener.isEnabled());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ public Map<String, Long> phaseTookMap() {
return phaseTookMap;
}

public String getRequestRemoteAddress() {
return searchRequest.remoteAddress().toString();

Check warning on line 55 in server/src/main/java/org/opensearch/action/search/SearchRequestContext.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/search/SearchRequestContext.java#L55

Added line #L55 was not covered by tests
}

SearchResponse.PhaseTook getPhaseTook() {
if (searchRequest != null && searchRequest.isPhaseTook() != null && searchRequest.isPhaseTook()) {
return new SearchResponse.PhaseTook(phaseTookMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.common.Booleans;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.rest.BaseRestHandler;
Expand Down Expand Up @@ -223,6 +224,9 @@ public static void parseSearchRequest(
}

searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));

// set remote address for searchRequest
searchRequest.remoteAddress(new TransportAddress(request.getHttpChannel().getRemoteAddress()));
}

/**
Expand Down

0 comments on commit fe9c669

Please sign in to comment.