diff --git a/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md index 33f26e080e..e3e0b8bd5e 100644 --- a/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md +++ b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md @@ -84,7 +84,7 @@ POST /_plugins/_ml/connectors/_create "content-type": "application/json" }, "url": "https://bedrock-runtime.us-east-1.amazonaws.com/model/${parameters.model}/invoke", - "request_body": "{\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"${parameters.prompt}\"}]}],\"anthropic_version\":\"${parameters.anthropic_version}\",\"max_tokens\":${parameters.max_tokens_to_sample}}" + "request_body": "{\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"${parameters.inputs}\"}]}],\"anthropic_version\":\"${parameters.anthropic_version}\",\"max_tokens\":${parameters.max_tokens_to_sample}}" } ] } diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 5439d73619..9dbffa918e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import java.util.ArrayList; import java.util.Arrays; @@ -51,6 +52,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -68,6 +70,7 @@ public class CreateControllerTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index e1bcda2b05..00d0cb429f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -8,6 +8,7 @@ import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import java.util.ArrayList; import java.util.Arrays; @@ -41,6 +42,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -58,6 +60,7 @@ public class DeleteControllerTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index 26c59decdf..b10c1b3f4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; @@ -33,6 +34,7 @@ import org.opensearch.ml.common.transport.controller.MLControllerGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; @@ -50,6 +52,7 @@ public class GetControllerTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index dab8410ad0..552be4c342 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +47,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -62,6 +64,7 @@ public class UpdateControllerTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 070d0def34..c55d8d4aca 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -45,6 +45,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.stats.ActionName.REGISTER; import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLNodeUtils.checkOpenCircuitBreaker; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; @@ -1254,6 +1255,9 @@ public synchronized void updateModelCache(String modelId, ActionListener */ public synchronized void deployControllerWithDeployedModel(String modelId, ActionListener listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } if (!modelCacheHelper.isModelDeployed(modelId)) { throw new OpenSearchStatusException( "The model of this model controller has not deployed yet, please deploy the model first.", @@ -1423,6 +1427,9 @@ private synchronized void deployControllerWithDeployingModel( * @param mlModel ml model */ public void deployControllerWithDeployingModel(MLModel mlModel, Integer eligibleNodeCount) { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } if (mlModel.getModelState() != MLModelState.DEPLOYING) { throw new OpenSearchStatusException( "This method should only be called when model is in DEPLOYING state, but the model is in state: " + mlModel.getModelState(), diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 39aaf05ff9..5f5d9ef4cf 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -755,10 +755,10 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); - RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(); - RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction(); - RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(); - RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction(); + RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting); + RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction(mlFeatureEnabledSetting); + RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting); + RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction(mlFeatureEnabledSetting); RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(mlFeatureEnabledSetting); RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(mlFeatureEnabledSetting); RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); @@ -969,7 +969,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, - MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, + MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java index 6eb0041edd..8144080e1e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -20,6 +21,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -29,11 +31,14 @@ public class RestMLCreateControllerAction extends BaseRestHandler { public final static String ML_CREATE_CONTROLLER_ACTION = "ml_create_controller_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLCreateControllerAction() {} + public RestMLCreateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -61,6 +66,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client * @return MLCreateControllerRequest */ private MLCreateControllerRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } + if (!request.hasContent()) { throw new OpenSearchParseException("Create model controller request has empty body"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java index 1524c20d6e..3cb8745ea9 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.io.IOException; @@ -15,6 +16,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -25,9 +27,13 @@ * This class consists of the REST handler to delete ML Model. */ public class RestMLDeleteControllerAction extends BaseRestHandler { + private static final String ML_DELETE_CONTROLLER_ACTION = "ml_delete_controller_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; - public void RestMLDeleteControllerAction() {} + public RestMLDeleteControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -42,6 +48,9 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } String modelId = request.param(PARAMETER_MODEL_ID); MLControllerDeleteRequest mlControllerDeleteRequest = new MLControllerDeleteRequest(modelId); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java index 0df1ce2893..881d75067f 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; import static org.opensearch.ml.utils.RestActionUtils.returnContent; @@ -17,6 +18,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; import org.opensearch.ml.common.transport.controller.MLControllerGetRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -25,12 +27,16 @@ import com.google.common.collect.ImmutableList; public class RestMLGetControllerAction extends BaseRestHandler { + private static final String ML_GET_CONTROLLER_ACTION = "ml_get_controller_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLGetControllerAction() {} + public RestMLGetControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -57,6 +63,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLControllerGetRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } + String modelId = getParameterId(request, PARAMETER_MODEL_ID); boolean returnContent = returnContent(request); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java index 07fa1cc8a9..fd7966f31b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -20,6 +21,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -29,11 +31,14 @@ public class RestMLUpdateControllerAction extends BaseRestHandler { public final static String ML_UPDATE_CONTROLLER_ACTION = "ml_update_controller_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLUpdateControllerAction() {} + public RestMLUpdateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -62,6 +67,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client * @throws IOException if an error occurs while parsing the request */ private MLUpdateControllerRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } + if (!request.hasContent()) { throw new OpenSearchParseException("Update model controller request has empty body"); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index b9d7f2a9fc..4b034c035a 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -243,4 +243,7 @@ private MLCommonsSettings() {} Setting.Property.NodeScope, Setting.Property.Dynamic ); + + public static final Setting ML_COMMONS_CONTROLLER_ENABLED = Setting + .boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index e393b97d24..0fd11e7c72 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; @@ -25,11 +26,14 @@ public class MLFeatureEnabledSetting { private volatile Boolean isLocalModelEnabled; private volatile AtomicBoolean isConnectorPrivateIpEnabled; + private volatile Boolean isControllerEnabled; + public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings); isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings); isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings)); + isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings); clusterService .getClusterSettings() @@ -41,6 +45,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it)); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it); } /** @@ -71,4 +76,12 @@ public AtomicBoolean isConnectorPrivateIpEnabled() { return isConnectorPrivateIpEnabled; } + /** + * Whether the controller feature is enabled. If disabled, APIs in ml-commons will block controller. + * @return whether the controller is enabled. + */ + public Boolean isControllerEnabled() { + return isControllerEnabled; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 68fee24fba..5340edba0f 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -24,6 +24,8 @@ public class MLExceptionUtils { "Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true."; public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG = "Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true."; + public static final String CONTROLLER_DISABLED_ERR_MSG = + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true."; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index c9a4a1a6d5..da31166c6f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -57,6 +57,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -104,6 +105,9 @@ public class CreateControllerTransportActionTests extends OpenSearchTestCase { @Mock MLDeployControllerNodesResponse mlDeployControllerNodesResponse; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -138,7 +142,7 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; - + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); createControllerTransportAction = spy( new CreateControllerTransportAction( transportService, @@ -148,7 +152,8 @@ public void setup() throws IOException { clusterService, modelAccessControlHelper, mlModelCacheHelper, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ) ); @@ -206,6 +211,18 @@ public void testCreateControllerSuccess() { verify(actionListener).onResponse(any(MLCreateControllerResponse.class)); } + @Test + public void testCreateControllerFailedWithControllerFeatureFlagDisabled() { + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testCreateControllerWithTextEmbeddingModelSuccess() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 1e49ab2fd7..d661ea72bd 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -54,6 +54,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -104,6 +105,9 @@ public class DeleteControllerTransportActionTests extends OpenSearchTestCase { @Mock MLUndeployControllerNodesResponse mlUndeployControllerNodesResponse; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -137,7 +141,7 @@ public void setup() throws IOException { ); DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); - + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); deleteControllerTransportAction = spy( new DeleteControllerTransportAction( transportService, @@ -147,7 +151,8 @@ public void setup() throws IOException { clusterService, mlModelManager, mlModelCacheHelper, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); @@ -193,6 +198,18 @@ public void testDeleteControllerSuccess() { verify(actionListener).onResponse(deleteResponse); } + @Test + public void testDeleteControllerFailedWithControllerFeatureFlagDisabled() { + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testDeleteControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index 489e71e080..1af7b81357 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -46,6 +46,7 @@ import org.opensearch.ml.common.transport.controller.MLControllerGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -81,6 +82,9 @@ public class GetControllerTransportActionTests extends OpenSearchTestCase { @Mock MLModel mlModel; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -93,6 +97,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); getControllerTransportAction = spy( new GetControllerTransportAction( transportService, @@ -101,7 +106,8 @@ public void setup() throws IOException { xContentRegistry, clusterService, mlModelManager, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); mlControllerGetRequest = MLControllerGetRequest.builder().modelId("testModelId").build(); @@ -136,6 +142,18 @@ public void testGetControllerSuccess() { verify(actionListener).onResponse(any(MLControllerGetResponse.class)); } + @Test + public void testGetControllerFailedWithControllerFeatureFlagDisabled() { + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testGetControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index fd378647e9..bc54de344b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -58,6 +58,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -105,6 +106,9 @@ public class UpdateControllerTransportActionTests extends OpenSearchTestCase { @Mock MLDeployControllerNodesResponse mlDeployControllerNodesResponse; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -141,7 +145,7 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; - + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); updateControllerTransportAction = spy( new UpdateControllerTransportAction( transportService, @@ -150,7 +154,8 @@ public void setup() throws IOException { clusterService, modelAccessControlHelper, mlModelCacheHelper, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ) ); @@ -216,6 +221,18 @@ public void testUpdateControllerSuccess() { verify(actionListener).onResponse(updateResponse); } + @Test + public void testUpdateControllerFailedWithControllerFeatureFlagDisabled() { + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateControllerWithTextEmbeddingModelSuccess() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 8868aaa0b4..a379325cc0 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -256,7 +256,7 @@ public void setup() throws URISyntaxException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); when(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()).thenReturn(new AtomicBoolean(false)); modelManager = spy( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java index 42b4cbe92c..71965761de 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.HashMap; import java.util.List; @@ -34,6 +35,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -50,6 +52,9 @@ public class RestMLCreateControllerActionTests extends OpenSearchTestCase { private NodeClient client; private ThreadPool threadPool; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @@ -58,7 +63,8 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLCreateControllerAction = new RestMLCreateControllerAction(); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); + restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting); doAnswer(invocation -> { invocation.getArgument(2); return null; @@ -74,7 +80,7 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLCreateControllerAction CreateModelAction = new RestMLCreateControllerAction(); + RestMLCreateControllerAction CreateModelAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting); assertNotNull(CreateModelAction); } @@ -129,6 +135,18 @@ public void testCreateControllerRequestWithNullField() throws Exception { restMLCreateControllerAction.handleRequest(request, channel, client); } + @Test + public void testCreateControllerRequestWithControllerDisabled() throws Exception { + exceptionRule.expect(IllegalStateException.class); + exceptionRule + .expectMessage( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true." + ); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + restMLCreateControllerAction.handleRequest(request, channel, client); + } + private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.POST; String requestContent = "{\"user_rate_limiter\":{\"testUser\":{}}}"; diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java index d3ba17ae5a..7d8b6b30d0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.util.HashMap; @@ -22,6 +23,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -30,6 +32,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -48,12 +51,17 @@ public class RestMLDeleteControllerActionTests extends OpenSearchTestCase { NodeClient client; private ThreadPool threadPool; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @Before public void setup() { - restMLDeleteControllerAction = new RestMLDeleteControllerAction(); + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); + restMLDeleteControllerAction = new RestMLDeleteControllerAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -73,7 +81,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLDeleteControllerAction mlDeleteControllerAction = new RestMLDeleteControllerAction(); + RestMLDeleteControllerAction mlDeleteControllerAction = new RestMLDeleteControllerAction(mlFeatureEnabledSetting); assertNotNull(mlDeleteControllerAction); } @@ -92,6 +100,17 @@ public void testRoutes() { assertEquals("/_plugins/_ml/controllers/{model_id}", route.getPath()); } + public void testDeleteControllerRequestWithControllerDisabled() throws Exception { + thrown.expect(IllegalStateException.class); + thrown + .expectMessage( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true." + ); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + restMLDeleteControllerAction.handleRequest(request, channel, client); + } + public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); restMLDeleteControllerAction.handleRequest(request, channel, client); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java index 2c16533d65..1351ca059c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.util.HashMap; @@ -19,9 +20,11 @@ import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -30,6 +33,7 @@ import org.opensearch.ml.common.transport.controller.MLControllerGetAction; import org.opensearch.ml.common.transport.controller.MLControllerGetRequest; import org.opensearch.ml.common.transport.controller.MLControllerGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -48,12 +52,17 @@ public class RestMLGetControllerActionTests extends OpenSearchTestCase { NodeClient client; private ThreadPool threadPool; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @Before public void setup() { - restMLGetControllerAction = new RestMLGetControllerAction(); + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); + restMLGetControllerAction = new RestMLGetControllerAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -73,7 +82,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetControllerAction mlGetControllerAction = new RestMLGetControllerAction(); + RestMLGetControllerAction mlGetControllerAction = new RestMLGetControllerAction(mlFeatureEnabledSetting); assertNotNull(mlGetControllerAction); } @@ -92,6 +101,18 @@ public void testRoutes() { assertEquals("/_plugins/_ml/controllers/{model_id}", route.getPath()); } + @Test + public void testGetControllerRequestWithControllerDisabled() throws Exception { + thrown.expect(IllegalStateException.class); + thrown + .expectMessage( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true." + ); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + restMLGetControllerAction.handleRequest(request, channel, client); + } + public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); restMLGetControllerAction.handleRequest(request, channel, client); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java index 98ab0f1e73..f2b8a9028a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.HashMap; import java.util.List; @@ -34,6 +35,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -53,12 +55,16 @@ public class RestMLUpdateControllerActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLUpdateControllerAction = new RestMLUpdateControllerAction(); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); + restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting); doAnswer(invocation -> { invocation.getArgument(2); return null; @@ -74,7 +80,7 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLUpdateControllerAction UpdateModelAction = new RestMLUpdateControllerAction(); + RestMLUpdateControllerAction UpdateModelAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting); assertNotNull(UpdateModelAction); } @@ -129,6 +135,18 @@ public void testUpdateControllerRequestWithNullField() throws Exception { restMLUpdateControllerAction.handleRequest(request, channel, client); } + @Test + public void testUpdateControllerRequestWithControllerDisabled() throws Exception { + exceptionRule.expect(IllegalStateException.class); + exceptionRule + .expectMessage( + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true." + ); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + restMLUpdateControllerAction.handleRequest(request, channel, client); + } + private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.PUT; String requestContent = "{\"user_rate_limiter\":{\"testUser\":{}}}";