Skip to content

Commit

Permalink
[FEATURE][ML] Only write numeric fields to data frame (#35961)
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitris-athanasiou authored Nov 28, 2018
1 parent 3f49eef commit 801665a
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ private void runPipelineAnalytics(String index, ActionListener<AcknowledgedRespo
listener::onFailure
);

// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), index, dataExtractorFactoryListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,29 @@

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class DataFrameDataExtractorFactory {

Expand All @@ -33,6 +39,20 @@ public class DataFrameDataExtractorFactory {
private static final List<String> IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no",
"_source", "_type", "_uid", "_version", "_feature", "_ignored");

/**
* The types supported by data frames
*/
private static final Set<String> COMPATIBLE_FIELD_TYPES;

static {
Set<String> compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values())
.map(NumberFieldMapper.NumberType::typeName)
.collect(Collectors.toSet());
compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module

COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes);
}

private final Client client;
private final String index;
private final ExtractedFields extractedFields;
Expand Down Expand Up @@ -82,10 +102,27 @@ public static void create(Client client, Map<String, String> headers, String ind
});
}

private static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapabilitiesResponse) {
// Visible for testing
static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapabilitiesResponse) {
Set<String> fields = fieldCapabilitiesResponse.get().keySet();
fields.removeAll(IGNORE_FIELDS);
return ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse)
removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse);
ExtractedFields extractedFields = ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse)
.filterFields(ExtractedField.ExtractionMethod.DOC_VALUE);
if (extractedFields.getAllFields().isEmpty()) {
throw ExceptionsHelper.badRequestException("No compatible fields could be detected");
}
return extractedFields;
}

private static void removeFieldsWithIncompatibleTypes(Set<String> fields, FieldCapabilitiesResponse fieldCapabilitiesResponse) {
Iterator<String> fieldsIterator = fields.iterator();
while (fieldsIterator.hasNext()) {
String field = fieldsIterator.next();
Map<String, FieldCapabilities> fieldCaps = fieldCapabilitiesResponse.getField(field);
if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) {
fieldsIterator.remove();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.analytics;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class DataFrameDataExtractorFactoryTests extends ESTestCase {

public void testDetectExtractedFields_GivenFloatField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float").build();

ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities);

List<ExtractedField> allFields = extractedFields.getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields.get(0).getName(), equalTo("some_float"));
}

public void testDetectExtractedFields_GivenNumericFieldWithMultipleTypes() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
.addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float")
.build();

ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities);

List<ExtractedField> allFields = extractedFields.getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields.get(0).getName(), equalTo("some_number"));
}

public void testDetectExtractedFields_GivenNonNumericField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
.addAggregatableField("some_keyword", "keyword").build();

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities));
assertThat(e.getMessage(), equalTo("No compatible fields could be detected"));
}

public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
.addAggregatableField("indecisive_field", "float", "keyword").build();

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities));
assertThat(e.getMessage(), equalTo("No compatible fields could be detected"));
}

public void testDetectExtractedFields_GivenMultipleFields() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.build();

ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities);

List<ExtractedField> allFields = extractedFields.getAllFields();
assertThat(allFields.size(), equalTo(2));
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()),
containsInAnyOrder("some_float", "some_long"));
}

public void testDetectExtractedFields_GivenIgnoredField() {
FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder()
.addAggregatableField("_id", "float").build();

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities));
assertThat(e.getMessage(), equalTo("No compatible fields could be detected"));
}

private static class MockFieldCapsResponseBuilder {

private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) {
Map<String, FieldCapabilities> caps = new HashMap<>();
for (String type : types) {
caps.put(type, new FieldCapabilities(field, type, true, true));
}
fieldCaps.put(field, caps);
return this;
}

private FieldCapabilitiesResponse build() {
FieldCapabilitiesResponse response = mock(FieldCapabilitiesResponse.class);
when(response.get()).thenReturn(fieldCaps);

for (String field : fieldCaps.keySet()) {
when(response.getField(field)).thenReturn(fieldCaps.get(field));
}
return response;
}
}
}

0 comments on commit 801665a

Please sign in to comment.