Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE][ML] Stregthen source dest validations for DF analytics #43399

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,8 @@ public void testPutDataFrameAnalyticsConfig() throws Exception {
.setAnalysis(OutlierDetection.createDefault())
.build();

createIndex("put-test-source-index", defaultMappingForTest());

PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
new PutDataFrameAnalyticsRequest(config),
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
Expand Down Expand Up @@ -1243,6 +1245,8 @@ public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
.setAnalysis(OutlierDetection.createDefault())
.build();

createIndex("get-test-source-index", defaultMappingForTest());

PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
new PutDataFrameAnalyticsRequest(config),
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
Expand All @@ -1256,6 +1260,8 @@ public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
}

public void testGetDataFrameAnalyticsConfig_MultipleConfigs() throws Exception {
createIndex("get-test-source-index", defaultMappingForTest());

MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configIdPrefix = "get-test-config-";
int numberOfConfigs = 10;
Expand Down Expand Up @@ -1461,6 +1467,8 @@ public void testDeleteDataFrameAnalyticsConfig() throws Exception {
.setAnalysis(OutlierDetection.createDefault())
.build();

createIndex("delete-test-source-index", defaultMappingForTest());

GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute(
new GetDataFrameAnalyticsRequest(configId + "*"),
machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,8 @@ public void onFailure(Exception e) {
}

public void testGetDataFrameAnalytics() throws Exception {
createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex());

RestHighLevelClient client = highLevelClient();
client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT);
{
Expand Down Expand Up @@ -2849,6 +2851,8 @@ public void onFailure(Exception e) {
}

public void testGetDataFrameAnalyticsStats() throws Exception {
createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex());

RestHighLevelClient client = highLevelClient();
client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT);
{
Expand Down Expand Up @@ -2897,6 +2901,8 @@ public void onFailure(Exception e) {
}

public void testPutDataFrameAnalytics() throws Exception {
createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex());

RestHighLevelClient client = highLevelClient();
{
// tag::put-data-frame-analytics-query-config
Expand Down Expand Up @@ -2988,6 +2994,8 @@ public void onFailure(Exception e) {
}

public void testDeleteDataFrameAnalytics() throws Exception {
createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex());

RestHighLevelClient client = highLevelClient();
client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT);
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.indices.InvalidIndexNameException;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;

import static org.elasticsearch.cluster.metadata.MetaDataCreateIndexService.validateIndexOrAliasName;

public class DataFrameAnalyticsDest implements Writeable, ToXContentObject {

public static final ParseField INDEX = new ParseField("index");
Expand Down Expand Up @@ -90,4 +94,13 @@ public String getIndex() {
public String getResultsField() {
return resultsField;
}

public void validate() {
if (index != null) {
validateIndexOrAliasName(index, InvalidIndexNameException::new);
if (index.toLowerCase(Locale.ROOT).equals(index) == false) {
throw new InvalidIndexNameException(index, "dest.index must be lowercase");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.indices.InvalidIndexNameException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.AbstractSerializingTestCase;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase<DataFrameAnalyticsDest> {

@Override
Expand All @@ -33,4 +37,19 @@ public static DataFrameAnalyticsDest createRandom() {
protected Writeable.Reader<DataFrameAnalyticsDest> instanceReader() {
return DataFrameAnalyticsDest::new;
}

public void testValidate_GivenIndexWithFunkyChars() {
expectThrows(InvalidIndexNameException.class, () -> new DataFrameAnalyticsDest("<script>foo", null).validate());
}

public void testValidate_GivenIndexWithUppercaseChars() {
InvalidIndexNameException e = expectThrows(InvalidIndexNameException.class,
() -> new DataFrameAnalyticsDest("Foo", null).validate());
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
assertThat(e.getMessage(), equalTo("Invalid index name [Foo], dest.index must be lowercase"));
}

public void testValidate_GivenValidIndexName() {
new DataFrameAnalyticsDest("foo_bar_42", null).validate();
}
}
7 changes: 7 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ integTestRunner {
'ml/data_frame_analytics_crud/Test put config with security headers in the body',
'ml/data_frame_analytics_crud/Test put config with inconsistent body/param ids',
'ml/data_frame_analytics_crud/Test put config with invalid id',
'ml/data_frame_analytics_crud/Test put config with invalid dest index name',
'ml/data_frame_analytics_crud/Test put config with pattern dest index name',
'ml/data_frame_analytics_crud/Test put config with missing concrete source index',
'ml/data_frame_analytics_crud/Test put config with missing wildcard source index',
'ml/data_frame_analytics_crud/Test put config with dest index same as source index',
'ml/data_frame_analytics_crud/Test put config with dest index matching multiple indices',
'ml/data_frame_analytics_crud/Test put config with dest index included in source via alias',
'ml/data_frame_analytics_crud/Test put config with unknown top level field',
'ml/data_frame_analytics_crud/Test put config with unknown field in outlier detection analysis',
'ml/data_frame_analytics_crud/Test put config given missing source',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
Expand Down Expand Up @@ -36,9 +37,11 @@
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.permission.ResourcePrivileges;
import org.elasticsearch.xpack.core.security.support.Exceptions;
import org.elasticsearch.xpack.ml.dataframe.SourceDestValidator;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;

import java.io.IOException;
import java.util.Objects;
import java.util.function.Supplier;

public class TransportPutDataFrameAnalyticsAction
Expand All @@ -49,13 +52,16 @@ public class TransportPutDataFrameAnalyticsAction
private final ThreadPool threadPool;
private final SecurityContext securityContext;
private final Client client;
private final ClusterService clusterService;
private final IndexNameExpressionResolver indexNameExpressionResolver;

private volatile ByteSizeValue maxModelMemoryLimit;

@Inject
public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService transportService, ActionFilters actionFilters,
XPackLicenseState licenseState, Client client, ThreadPool threadPool,
ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider) {
ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver,
DataFrameAnalyticsConfigProvider configProvider) {
super(PutDataFrameAnalyticsAction.NAME, transportService, actionFilters,
(Supplier<PutDataFrameAnalyticsAction.Request>) PutDataFrameAnalyticsAction.Request::new);
this.licenseState = licenseState;
Expand All @@ -64,6 +70,8 @@ public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
new SecurityContext(settings, threadPool.getThreadContext()) : null;
this.client = client;
this.clusterService = clusterService;
this.indexNameExpressionResolver = Objects.requireNonNull(indexNameExpressionResolver);

maxModelMemoryLimit = MachineLearningField.MAX_MODEL_MEMORY_LIMIT.get(settings);
clusterService.getClusterSettings()
Expand Down Expand Up @@ -146,5 +154,7 @@ private void validateConfig(DataFrameAnalyticsConfig config) {
throw ExceptionsHelper.badRequestException("id [{}] is too long; must not contain more than {} characters", config.getId(),
MlStrings.ID_LENGTH_LIMIT);
}
config.getDest().validate();
new SourceDestValidator(clusterService.state(), indexNameExpressionResolver).check(config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
import org.elasticsearch.xpack.ml.dataframe.SourceDestValidator;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
Expand Down Expand Up @@ -157,7 +158,10 @@ public void onFailure(Exception e) {

// Validate config
ActionListener<DataFrameAnalyticsConfig> configListener = ActionListener.wrap(
config -> DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, validateListener),
config -> {
new SourceDestValidator(clusterService.state(), indexNameExpressionResolver).check(config);
DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, validateListener);
},
listener::onFailure
);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe;

import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;

public class SourceDestValidator {

private final ClusterState clusterState;
private final IndexNameExpressionResolver indexNameExpressionResolver;

public SourceDestValidator(ClusterState clusterState, IndexNameExpressionResolver indexNameExpressionResolver) {
this.clusterState = Objects.requireNonNull(clusterState);
this.indexNameExpressionResolver = Objects.requireNonNull(indexNameExpressionResolver);
}

public void check(DataFrameAnalyticsConfig config) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it is just a Java thing, but I have trouble with creating an object with state, but an internal method is only called once. Seems like a static check(ClusterState clusterState, IndexNameExpressionResolver indexNameExpressionResolver, DataFrameAnalyticsConfig config) is fewer lines of code and less state to break.

Of course, I am hypocritical with this :). I often make these classes myself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know what you mean. I insist going for objects though because I think they lend themselves better for expansion. On the other hand, when one sees a static method it discourages refactoring to make an object. Of course, only time will tell. My view is that when paradigms are not clearly better or worse than alternatives, we have to try them out and wait for empirical evidence to reward us or slap us in the face :-)

String sourceIndex = config.getSource().getIndex();
String destIndex = config.getDest().getIndex();

String[] sourceExpressions = Strings.tokenizeToStringArray(sourceIndex, ",");

for (String sourceExpression : sourceExpressions) {
if (Regex.simpleMatch(sourceExpression, destIndex)) {
throw ExceptionsHelper.badRequestException("Destination index [{}] must not be included in source index [{}]",
destIndex, sourceExpression);
}
}

Set<String> concreteSourceIndexNames = new HashSet<>(Arrays.asList(indexNameExpressionResolver.concreteIndexNames(clusterState,
IndicesOptions.lenientExpandOpen(), sourceExpressions)));

if (concreteSourceIndexNames.isEmpty()) {
throw ExceptionsHelper.badRequestException("No index matches source index [{}]", sourceIndex);
}

final String[] concreteDestIndexNames = indexNameExpressionResolver.concreteIndexNames(clusterState,
IndicesOptions.lenientExpandOpen(), destIndex);

if (concreteDestIndexNames.length > 1) {
// In case it is an alias, it may match multiple indices
throw ExceptionsHelper.badRequestException("Destination index [{}] should match a single index; matches {}", destIndex,
Arrays.toString(concreteDestIndexNames));
}
if (concreteDestIndexNames.length == 1 && concreteSourceIndexNames.contains(concreteDestIndexNames[0])) {
// In case the dest index is an alias, we need to check the concrete index is not matched by source
throw ExceptionsHelper.badRequestException("Destination index [{}], which is an alias for [{}], " +
"must not be included in source index [{}]", destIndex, concreteDestIndexNames[0], sourceIndex);
}
}
}
Loading