Skip to content

Commit

Permalink
Refactor async executor service dependencies using guice framework
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vmmusings committed Feb 1, 2024
1 parent e59bf75 commit e2d0ac8
Show file tree
Hide file tree
Showing 17 changed files with 522 additions and 167 deletions.
116 changes: 21 additions & 95 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static java.util.Collections.singletonList;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
Expand Down Expand Up @@ -68,7 +61,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -87,26 +79,13 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;
Expand All @@ -127,7 +106,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
Expand Down Expand Up @@ -223,32 +201,16 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
b -> {
b.bind(NodeClient.class).toInstance((NodeClient) client);
b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings);
b.bind(DataSourceService.class).toInstance(dataSourceService);
b.bind(ClusterService.class).toInstance(clusterService);
});

modules.add(new AsyncExecutorServiceModule());
injector = modules.createInjector();
ClusterManagerEventListener clusterManagerEventListener =
new ClusterManagerEventListener(
Expand All @@ -261,12 +223,15 @@ public Collection<Object> createComponents(
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
dataSourceService,
injector.getInstance(AsyncQueryExecutorService.class),
clusterManagerEventListener,
pluginSettings);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return Collections.singletonList(
return singletonList(
new FixedExecutorBuilder(
settings,
AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME,
Expand Down Expand Up @@ -319,56 +284,17 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
private AsyncQueryExecutorService createAsyncExecutorService() {
ModulesBuilder modulesBuilder = new ModulesBuilder();
modulesBuilder.add(
b -> {
b.bind(NodeClient.class).toInstance(client);
b.bind(org.opensearch.sql.common.setting.Settings.class)
.toInstance(new OpenSearchSettings(clusterService.getClusterSettings()));
b.bind(DataSourceService.class).toInstance(dataSourceService);
});
modulesBuilder.add(new AsyncExecutorServiceModule());
this.injector = modulesBuilder.createInjector();
return this.injector.getInstance(AsyncQueryExecutorService.class);
}
}
3 changes: 2 additions & 1 deletion spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ jacocoTestCoverageVerification {
// TODO: add tests for purging flint indices
'org.opensearch.sql.spark.cluster.ClusterManagerEventListener*',
'org.opensearch.sql.spark.cluster.FlintIndexRetention',
'org.opensearch.sql.spark.cluster.IndexCleanup'
'org.opensearch.sql.spark.cluster.IndexCleanup',
'org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.opensearch.sql.spark.client;

/** Interface for EMRServerlessClientFactory. */
public interface EMRServerlessClientFactory {

/**
* @return {@link EMRServerlessClient}
*/
EMRServerlessClient getClient();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.opensearch.sql.spark.client;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

private final SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private EMRServerlessClient emrServerlessClient;
private String region;

@Override
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
this.emrServerlessClient = createEMRServerlessClient(this.region);
}
return this.emrServerlessClient;
}

private boolean isNewClientCreationRequired(String region) {
return !region.equals(this.region) || emrServerlessClient == null;
}

private void validateSparkExecutionEngineConfig(
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
if (sparkExecutionEngineConfig == null || sparkExecutionEngineConfig.getRegion() == null) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled as %s is not configured in cluster settings. Please"
+ " configure the setting and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}

private EMRServerlessClient createEMRServerlessClient(String awsRegion) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(awsRegion)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
Expand All @@ -18,6 +16,7 @@
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand All @@ -35,13 +34,12 @@
@AllArgsConstructor
public class SparkQueryDispatcher {

private static final Logger LOG = LogManager.getLogger();
public static final String INDEX_TAG_KEY = "index";
public static final String DATASOURCE_TAG_KEY = "datasource";
public static final String CLUSTER_NAME_TAG_KEY = "domain_ident";
public static final String JOB_TYPE_TAG_KEY = "type";

private EMRServerlessClient emrServerlessClient;
private EMRServerlessClientFactory emrServerlessClientFactory;

private DataSourceService dataSourceService;

Expand All @@ -60,10 +58,10 @@ public class SparkQueryDispatcher {
private StateStore stateStore;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource());
dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata);

AsyncQueryHandler asyncQueryHandler =
sessionManager.isEnabled()
? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
Expand All @@ -83,7 +81,7 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
contextBuilder.indexQueryDetails(indexQueryDetails);

if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) {
asyncQueryHandler = createIndexDMLHandler();
asyncQueryHandler = createIndexDMLHandler(emrServerlessClient);
} else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())
&& indexQueryDetails.isAutoRefresh()) {
asyncQueryHandler =
Expand All @@ -99,32 +97,34 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
if (asyncQueryJobMetadata.getSessionId() != null) {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return createIndexDMLHandler().getQueryResponse(asyncQueryJobMetadata);
return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata);
} else {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
}
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
AsyncQueryHandler queryHandler;
if (asyncQueryJobMetadata.getSessionId() != null) {
queryHandler =
new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
queryHandler = createIndexDMLHandler();
queryHandler = createIndexDMLHandler(emrServerlessClient);
} else {
queryHandler =
new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
}
return queryHandler.cancelJob(asyncQueryJobMetadata);
}

private IndexDMLHandler createIndexDMLHandler() {
private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
dataSourceService,
Expand Down
Loading

0 comments on commit e2d0ac8

Please sign in to comment.