diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 98766633c0f3f..8dc62e2efecbb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -15,10 +15,14 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.FieldAliasMapper; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Arrays; @@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis { public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); + public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors"); private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; @@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis { */ public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30; + @SuppressWarnings("unchecked") private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), @@ -70,7 +76,8 @@ private static ConstructingObjectParser createParser(boole (ClassAssignmentObjective) a[8], (Integer) a[9], (Double) a[10], - (Long) a[11])); + (Long) a[11], + (List) a[12])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -78,6 +85,12 @@ private static ConstructingObjectParser createParser(boole parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); + parser.declareNamedObjects(optionalConstructorArg(), + (p, c, n) -> lenient ? + p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) : + p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)), + (classification) -> {/*TODO should we throw if this is not set?*/}, + FEATURE_PROCESSORS); return parser; } @@ -117,6 +130,7 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU private final int numTopClasses; private final double trainingPercent; private final long randomizeSeed; + private final List featureProcessors; public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @@ -124,7 +138,8 @@ public Classification(String dependentVariable, @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable Integer numTopClasses, @Nullable Double trainingPercent, - @Nullable Long randomizeSeed) { + @Nullable Long randomizeSeed, + @Nullable List featureProcessors) { if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); } @@ -139,10 +154,11 @@ public Classification(String dependentVariable, this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; + this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors); } public Classification(String dependentVariable) { - this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null); } public Classification(StreamInput in) throws IOException { @@ -161,6 +177,11 @@ public Classification(StreamInput in) throws IOException { } else { randomizeSeed = Randomness.get().nextLong(); } + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class)); + } else { + featureProcessors = Collections.emptyList(); + } } public String getDependentVariable() { @@ -191,6 +212,10 @@ public long getRandomizeSeed() { return randomizeSeed; } + public List getFeatureProcessors() { + return featureProcessors; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -209,6 +234,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_7_6_0)) { out.writeOptionalLong(randomizeSeed); } + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeNamedWriteableList(featureProcessors); + } } @Override @@ -227,6 +255,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (version.onOrAfter(Version.V_7_6_0)) { builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); } + if (featureProcessors.isEmpty() == false) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors); + } builder.endObject(); return builder; } @@ -247,6 +278,10 @@ public Map getParams(FieldInfo fieldInfo) { } params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent); + if (featureProcessors.isEmpty() == false) { + params.put(FEATURE_PROCESSORS.getPreferredName(), + featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList())); + } return params; } @@ -388,6 +423,7 @@ public boolean equals(Object o) { && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(classAssignmentObjective, that.classAssignmentObjective) && Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(featureProcessors, that.featureProcessors) && trainingPercent == that.trainingPercent && randomizeSeed == that.randomizeSeed; } @@ -395,7 +431,7 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective, - numTopClasses, trainingPercent, randomizeSeed); + numTopClasses, trainingPercent, randomizeSeed, featureProcessors); } public enum ClassAssignmentObjective { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 352807b78bbf0..6f7a47f291234 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -15,9 +15,13 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.util.Arrays; @@ -28,6 +32,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis { public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); public static final ParseField LOSS_FUNCTION = new ParseField("loss_function"); public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter"); + public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors"); private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1"; private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + @SuppressWarnings("unchecked") private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), @@ -59,7 +66,8 @@ private static ConstructingObjectParser createParser(boolean l (Double) a[8], (Long) a[9], (LossFunction) a[10], - (Double) a[11])); + (Double) a[11], + (List) a[12])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -67,6 +75,12 @@ private static ConstructingObjectParser createParser(boolean l parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION); parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER); + parser.declareNamedObjects(optionalConstructorArg(), + (p, c, n) -> lenient ? + p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) : + p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)), + (regression) -> {/*TODO should we throw if this is not set?*/}, + FEATURE_PROCESSORS); return parser; } @@ -90,6 +104,7 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno private final long randomizeSeed; private final LossFunction lossFunction; private final Double lossFunctionParameter; + private final List featureProcessors; public Regression(String dependentVariable, BoostedTreeParams boostedTreeParams, @@ -97,7 +112,8 @@ public Regression(String dependentVariable, @Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction, - @Nullable Double lossFunctionParameter) { + @Nullable Double lossFunctionParameter, + @Nullable List featureProcessors) { if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); } @@ -112,10 +128,11 @@ public Regression(String dependentVariable, throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName()); } this.lossFunctionParameter = lossFunctionParameter; + this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors); } public Regression(String dependentVariable) { - this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null); } public Regression(StreamInput in) throws IOException { @@ -126,6 +143,11 @@ public Regression(StreamInput in) throws IOException { randomizeSeed = in.readOptionalLong(); lossFunction = in.readEnum(LossFunction.class); lossFunctionParameter = in.readOptionalDouble(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class)); + } else { + featureProcessors = Collections.emptyList(); + } } public String getDependentVariable() { @@ -156,6 +178,10 @@ public Double getLossFunctionParameter() { return lossFunctionParameter; } + public List getFeatureProcessors() { + return featureProcessors; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -170,6 +196,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalLong(randomizeSeed); out.writeEnum(lossFunction); out.writeOptionalDouble(lossFunctionParameter); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeNamedWriteableList(featureProcessors); + } } @Override @@ -190,6 +219,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (lossFunctionParameter != null) { builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter); } + if (featureProcessors.isEmpty() == false) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors); + } builder.endObject(); return builder; } @@ -207,6 +239,10 @@ public Map getParams(FieldInfo fieldInfo) { if (lossFunctionParameter != null) { params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter); } + if (featureProcessors.isEmpty() == false) { + params.put(FEATURE_PROCESSORS.getPreferredName(), + featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList())); + } return params; } @@ -290,13 +326,14 @@ public boolean equals(Object o) { && trainingPercent == that.trainingPercent && randomizeSeed == that.randomizeSeed && lossFunction == that.lossFunction + && Objects.equals(featureProcessors, that.featureProcessors) && Objects.equals(lossFunctionParameter, that.lossFunctionParameter); } @Override public int hashCode() { return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, - lossFunctionParameter); + lossFunctionParameter, featureProcessors); } public enum LossFunction { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 9f225a997dbc4..2ba7f114e8b41 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -57,23 +57,23 @@ public List getNamedXContentParsers() { // PreProcessing Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME, - OneHotEncoding::fromXContentLenient)); + (p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME, - TargetMeanEncoding::fromXContentLenient)); + (p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME, - FrequencyEncoding::fromXContentLenient)); + (p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME, - CustomWordEmbedding::fromXContentLenient)); + (p, c) -> CustomWordEmbedding.fromXContentLenient(p))); // PreProcessing Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME, - OneHotEncoding::fromXContentStrict)); + (p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME, - TargetMeanEncoding::fromXContentStrict)); + (p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME, - FrequencyEncoding::fromXContentStrict)); + (p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c))); namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME, - CustomWordEmbedding::fromXContentStrict)); + (p, c) -> CustomWordEmbedding.fromXContentStrict(p))); // Model Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index 9493572b8386a..24ec52b3650fe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -56,8 +56,8 @@ private static ObjectParser createParser(b TRAINED_MODEL); parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors, (p, c, n) -> ignoreUnknownFields ? - p.namedObject(LenientlyParsedPreProcessor.class, n, null) : - p.namedObject(StrictlyParsedPreProcessor.class, n, null), + p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) : + p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT), (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), PREPROCESSORS); return parser; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 8518caf66eb70..bac61d9b8ef94 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -50,15 +50,15 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights"); public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3])); + (a, c) -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3])); parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { @@ -123,11 +123,11 @@ private static List> parseArrays(String fieldName, } public static CustomWordEmbedding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT); } public static CustomWordEmbedding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + return LENIENT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT); } private static final int CONCAT_LAYER_SIZE = 80; @@ -256,6 +256,11 @@ public boolean isCustom() { return false; } + @Override + public String getOutputFieldType(String outputField) { + return "dense_vector"; + } + @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 4179e09be0ed7..f04c2f291b69b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -36,15 +37,18 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map"); public static final ParseField CUSTOM = new ParseField("custom"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map)a[2], (Boolean)a[3])); + (a, c) -> new FrequencyEncoding((String)a[0], + (String)a[1], + (Map)a[2], + a[3] == null ? c.isCustomByDefault() : (Boolean)a[3])); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareObject(ConstructingObjectParser.constructorArg(), @@ -54,12 +58,12 @@ private static ConstructingObjectParser createParser(bo return parser; } - public static FrequencyEncoding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + public static FrequencyEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } - public static FrequencyEncoding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + public static FrequencyEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } private final String field; @@ -112,6 +116,11 @@ public boolean isCustom() { return custom; } + @Override + public String getOutputFieldType(String outputField) { + return NumberFieldMapper.NumberType.DOUBLE.typeName(); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index 6b56c488ce6bd..39d8d90dea7bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -35,27 +36,29 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars public static final ParseField HOT_MAP = new ParseField("hot_map"); public static final ParseField CUSTOM = new ParseField("custom"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new OneHotEncoding((String)a[0], (Map)a[1], (Boolean)a[2])); + (a, c) -> new OneHotEncoding((String)a[0], + (Map)a[1], + a[2] == null ? c.isCustomByDefault() : (Boolean)a[2])); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP); parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM); return parser; } - public static OneHotEncoding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + public static OneHotEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } - public static OneHotEncoding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + public static OneHotEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } private final String field; @@ -98,6 +101,11 @@ public boolean isCustom() { return custom; } + @Override + public String getOutputFieldType(String outputField) { + return NumberFieldMapper.NumberType.INTEGER.typeName(); + } + @Override public String getName() { return NAME.getPreferredName(); @@ -119,8 +127,9 @@ public void process(Map fields) { if (value == null) { return; } + final String stringValue = value.toString(); hotMap.forEach((val, col) -> { - int encoding = value.toString().equals(val) ? 1 : 0; + int encoding = stringValue.equals(val) ? 1 : 0; fields.put(col, encoding); }); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java index c5605af6295b0..596664773704c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java @@ -18,6 +18,18 @@ */ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable { + class PreProcessorParseContext { + public static final PreProcessorParseContext DEFAULT = new PreProcessorParseContext(false); + final boolean defaultIsCustomValue; + public PreProcessorParseContext(boolean defaultIsCustomValue) { + this.defaultIsCustomValue = defaultIsCustomValue; + } + + public boolean isCustomByDefault() { + return defaultIsCustomValue; + } + } + /** * The expected input fields */ @@ -48,4 +60,6 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou */ boolean isCustom(); + String getOutputFieldType(String outputField); + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 0c0fd0814c2cb..ccaf14984cfae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -36,15 +37,19 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly public static final ParseField DEFAULT_VALUE = new ParseField("default_value"); public static final ParseField CUSTOM = new ParseField("custom"); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>( + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map)a[2], (Double)a[3], (Boolean)a[4])); + (a, c) -> new TargetMeanEncoding((String)a[0], + (String)a[1], + (Map)a[2], + (Double)a[3], + a[4] == null ? c.isCustomByDefault() : (Boolean)a[4])); parser.declareString(ConstructingObjectParser.constructorArg(), FIELD); parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareObject(ConstructingObjectParser.constructorArg(), @@ -55,12 +60,12 @@ private static ConstructingObjectParser createParser(b return parser; } - public static TargetMeanEncoding fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null); + public static TargetMeanEncoding fromXContentStrict(XContentParser parser, PreProcessorParseContext context) { + return STRICT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } - public static TargetMeanEncoding fromXContentLenient(XContentParser parser) { - return LENIENT_PARSER.apply(parser, null); + public static TargetMeanEncoding fromXContentLenient(XContentParser parser, PreProcessorParseContext context) { + return LENIENT_PARSER.apply(parser, context == null ? PreProcessorParseContext.DEFAULT : context); } private final String field; @@ -123,6 +128,11 @@ public boolean isCustom() { return custom; } + @Override + public String getOutputFieldType(String outputField) { + return NumberFieldMapper.NumberType.DOUBLE.typeName(); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index 7739acf1ad749..0e95b66dd1532 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -41,7 +41,7 @@ public class InferenceDefinition { (p, c, n) -> p.namedObject(InferenceModel.class, n, null), TRAINED_MODEL); PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors, - (p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, null), + (p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT), (trainedModelDefBuilder) -> {}, PREPROCESSORS); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index ecd21df1f115c..f9bfcf1203a71 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -326,12 +326,14 @@ public final class ReservedFieldNames { Regression.LOSS_FUNCTION_PARAMETER.getPreferredName(), Regression.PREDICTION_FIELD_NAME.getPreferredName(), Regression.TRAINING_PERCENT.getPreferredName(), + Regression.FEATURE_PROCESSORS.getPreferredName(), Classification.NAME.getPreferredName(), Classification.DEPENDENT_VARIABLE.getPreferredName(), Classification.PREDICTION_FIELD_NAME.getPreferredName(), Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), Classification.NUM_TOP_CLASSES.getPreferredName(), Classification.TRAINING_PERCENT.getPreferredName(), + Classification.FEATURE_PROCESSORS.getPreferredName(), BoostedTreeParams.LAMBDA.getPreferredName(), BoostedTreeParams.GAMMA.getPreferredName(), BoostedTreeParams.ETA.getPreferredName(), diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json index 927e9092ada66..747de52a5aec7 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json @@ -34,6 +34,9 @@ "feature_bag_fraction" : { "type" : "double" }, + "feature_processors": { + "enabled": false + }, "gamma" : { "type" : "double" }, @@ -84,6 +87,9 @@ "feature_bag_fraction" : { "type" : "double" }, + "feature_processors": { + "enabled": false + }, "gamma" : { "type" : "double" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java index a82c3e6b957a2..fdbb06fcbe384 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.util.ArrayList; import java.util.Collections; @@ -27,6 +28,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractWireSerial protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -35,6 +37,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java index 038c35fdd2930..cc01ffaf214d2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.junit.Before; import java.util.ArrayList; @@ -43,6 +44,7 @@ public void setUpId() { protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -51,6 +53,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java index 6338e031070b5..949409a772e0f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import java.util.ArrayList; import java.util.Collections; @@ -24,6 +25,7 @@ public class PutDataFrameAnalyticsActionResponseTests extends AbstractWireSerial protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider() .getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 8068d03456b29..c3645b5a2b0ff 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.junit.Before; @@ -78,6 +79,7 @@ protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -86,6 +88,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } @@ -147,14 +150,16 @@ protected void assertOnBWCObject(DataFrameAnalyticsConfig bwcSerializedObject, D bwcRegression.getTrainingPercent(), 42L, bwcRegression.getLossFunction(), - bwcRegression.getLossFunctionParameter()); + bwcRegression.getLossFunctionParameter(), + bwcRegression.getFeatureProcessors()); testAnalysis = new Regression(testRegression.getDependentVariable(), testRegression.getBoostedTreeParams(), testRegression.getPredictionFieldName(), testRegression.getTrainingPercent(), 42L, testRegression.getLossFunction(), - testRegression.getLossFunctionParameter()); + testRegression.getLossFunctionParameter(), + bwcRegression.getFeatureProcessors()); } else { Classification testClassification = (Classification)testInstance.getAnalysis(); Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis(); @@ -164,14 +169,16 @@ protected void assertOnBWCObject(DataFrameAnalyticsConfig bwcSerializedObject, D bwcClassification.getClassAssignmentObjective(), bwcClassification.getNumTopClasses(), bwcClassification.getTrainingPercent(), - 42L); + 42L, + bwcClassification.getFeatureProcessors()); testAnalysis = new Classification(testClassification.getDependentVariable(), testClassification.getBoostedTreeParams(), testClassification.getPredictionFieldName(), testClassification.getClassAssignmentObjective(), testClassification.getNumTopClasses(), testClassification.getTrainingPercent(), - 42L); + 42L, + testClassification.getFeatureProcessors()); } super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject) .setAnalysis(bwcAnalysis) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index d0a3ea2c71886..081fbe0c6ec9e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -8,25 +8,41 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; @@ -55,6 +71,21 @@ protected Classification createTestInstance() { return createRandom(); } + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + public static Classification createRandom() { String dependentVariableName = randomAlphaOfLength(10); BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); @@ -65,7 +96,14 @@ public static Classification createRandom() { Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); Long randomizeSeed = randomBoolean() ? null : randomLong(); return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective, - numTopClasses, trainingPercent, randomizeSeed); + numTopClasses, trainingPercent, randomizeSeed, + randomBoolean() ? + null : + Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true), + OneHotEncodingTests.createRandom(true), + TargetMeanEncodingTests.createRandom(true))) + .limit(randomIntBetween(0, 5)) + .collect(Collectors.toList())); } public static Classification mutateForVersion(Classification instance, Version version) { @@ -75,7 +113,8 @@ public static Classification mutateForVersion(Classification instance, Version v version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null, instance.getNumTopClasses(), instance.getTrainingPercent(), - instance.getRandomizeSeed()); + instance.getRandomizeSeed(), + version.onOrAfter(Version.V_8_0_0) ? instance.getFeatureProcessors() : Collections.emptyList()); } @Override @@ -91,14 +130,16 @@ protected void assertOnBWCObject(Classification bwcSerializedObject, Classificat bwcSerializedObject.getClassAssignmentObjective(), bwcSerializedObject.getNumTopClasses(), bwcSerializedObject.getTrainingPercent(), - 42L); + 42L, + bwcSerializedObject.getFeatureProcessors()); Classification newInstance = new Classification(testInstance.getDependentVariable(), testInstance.getBoostedTreeParams(), testInstance.getPredictionFieldName(), testInstance.getClassAssignmentObjective(), testInstance.getNumTopClasses(), testInstance.getTrainingPercent(), - 42L); + 42L, + testInstance.getFeatureProcessors()); super.assertOnBWCObject(newBwc, newInstance, version); } @@ -107,87 +148,138 @@ protected Writeable.Reader instanceReader() { return Classification::new; } + public void testDeserialization() throws IOException { + String toDeserialize = "{\n" + + " \"dependent_variable\": \"FlightDelayMin\",\n" + + " \"feature_processors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"OriginWeather\",\n" + + " \"hot_map\": {\n" + + " \"sunny_col\": \"Sunny\",\n" + + " \"clear_col\": \"Clear\",\n" + + " \"rainy_col\": \"Rain\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"DestWeather\",\n" + + " \"hot_map\": {\n" + + " \"dest_sunny_col\": \"Sunny\",\n" + + " \"dest_clear_col\": \"Clear\",\n" + + " \"dest_rainy_col\": \"Rain\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"OriginWeather\",\n" + + " \"feature_name\": \"mean\",\n" + + " \"frequency_map\": {\n" + + " \"Sunny\": 0.8,\n" + + " \"Rain\": 0.2\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }" + + ""; + + try(XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(toDeserialize), + XContentType.JSON)) { + Classification parsed = Classification.fromXContent(parser, false); + assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin")); + for (PreProcessor preProcessor : parsed.getFeatureProcessors()) { + assertThat(preProcessor.isCustom(), is(true)); + } + } + } + + public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null)); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null)); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testGetPredictionFieldName() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null); assertThat(classification.getPredictionFieldName(), equalTo("result")); - classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); } public void testClassAssignmentObjective() { Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", - Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong()); + Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", - Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong()); + Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); // class_assignment_objective == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null); assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); } public void testGetNumTopClasses() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(7)); // Boundary condition: num_top_classes == 0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(0)); // Boundary condition: num_top_classes == 1000 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(1000)); // num_top_classes == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null); assertThat(classification.getNumTopClasses(), equalTo(2)); } public void testGetTrainingPercent() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null); assertThat(classification.getTrainingPercent(), equalTo(100.0)); } @@ -231,6 +323,7 @@ public void testGetParams() { null, null, 50.0, + null, null).getParams(fieldInfo), equalTo( Map.of( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index 0d695d0fbbde4..7ac70f8c4ea2c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -8,18 +8,35 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -45,6 +62,21 @@ protected Regression createTestInstance() { return createRandom(); } + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + public static Regression createRandom() { return createRandom(BoostedTreeParamsTests.createRandom()); } @@ -57,7 +89,14 @@ private static Regression createRandom(BoostedTreeParams boostedTreeParams) { Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values()); Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false); return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, - lossFunctionParameter); + lossFunctionParameter, + randomBoolean() ? + null : + Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true), + OneHotEncodingTests.createRandom(true), + TargetMeanEncodingTests.createRandom(true))) + .limit(randomIntBetween(0, 5)) + .collect(Collectors.toList())); } public static Regression mutateForVersion(Regression instance, Version version) { @@ -67,7 +106,8 @@ public static Regression mutateForVersion(Regression instance, Version version) instance.getTrainingPercent(), instance.getRandomizeSeed(), instance.getLossFunction(), - instance.getLossFunctionParameter()); + instance.getLossFunctionParameter(), + version.onOrAfter(Version.V_8_0_0) ? instance.getFeatureProcessors() : Collections.emptyList()); } @Override @@ -83,14 +123,16 @@ protected void assertOnBWCObject(Regression bwcSerializedObject, Regression test bwcSerializedObject.getTrainingPercent(), 42L, bwcSerializedObject.getLossFunction(), - bwcSerializedObject.getLossFunctionParameter()); + bwcSerializedObject.getLossFunctionParameter(), + bwcSerializedObject.getFeatureProcessors()); Regression newInstance = new Regression(testInstance.getDependentVariable(), testInstance.getBoostedTreeParams(), testInstance.getPredictionFieldName(), testInstance.getTrainingPercent(), 42L, testInstance.getLossFunction(), - testInstance.getLossFunctionParameter()); + testInstance.getLossFunctionParameter(), + testInstance.getFeatureProcessors()); super.assertOnBWCObject(newBwc, newInstance, version); } @@ -104,56 +146,122 @@ protected Writeable.Reader instanceReader() { return Regression::new; } + public void testDeserialization() throws IOException { + String toDeserialize = "{\n" + + " \"dependent_variable\": \"FlightDelayMin\",\n" + + " \"feature_processors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"OriginWeather\",\n" + + " \"hot_map\": {\n" + + " \"sunny_col\": \"Sunny\",\n" + + " \"clear_col\": \"Clear\",\n" + + " \"rainy_col\": \"Rain\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"DestWeather\",\n" + + " \"hot_map\": {\n" + + " \"dest_sunny_col\": \"Sunny\",\n" + + " \"dest_clear_col\": \"Clear\",\n" + + " \"dest_rainy_col\": \"Rain\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"OriginWeather\",\n" + + " \"feature_name\": \"mean\",\n" + + " \"frequency_map\": {\n" + + " \"Sunny\": 0.8,\n" + + " \"Rain\": 0.2\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }" + + ""; + + try(XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(toDeserialize), + XContentType.JSON)) { + Regression parsed = Regression.fromXContent(parser, false); + assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin")); + for (PreProcessor preProcessor : parsed.getFeatureProcessors()) { + assertThat(preProcessor.isCustom(), is(true)); + } + } + } + public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null)); + assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenLossFunctionParameterIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null)); assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double")); } public void testConstructor_GivenLossFunctionParameterIsNegative() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null)); assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double")); } public void testGetPredictionFieldName() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0); + Regression regression = new Regression( + "foo", + BOOSTED_TREE_PARAMS, + "result", + 50.0, + randomLong(), + Regression.LossFunction.MSE, + 1.0, + null); assertThat(regression.getPredictionFieldName(), equalTo("result")); - regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction")); } public void testGetTrainingPercent() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(), Regression.LossFunction.MSE, 1.0); + Regression regression = new Regression("foo", + BOOSTED_TREE_PARAMS, + "result", + 50.0, + randomLong(), + Regression.LossFunction.MSE, + 1.0, + null); assertThat(regression.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null); assertThat(regression.getTrainingPercent(), equalTo(100.0)); } @@ -165,6 +273,7 @@ public void testGetParams_ShouldIncludeBoostedTreeParams() { 100.0, 0L, Regression.LossFunction.MSE, + null, null); Map params = regression.getParams(null); @@ -182,7 +291,9 @@ public void testGetParams_GivenRandomWithoutBoostedTreeParams() { Map params = regression.getParams(null); - int expectedParamsCount = 4 + (regression.getLossFunctionParameter() == null ? 0 : 1); + int expectedParamsCount = 4 + + (regression.getLossFunctionParameter() == null ? 0 : 1) + + (regression.getFeatureProcessors().isEmpty() ? 0 : 1); assertThat(params.size(), equalTo(expectedParamsCount)); assertThat(params.get("dependent_variable"), equalTo(regression.getDependentVariable())); assertThat(params.get("prediction_field_name"), equalTo(regression.getPredictionFieldName())); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java index 1e63aac086969..254ffa6962da2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java @@ -24,7 +24,9 @@ public class FrequencyEncodingTests extends PreProcessingTests valueMap = new HashMap<>(); for (int i = 0; i < valuesSize; i++) { @@ -41,7 +47,7 @@ public static FrequencyEncoding createRandom() { return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap, - randomBoolean() ? null : randomBoolean()); + isCustom); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java index 53887cc5ae911..4512532f76296 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java @@ -24,7 +24,9 @@ public class OneHotEncodingTests extends PreProcessingTests { @Override protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException { - return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser); + return lenient ? + OneHotEncoding.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) : + OneHotEncoding.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT); } @Override @@ -33,6 +35,10 @@ protected OneHotEncoding createTestInstance() { } public static OneHotEncoding createRandom() { + return createRandom(randomBoolean() ? randomBoolean() : null); + } + + public static OneHotEncoding createRandom(Boolean isCustom) { int valuesSize = randomIntBetween(1, 10); Map valueMap = new HashMap<>(); for (int i = 0; i < valuesSize; i++) { @@ -40,7 +46,7 @@ public static OneHotEncoding createRandom() { } return new OneHotEncoding(randomAlphaOfLength(10), valueMap, - randomBoolean() ? randomBoolean() : null); + isCustom); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java index 60765c83e11e0..9a31da55fc2c3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java @@ -24,7 +24,9 @@ public class TargetMeanEncodingTests extends PreProcessingTests valueMap = new HashMap<>(); for (int i = 0; i < valuesSize; i++) { @@ -42,7 +49,7 @@ public static TargetMeanEncoding createRandom() { randomAlphaOfLength(10), valueMap, randomDoubleBetween(0.0, 1.0, false), - randomBoolean() ? randomBoolean() : null); + isCustom); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 5cdac65370040..3c7a9cfd8d14b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -20,28 +20,37 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.junit.After; import org.junit.Before; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -106,6 +115,15 @@ public void cleanup() { .get(); } + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = new ArrayList<>(searchModule.getNamedXContents()); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(entries); + } + public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("classification_single_numeric_feature_and_mixed_data_set"); String predictedClassField = KEYWORD_FIELD + "_prediction"; @@ -119,6 +137,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws null, null, null, + null, null)); putAnalytics(config); @@ -174,6 +193,7 @@ public void testWithDatastreams() throws Exception { null, null, null, + null, null)); putAnalytics(config); @@ -266,6 +286,76 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } + public void testWithCustomFeatureProcessors() throws Exception { + initialize("classification_with_custom_feature_processors"); + String predictedClassField = KEYWORD_FIELD + "_prediction"; + indexData(sourceIndex, 300, 50, KEYWORD_FIELD); + + DataFrameAnalyticsConfig config = + buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null, + null, + Arrays.asList( + new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true) + ))); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgressIsZero(jobId); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + client().admin().indices().refresh(new RefreshRequest(destIndex)); + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getFieldValue(destDoc, "ml"); + assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); + assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + @SuppressWarnings("unchecked") + List> importanceArray = (List>)resultsObject.get("feature_importance"); + assertThat(importanceArray, hasSize(greaterThan(0))); + } + + assertProgressComplete(jobId); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [classification]", + "Estimated memory usage for this analytics to be", + "Starting analytics on node", + "Started analytics", + expectedDestIndexAuditMessage(), + "Started reindexing to destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Started loading data", + "Started analyzing", + "Started writing results", + "Finished analysis"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); + + GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE, + new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet(); + assertThat(response.getResources().results().size(), equalTo(1)); + TrainedModelConfig modelConfig = response.getResources().results().get(0); + modelConfig.ensureParsedDefinition(xContentRegistry()); + assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0)); + for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) { + PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i); + assertThat(preProcessor.isCustom(), equalTo(i == 0)); + } + } + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, String dependentVariable, List dependentVariableValues, @@ -281,7 +371,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null)); putAnalytics(config); assertIsStopped(jobId); @@ -350,7 +440,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableI "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer"); } - public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception { + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() { ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( @@ -358,7 +448,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableI assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];")); } - public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() throws Exception { + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsText() { ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( @@ -547,7 +637,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null)); putAnalytics(firstJob); String secondJobId = "classification_two_jobs_with_same_randomize_seed_2"; @@ -555,7 +645,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null)); putAnalytics(secondJob); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java index a61065d4445a6..20b82f8caa9c5 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java @@ -104,6 +104,7 @@ public void testTrainingPercentageIsApplied() throws IOException { 100.0, null, null, + null, null)) .buildForExplain(); @@ -122,6 +123,7 @@ public void testTrainingPercentageIsApplied() throws IOException { 50.0, null, null, + null, null)) .buildForExplain(); @@ -149,6 +151,7 @@ public void testSimultaneousExplainSameConfig() throws IOException { 100.0, null, null, + null, null)) .buildForExplain(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index a436072e8975d..47b53010f636c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -14,23 +14,34 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.junit.After; import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -64,6 +75,15 @@ public void cleanup() { cleanUp(); } + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = new ArrayList<>(searchModule.getNamedXContents()); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + entries.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(entries); + } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/60340") public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("regression_single_numeric_feature_and_mixed_data_set"); @@ -78,6 +98,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws null, null, null, + null, null) ); putAnalytics(config); @@ -216,7 +237,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null)); putAnalytics(config); assertIsStopped(jobId); @@ -343,7 +364,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null)); putAnalytics(firstJob); String secondJobId = "regression_two_jobs_with_same_randomize_seed_2"; @@ -351,7 +372,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed(); DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null)); putAnalytics(secondJob); @@ -412,7 +433,7 @@ public void testDependentVariableIsLong() throws Exception { sourceIndex, destIndex, null, - new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null)); + new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null)); putAnalytics(config); assertIsStopped(jobId); @@ -439,6 +460,7 @@ public void testWithDatastream() throws Exception { null, null, null, + null, null) ); putAnalytics(config); @@ -535,6 +557,7 @@ public void testAliasFields() throws Exception { 90.0, null, null, + null, null); DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() .setId(jobId) @@ -590,6 +613,73 @@ public void testAliasFields() throws Exception { "Finished analysis"); } + public void testWithCustomFeatureProcessors() throws Exception { + initialize("regression_with_custom_feature_processors"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; + indexData(sourceIndex, 300, 50); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Regression( + DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null, + null, + Arrays.asList( + new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD, + Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true) + )) + ); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgressIsZero(jobId); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + // for debugging + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); + + assertThat(resultsObject.containsKey(predictedClassField), is(true)); + assertThat(resultsObject.containsKey("is_training"), is(true)); + assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + } + + assertProgressComplete(jobId); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [regression]", + "Estimated memory usage for this analytics to be", + "Starting analytics on node", + "Started analytics", + "Creating destination index [" + destIndex + "]", + "Started reindexing to destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Started loading data", + "Started analyzing", + "Started writing results", + "Finished analysis"); + GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE, + new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet(); + assertThat(response.getResources().results().size(), equalTo(1)); + TrainedModelConfig modelConfig = response.getResources().results().get(0); + modelConfig.ensureParsedDefinition(xContentRegistry()); + assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0)); + for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) { + PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i); + assertThat(preProcessor.isCustom(), equalTo(i == 0)); + } + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 74581ac3d45ad..624aee9e41be3 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -71,7 +71,7 @@ public void testStoreModelViaChunkedPersister() throws IOException { analyticsConfig, new DataFrameAnalyticsAuditor(client(), "test-node"), (ex) -> { throw new ElasticsearchException(ex); }, - new ExtractedFields(extractedFieldList, Collections.emptyMap()) + new ExtractedFields(extractedFieldList, Collections.emptyList(), Collections.emptyMap()) ); //Accuracy for size is not tested here diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java index 38a9705deeec5..e66178fb77e25 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -171,9 +172,9 @@ public void testUpdate() throws Exception { blockingCall( actionListener -> configProvider.put(initialConfig, emptyMap(), actionListener), configHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(configHolder.get(), is(notNullValue())); assertThat(configHolder.get(), is(equalTo(initialConfig))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that changes description AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -188,7 +189,7 @@ public void testUpdate() throws Exception { actionListener -> configProvider.update(configUpdate, emptyMap(), ClusterState.EMPTY_STATE, actionListener), updatedConfigHolder, exceptionHolder); - + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -196,7 +197,6 @@ public void testUpdate() throws Exception { new DataFrameAnalyticsConfig.Builder(initialConfig) .setDescription("description-1") .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that changes model memory limit AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -212,6 +212,7 @@ public void testUpdate() throws Exception { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -220,7 +221,6 @@ public void testUpdate() throws Exception { .setDescription("description-1") .setModelMemoryLimit(new ByteSizeValue(1024)) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Noop update AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -233,6 +233,7 @@ public void testUpdate() throws Exception { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -241,7 +242,6 @@ public void testUpdate() throws Exception { .setDescription("description-1") .setModelMemoryLimit(new ByteSizeValue(1024)) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that changes both description and model memory limit AtomicReference updatedConfigHolder = new AtomicReference<>(); @@ -258,6 +258,7 @@ public void testUpdate() throws Exception { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -266,7 +267,6 @@ public void testUpdate() throws Exception { .setDescription("description-2") .setModelMemoryLimit(new ByteSizeValue(2048)) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } { // Update that applies security headers Map securityHeaders = Collections.singletonMap("_xpack_security_authentication", "dummy"); @@ -281,6 +281,7 @@ public void testUpdate() throws Exception { updatedConfigHolder, exceptionHolder); + assertNoException(exceptionHolder); assertThat(updatedConfigHolder.get(), is(notNullValue())); assertThat( updatedConfigHolder.get(), @@ -290,7 +291,6 @@ public void testUpdate() throws Exception { .setModelMemoryLimit(new ByteSizeValue(2048)) .setHeaders(securityHeaders) .build()))); - assertThat(exceptionHolder.get(), is(nullValue())); } } @@ -370,6 +370,7 @@ private static ClusterState clusterStateWithRunningAnalyticsTask(String analytic public NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java index eb13f2395ef8d..580f0c349a1ca 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java @@ -28,7 +28,9 @@ public class TimeBasedExtractedFields extends ExtractedFields { private final ExtractedField timeField; public TimeBasedExtractedFields(ExtractedField timeField, List allFields) { - super(allFields, Collections.emptyMap()); + super(allFields, + Collections.emptyList(), + Collections.emptyMap()); if (!allFields.contains(timeField)) { throw new IllegalArgumentException("timeField should also be contained in allFields"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index a82fc92a67549..6872585929b8f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -28,15 +28,18 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.ProcessedField; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -46,6 +49,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * An implementation that extracts data from elasticsearch using search and scroll on a client. @@ -67,10 +71,29 @@ public class DataFrameDataExtractor { private boolean hasNext; private boolean searchHasShardFailure; private final CachedSupplier trainTestSplitter; + // These are fields that are sent directly to the analytics process + // They are not passed through a feature_processor + private final String[] organicFeatures; + // These are the output field names for the feature_processors + private final String[] processedFeatures; + private final Map extractedFieldsByName; DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) { this.client = Objects.requireNonNull(client); this.context = Objects.requireNonNull(context); + Set processedFieldInputs = context.extractedFields.getProcessedFieldInputs(); + this.organicFeatures = context.extractedFields.getAllFields() + .stream() + .map(ExtractedField::getName) + .filter(f -> processedFieldInputs.contains(f) == false) + .toArray(String[]::new); + this.processedFeatures = context.extractedFields.getProcessedFields() + .stream() + .map(ProcessedField::getOutputFieldNames) + .flatMap(List::stream) + .toArray(String[]::new); + this.extractedFieldsByName = new LinkedHashMap<>(); + context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f)); hasNext = true; searchHasShardFailure = false; this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create); @@ -188,26 +211,78 @@ private List processSearchResponse(SearchResponse searchResponse) { return rows; } + private String extractNonProcessedValues(SearchHit hit, String organicFeature) { + ExtractedField field = extractedFieldsByName.get(organicFeature); + Object[] values = field.value(hit); + if (values.length == 1 && isValidValue(values[0])) { + return Objects.toString(values[0]); + } + if (values.length == 0 && context.supportsRowsWithMissingValues) { + // if values is empty then it means it's a missing value + return NULL_VALUE; + } + // we are here if we have a missing value but the analysis does not support those + // or the value type is not supported (e.g. arrays, etc.) + return null; + } + + private String[] extractProcessedValue(ProcessedField processedField, SearchHit hit) { + Object[] values = processedField.value(hit, extractedFieldsByName::get); + if (values.length == 0 && context.supportsRowsWithMissingValues == false) { + return null; + } + final String[] extractedValue = new String[processedField.getOutputFieldNames().size()]; + for (int i = 0; i < processedField.getOutputFieldNames().size(); i++) { + extractedValue[i] = NULL_VALUE; + } + // if values is empty then it means it's a missing value + if (values.length == 0) { + return extractedValue; + } + + if (values.length != processedField.getOutputFieldNames().size()) { + throw ExceptionsHelper.badRequestException( + "field_processor [{}] output size expected to be [{}], instead it was [{}]", + processedField.getProcessorName(), + processedField.getOutputFieldNames().size(), + values.length); + } + + for (int i = 0; i < processedField.getOutputFieldNames().size(); ++i) { + Object value = values[i]; + if (value == null && context.supportsRowsWithMissingValues) { + continue; + } + if (isValidValue(value) == false) { + // we are here if we have a missing value but the analysis does not support those + // or the value type is not supported (e.g. arrays, etc.) + return null; + } + extractedValue[i] = Objects.toString(value); + } + return extractedValue; + } + private Row createRow(SearchHit hit) { - String[] extractedValues = new String[context.extractedFields.getAllFields().size()]; - for (int i = 0; i < extractedValues.length; ++i) { - ExtractedField field = context.extractedFields.getAllFields().get(i); - Object[] values = field.value(hit); - if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { - extractedValues[i] = Objects.toString(values[0]); - } else { - if (values.length == 0 && context.supportsRowsWithMissingValues) { - // if values is empty then it means it's a missing value - extractedValues[i] = NULL_VALUE; - } else { - // we are here if we have a missing value but the analysis does not support those - // or the value type is not supported (e.g. arrays, etc.) - extractedValues = null; - break; - } + String[] extractedValues = new String[organicFeatures.length + processedFeatures.length]; + int i = 0; + for (String organicFeature : organicFeatures) { + String extractedValue = extractNonProcessedValues(hit, organicFeature); + if (extractedValue == null) { + return new Row(null, hit, true); } + extractedValues[i++] = extractedValue; } - boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues); + for (ProcessedField processedField : context.extractedFields.getProcessedFields()) { + String[] processedValues = extractProcessedValue(processedField, hit); + if (processedValues == null) { + return new Row(null, hit, true); + } + for (String processedValue : processedValues) { + extractedValues[i++] = processedValue; + } + } + boolean isTraining = trainTestSplitter.get().isTraining(extractedValues); return new Row(extractedValues, hit, isTraining); } @@ -241,7 +316,7 @@ private void clearScroll(String scrollId) { } public List getFieldNames() { - return context.extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()); + return Stream.concat(Arrays.stream(organicFeatures), Arrays.stream(processedFeatures)).collect(Collectors.toList()); } public ExtractedFields getExtractedFields() { @@ -253,12 +328,12 @@ public DataSummary collectDataSummary() { SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); long rows = searchResponse.getHits().getTotalHits().value; LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows); - return new DataSummary(rows, context.extractedFields.getAllFields().size()); + return new DataSummary(rows, organicFeatures.length + processedFeatures.length); } public void collectDataSummaryAsync(ActionListener dataSummaryActionListener) { SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder(); - final int numberOfFields = context.extractedFields.getAllFields().size(); + final int numberOfFields = organicFeatures.length + processedFeatures.length; ClientHelper.executeWithHeadersAsync(context.headers, ClientHelper.ML_ORIGIN, @@ -298,7 +373,11 @@ private QueryBuilder allExtractedFieldsExistQuery() { } public Set getCategoricalFields(DataFrameAnalysis analysis) { - return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); + return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis); + } + + private static boolean isValidValue(Object value) { + return value instanceof Number || value instanceof String; } public static class DataSummary { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index daaa371dda3a1..1b03544d015e5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -13,27 +13,33 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.ProcessedField; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -60,7 +66,9 @@ public class ExtractedFieldsDetector { private final FieldCapabilitiesResponse fieldCapabilitiesResponse; private final Map cardinalitiesForFieldsWithConstraints; - ExtractedFieldsDetector(DataFrameAnalyticsConfig config, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse, + ExtractedFieldsDetector(DataFrameAnalyticsConfig config, + int docValueFieldsLimit, + FieldCapabilitiesResponse fieldCapabilitiesResponse, Map cardinalitiesForFieldsWithConstraints) { this.config = Objects.requireNonNull(config); this.docValueFieldsLimit = docValueFieldsLimit; @@ -69,23 +77,39 @@ public class ExtractedFieldsDetector { } public Tuple> detect() { + List processedFields = extractFeatureProcessors() + .stream() + .map(ProcessedField::new) + .collect(Collectors.toList()); TreeSet fieldSelection = new TreeSet<>(Comparator.comparing(FieldSelection::getName)); - Set fields = getIncludedFields(fieldSelection); + Set fields = getIncludedFields(fieldSelection, + processedFields.stream() + .map(ProcessedField::getInputFieldNames) + .flatMap(List::stream) + .collect(Collectors.toSet())); checkFieldsHaveCompatibleTypes(fields); checkRequiredFields(fields); checkFieldsWithCardinalityLimit(); - ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection); + ExtractedFields extractedFields = detectExtractedFields(fields, fieldSelection, processedFields); addIncludedFields(extractedFields, fieldSelection); + checkOutputFeatureUniqueness(processedFields, fields); + return Tuple.tuple(extractedFields, Collections.unmodifiableList(new ArrayList<>(fieldSelection))); } - private Set getIncludedFields(Set fieldSelection) { + private Set getIncludedFields(Set fieldSelection, Set requiredFieldsForProcessors) { Set fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet()); + validateFieldsRequireForProcessors(requiredFieldsForProcessors); fields.removeAll(IGNORE_FIELDS); removeFieldsUnderResultsField(fields); removeObjects(fields); applySourceFiltering(fields); + if (fields.containsAll(requiredFieldsForProcessors) == false) { + throw ExceptionsHelper.badRequestException( + "fields {} required by field_processors are not included in source filtering.", + Sets.difference(requiredFieldsForProcessors, fields)); + } FetchSourceContext analyzedFields = config.getAnalyzedFields(); // If the user has not explicitly included fields we'll include all compatible fields @@ -93,20 +117,63 @@ private Set getIncludedFields(Set fieldSelection) { removeFieldsWithIncompatibleTypes(fields, fieldSelection); } includeAndExcludeFields(fields, fieldSelection); + if (fields.containsAll(requiredFieldsForProcessors) == false) { + throw ExceptionsHelper.badRequestException( + "fields {} required by field_processors are not included in the analyzed_fields.", + Sets.difference(requiredFieldsForProcessors, fields)); + } return fields; } + private void validateFieldsRequireForProcessors(Set processorFields) { + Set fieldsForProcessor = new HashSet<>(processorFields); + removeFieldsUnderResultsField(fieldsForProcessor); + if (fieldsForProcessor.size() < processorFields.size()) { + throw ExceptionsHelper.badRequestException("fields contained in results field [{}] cannot be used in a feature_processor", + config.getDest().getResultsField()); + } + removeObjects(fieldsForProcessor); + if (fieldsForProcessor.size() < processorFields.size()) { + throw ExceptionsHelper.badRequestException("fields for feature_processors must not be objects"); + } + fieldsForProcessor.removeAll(IGNORE_FIELDS); + if (fieldsForProcessor.size() < processorFields.size()) { + throw ExceptionsHelper.badRequestException("the following fields cannot be used in feature_processors {}", IGNORE_FIELDS); + } + List fieldsMissingInMapping = processorFields.stream() + .filter(f -> fieldCapabilitiesResponse.get().containsKey(f) == false) + .collect(Collectors.toList()); + if (fieldsMissingInMapping.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "the fields {} were not found in the field capabilities of the source indices [{}]. " + + "Fields must exist and be mapped to be used in feature_processors.", + fieldsMissingInMapping, + Strings.arrayToCommaDelimitedString(config.getSource().getIndex())); + } + List processedRequiredFields = config.getAnalysis() + .getRequiredFields() + .stream() + .map(RequiredField::getName) + .filter(processorFields::contains) + .collect(Collectors.toList()); + if (processedRequiredFields.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "required analysis fields {} cannot be used in a feature_processor", + processedRequiredFields); + } + } + private void removeFieldsUnderResultsField(Set fields) { - String resultsField = config.getDest().getResultsField(); + final String resultsFieldPrefix = config.getDest().getResultsField() + "."; Iterator fieldsIterator = fields.iterator(); while (fieldsIterator.hasNext()) { String field = fieldsIterator.next(); - if (field.startsWith(resultsField + ".")) { + if (field.startsWith(resultsFieldPrefix)) { fieldsIterator.remove(); } } - fields.removeIf(field -> field.startsWith(resultsField + ".")); + fields.removeIf(field -> field.startsWith(resultsFieldPrefix)); } private void removeObjects(Set fields) { @@ -287,9 +354,23 @@ private void checkFieldsWithCardinalityLimit() { } } - private ExtractedFields detectExtractedFields(Set fields, Set fieldSelection) { - ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse, - cardinalitiesForFieldsWithConstraints); + private List extractFeatureProcessors() { + if (config.getAnalysis() instanceof Classification) { + return ((Classification)config.getAnalysis()).getFeatureProcessors(); + } else if (config.getAnalysis() instanceof Regression) { + return ((Regression)config.getAnalysis()).getFeatureProcessors(); + } + return Collections.emptyList(); + } + + private ExtractedFields detectExtractedFields(Set fields, + Set fieldSelection, + List processedFields) { + ExtractedFields extractedFields = ExtractedFields.build(fields, + Collections.emptySet(), + fieldCapabilitiesResponse, + cardinalitiesForFieldsWithConstraints, + processedFields); boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection); if (preferSource) { @@ -304,10 +385,15 @@ private ExtractedFields detectExtractedFields(Set fields, Set fieldSelection) { - Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) + Set requiredFields = config.getAnalysis() + .getRequiredFields() + .stream() + .map(RequiredField::getName) .collect(Collectors.toSet()); + Set processorInputFields = extractedFields.getProcessedFieldInputs(); Map nameOrParentToField = new LinkedHashMap<>(); for (ExtractedField currentField : extractedFields.getAllFields()) { String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName(); @@ -315,15 +401,37 @@ private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, if (existingField != null) { ExtractedField parent = currentField.isMultiField() ? existingField : currentField; ExtractedField multiField = currentField.isMultiField() ? currentField : existingField; + // If required fields contains parent or multifield and the processor input fields reference the other, that is an error + // we should not allow processing of data that is required. + if ((requiredFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) + || (requiredFields.contains(multiField.getName()) && processorInputFields.contains(parent.getName()))) { + throw ExceptionsHelper.badRequestException( + "feature_processors cannot be applied to required fields for analysis; multi-field [{}] parent [{}]", + multiField.getName(), + parent.getName()); + } + // If processor input fields have BOTH, we need to keep both. + if (processorInputFields.contains(parent.getName()) && processorInputFields.contains(multiField.getName())) { + throw ExceptionsHelper.badRequestException( + "feature_processors refer to both multi-field [{}] and parent [{}]. Please only refer to one or the other", + multiField.getName(), + parent.getName()); + } nameOrParentToField.put(nameOrParent, - chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection)); + chooseMultiFieldOrParent(preferSource, requiredFields, processorInputFields, parent, multiField, fieldSelection)); } } - return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints); + return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), + extractedFields.getProcessedFields(), + cardinalitiesForFieldsWithConstraints); } - private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set requiredFields, ExtractedField parent, - ExtractedField multiField, Set fieldSelection) { + private ExtractedField chooseMultiFieldOrParent(boolean preferSource, + Set requiredFields, + Set processorInputFields, + ExtractedField parent, + ExtractedField multiField, + Set fieldSelection) { // Check requirements first if (requiredFields.contains(parent.getName())) { addExcludedField(multiField.getName(), "[" + parent.getName() + "] is required instead", fieldSelection); @@ -333,6 +441,19 @@ private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set fieldSelection) { Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) .collect(Collectors.toSet()); - Set categoricalFields = getCategoricalFields(extractedFields, config.getAnalysis()); + Set categoricalFields = getCategoricalInputFields(extractedFields, config.getAnalysis()); for (ExtractedField includedField : extractedFields.getAllFields()) { FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ? FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL; @@ -402,12 +527,62 @@ private void addIncludedFields(ExtractedFields extractedFields, Set getCategoricalFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { + static void checkOutputFeatureUniqueness(List processedFields, Set selectedFields) { + Set processInputs = processedFields.stream() + .map(ProcessedField::getInputFieldNames) + .flatMap(List::stream) + .collect(Collectors.toSet()); + // All analysis fields that we include that are NOT processed + // This indicates that they are sent as is + Set organicFields = Sets.difference(selectedFields, processInputs); + + Set processedFeatures = new HashSet<>(); + Set duplicatedFields = new HashSet<>(); + for (ProcessedField processedField : processedFields) { + for (String output : processedField.getOutputFieldNames()) { + if (processedFeatures.add(output) == false) { + duplicatedFields.add(output); + } + } + } + if (duplicatedFields.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "feature_processors must define unique output field names; duplicate fields {}", + duplicatedFields); + } + Set duplicateOrganicAndProcessed = Sets.intersection(organicFields, processedFeatures); + if (duplicateOrganicAndProcessed.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "feature_processors output fields must not include non-processed analysis fields; duplicate fields {}", + duplicateOrganicAndProcessed); + } + } + + static Set getCategoricalInputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { return extractedFields.getAllFields().stream() .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()) .containsAll(extractedField.getTypes())) .map(ExtractedField::getName) - .collect(Collectors.toUnmodifiableSet()); + .collect(Collectors.toSet()); + } + + static Set getCategoricalOutputFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { + Set processInputFields = extractedFields.getProcessedFieldInputs(); + Set categoricalFields = extractedFields.getAllFields().stream() + .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()) + .containsAll(extractedField.getTypes())) + .map(ExtractedField::getName) + .filter(name -> processInputFields.contains(name) == false) + .collect(Collectors.toSet()); + + extractedFields.getProcessedFields().forEach(processedField -> + processedField.getOutputFieldNames().forEach(outputField -> { + if (analysis.getAllowedCategoricalTypes(outputField).containsAll(processedField.getOutputFieldType(outputField))) { + categoricalFields.add(outputField); + } + }) + ); + return Collections.unmodifiableSet(categoricalFields); } private static boolean isBoolean(Set types) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 13358461e4393..97cd053b13f27 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -178,7 +178,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont AnalyticsProcess process = processContext.process.get(); AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); try { - writeHeaderRecord(dataExtractor, process); + writeHeaderRecord(dataExtractor, process, task); writeDataRows(dataExtractor, process, task); process.writeEndOfDataMessage(); process.flushStream(); @@ -268,8 +268,11 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces } } - private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { + private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, + AnalyticsProcess process, + DataFrameAnalyticsTask task) throws IOException { List fieldNames = dataExtractor.getFieldNames(); + LOGGER.debug(() -> new ParameterizedMessage("[{}] header row fields {}", task.getParams().getId(), fieldNames)); // We add 2 extra fields, both named dot: // - the document hash diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 725627ab21cf8..213fa1d369ffb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; @@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; @@ -34,6 +36,7 @@ import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -191,8 +194,21 @@ private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModel return latch; } + private long customProcessorSize() { + List preProcessors = new ArrayList<>(); + if (analytics.getAnalysis() instanceof Classification) { + preProcessors = ((Classification) analytics.getAnalysis()).getFeatureProcessors(); + } else if (analytics.getAnalysis() instanceof Regression) { + preProcessors = ((Regression) analytics.getAnalysis()).getFeatureProcessors(); + } + return preProcessors.stream().mapToLong(PreProcessor::ramBytesUsed).sum() + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * preProcessors.size(); + } + private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { Instant createTime = Instant.now(); + // The native process does not provide estimates for the custom feature_processor objects + long customProcessorSize = customProcessorSize(); String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); currentModelId.set(modelId); List fieldNames = extractedFields.getAllFields(); @@ -214,7 +230,7 @@ private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { .setDescription(analytics.getDescription()) .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) - .setEstimatedHeapMemory(modelSize.ramBytesUsed()) + .setEstimatedHeapMemory(modelSize.ramBytesUsed() + customProcessorSize) .setEstimatedOperations(modelSize.numOperations()) .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setLicenseLevel(License.OperationMode.PLATINUM.description()) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java index ab314a5d21851..3853ea2629af7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java @@ -12,7 +12,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.utils.MlStrings; -import java.util.Collection; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -21,27 +21,39 @@ import java.util.stream.Collectors; /** - * The fields the datafeed has to extract + * The fields the data[feed|frame] has to extract */ public class ExtractedFields { private final List allFields; private final List docValueFields; + private final List processedFields; private final String[] sourceFields; private final Map cardinalitiesForFieldsWithConstraints; - public ExtractedFields(List allFields, Map cardinalitiesForFieldsWithConstraints) { - this.allFields = Collections.unmodifiableList(allFields); + public ExtractedFields(List allFields, + List processedFields, + Map cardinalitiesForFieldsWithConstraints) { + this.allFields = new ArrayList<>(allFields); this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields); this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField) .toArray(String[]::new); this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints); + this.processedFields = processedFields == null ? Collections.emptyList() : processedFields; + } + + public List getProcessedFields() { + return processedFields; } public List getAllFields() { return allFields; } + public Set getProcessedFieldInputs() { + return processedFields.stream().map(ProcessedField::getInputFieldNames).flatMap(List::stream).collect(Collectors.toSet()); + } + public String[] getSourceFields() { return sourceFields; } @@ -58,11 +70,15 @@ private static List filterFields(ExtractedField.Method method, L return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); } - public static ExtractedFields build(Collection allFields, Set scriptFields, + public static ExtractedFields build(Set allFields, + Set scriptFields, FieldCapabilitiesResponse fieldsCapabilities, - Map cardinalitiesForFieldsWithConstraints) { + Map cardinalitiesForFieldsWithConstraints, + List processedFields) { ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities); - return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()), + return new ExtractedFields( + allFields.stream().map(extractionMethodDetector::detect).collect(Collectors.toList()), + processedFields, cardinalitiesForFieldsWithConstraints); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java new file mode 100644 index 0000000000000..50f13f9408658 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.extractor; + +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +public class ProcessedField { + private final PreProcessor preProcessor; + + public ProcessedField(PreProcessor processor) { + this.preProcessor = Objects.requireNonNull(processor); + } + + public List getInputFieldNames() { + return preProcessor.inputFields(); + } + + public List getOutputFieldNames() { + return preProcessor.outputFields(); + } + + public Set getOutputFieldType(String outputField) { + return Collections.singleton(preProcessor.getOutputFieldType(outputField)); + } + + public Object[] value(SearchHit hit, Function fieldExtractor) { + Map inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f); + for (String field : preProcessor.inputFields()) { + ExtractedField extractedField = fieldExtractor.apply(field); + if (extractedField == null) { + return new Object[0]; + } + Object[] values = extractedField.value(hit); + if (values == null || values.length == 0) { + continue; + } + final Object value = values[0]; + if (values.length == 1 && (value instanceof String || value instanceof Number)) { + inputs.put(field, value); + } + } + preProcessor.process(inputs); + return preProcessor.outputFields().stream().map(inputs::get).toArray(); + } + + public String getProcessorName() { + return preProcessor.getName(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java index 25e75040fa81f..68381fb113f67 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java @@ -128,4 +128,11 @@ protected T blockingCall(Consumer> function) throws Except return responseHolder.get(); } + public static void assertNoException(AtomicReference error) throws Exception { + if (error.get() == null) { + return; + } + throw error.get(); + } + } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index a20fff05b3f9b..5d910c5fca7a1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -15,8 +15,10 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; @@ -27,10 +29,13 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.ProcessedField; import org.elasticsearch.xpack.ml.extractor.SourceField; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.junit.Before; @@ -45,8 +50,10 @@ import java.util.Map; import java.util.Optional; import java.util.Queue; +import java.util.function.Function; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -83,7 +90,9 @@ public void setUpTests() { query = QueryBuilders.matchAllQuery(); extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("keyword")), - new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap()); + new DocValueField("field_2", Collections.singleton("keyword"))), + Collections.emptyList(), + Collections.emptyMap()); scrollSize = 1000; headers = Collections.emptyMap(); @@ -304,7 +313,9 @@ public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOExceptio // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 extractedFields = new ExtractedFields(Arrays.asList( (ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap()); + (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), + Collections.emptyList(), + Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(false, false); @@ -446,7 +457,9 @@ public void testGetCategoricalFields() { (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), (ExtractedField) new DocValueField("field_long", Collections.singleton("long")), (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap()); + (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), + Collections.emptyList(), + Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(true, true); assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); @@ -466,12 +479,100 @@ public void testGetCategoricalFields() { containsInAnyOrder("field_keyword", "field_text", "field_boolean")); } + public void testGetFieldNames_GivenProcessesFeatures() { + // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 + extractedFields = new ExtractedFields(Arrays.asList( + (ExtractedField) new DocValueField("field_boolean", Collections.singleton("boolean")), + (ExtractedField) new DocValueField("field_float", Collections.singleton("float")), + (ExtractedField) new DocValueField("field_double", Collections.singleton("double")), + (ExtractedField) new DocValueField("field_byte", Collections.singleton("byte")), + (ExtractedField) new DocValueField("field_short", Collections.singleton("short")), + (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), + (ExtractedField) new DocValueField("field_long", Collections.singleton("long")), + (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), + (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), + Arrays.asList( + new ProcessedField(new CategoricalPreProcessor("field_long", "animal")), + buildProcessedField("field_short", "field_1", "field_2") + ), + Collections.emptyMap()); + TestExtractor dataExtractor = createExtractor(true, true); + + assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")), + containsInAnyOrder("field_keyword", "field_text", "animal")); + + List fieldNames = dataExtractor.getFieldNames(); + assertThat(fieldNames, containsInAnyOrder( + "animal", + "field_1", + "field_2", + "field_boolean", + "field_float", + "field_double", + "field_byte", + "field_integer", + "field_keyword", + "field_text")); + assertThat(dataExtractor.getFieldNames(), contains(fieldNames.toArray(String[]::new))); + } + + public void testExtractionWithProcessedFeatures() throws IOException { + extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("keyword")), + new DocValueField("field_2", Collections.singleton("keyword"))), + Arrays.asList( + new ProcessedField(new CategoricalPreProcessor("field_1", "animal")), + new ProcessedField(new OneHotEncoding("field_1", + Arrays.asList("11", "12") + .stream() + .collect(Collectors.toMap(Function.identity(), s -> s.equals("11") ? "field_11" : "field_12")), + true)) + ), + Collections.emptyMap()); + + TestExtractor dataExtractor = createExtractor(true, true); + + // First and only batch + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); + dataExtractor.setNextResponse(response1); + + // Empty + SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); + dataExtractor.setNextResponse(lastAndEmptyResponse); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(3)); + + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"21", "dog", "1", "0"})); + assertThat(rows.get().get(1).getValues(), + equalTo(new String[] {"22", "dog", DataFrameDataExtractor.NULL_VALUE, DataFrameDataExtractor.NULL_VALUE})); + assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"23", "dog", "0", "0"})); + + assertThat(rows.get().get(0).shouldSkip(), is(false)); + assertThat(rows.get().get(1).shouldSkip(), is(false)); + assertThat(rows.get().get(2).shouldSkip(), is(false)); + } + private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory); return new TestExtractor(client, context); } + private static ProcessedField buildProcessedField(String inputField, String... outputFields) { + return new ProcessedField(buildPreProcessor(inputField, outputFields)); + } + + private static PreProcessor buildPreProcessor(String inputField, String... outputFields) { + return new OneHotEncoding(inputField, + Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())), + true); + } + private SearchResponse createSearchResponse(List field1Values, List field2Values) { assertThat(field1Values.size(), equalTo(field2Values.size())); SearchResponse searchResponse = mock(SearchResponse.class); @@ -545,4 +646,70 @@ protected SearchResponse executeSearchScrollRequest(String scrollId) { return searchResponse; } } + + private static class CategoricalPreProcessor implements PreProcessor { + + private final List inputFields; + private final List outputFields; + + CategoricalPreProcessor(String inputField, String outputField) { + this.inputFields = Arrays.asList(inputField); + this.outputFields = Arrays.asList(outputField); + } + + @Override + public List inputFields() { + return inputFields; + } + + @Override + public List outputFields() { + return outputFields; + } + + @Override + public void process(Map fields) { + fields.put(outputFields.get(0), "dog"); + } + + @Override + public Map reverseLookup() { + return null; + } + + @Override + public boolean isCustom() { + return true; + } + + @Override + public String getOutputFieldType(String outputField) { + return "text"; + } + + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public String getWriteableName() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public String getName() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return null; + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index c0b5f19803f17..744452439acc8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -15,10 +15,13 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; @@ -30,11 +33,14 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.arrayContainingInAnyOrder; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -929,12 +935,23 @@ public void testDetect_GivenAnalyzedFieldIncludesObjectField() { assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]")); } + private static FieldCapabilitiesResponse simpleFieldResponse() { + return new MockFieldCapsResponseBuilder() + .addAggregatableField("field_11", "float") + .addNonAggregatableField("field_21", "float") + .addAggregatableField("field_21.child", "float") + .addNonAggregatableField("field_31", "float") + .addAggregatableField("field_31.child", "float") + .addNonAggregatableField("object_field", "object") + .build(); + } + public void testDetect_GivenAnalyzedFieldExcludesObjectField() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("float_field", "float") .addNonAggregatableField("object_field", "object").build(); - analyzedFields = new FetchSourceContext(true, null, new String[] { "object_field" }); + analyzedFields = new FetchSourceContext(true, null, new String[]{"object_field"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); @@ -943,6 +960,177 @@ public void testDetect_GivenAnalyzedFieldExcludesObjectField() { assertThat(e.getMessage(), equalTo("analyzed_fields must not include or exclude object fields: [object_field]")); } + public void testDetect_givenFeatureProcessorsFailures_ResultsField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("ml.result", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields contained in results field [ml] cannot be used in a feature_processor")); + } + + public void testDetect_givenFeatureProcessorsFailures_Objects() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("object_field", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields for feature_processors must not be objects")); + } + + public void testDetect_givenFeatureProcessorsFailures_ReservedFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("_id", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("the following fields cannot be used in feature_processors")); + } + + public void testDetect_givenFeatureProcessorsFailures_MissingFieldFromIndex() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("bar", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("the fields [bar] were not found in the field capabilities of the source indices")); + } + + public void testDetect_givenFeatureProcessorsFailures_UsingRequiredField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("required analysis fields [field_31] cannot be used in a feature_processor")); + } + + public void testDetect_givenFeatureProcessorsFailures_BadSourceFiltering() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + sourceFiltering = new FetchSourceContext(true, null, new String[]{"field_1*"}); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields [field_11] required by field_processors are not included in source filtering.")); + } + + public void testDetect_givenFeatureProcessorsFailures_MissingAnalyzedField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + analyzedFields = new FetchSourceContext(true, null, new String[]{"field_1*"}); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_11", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("fields [field_11] required by field_processors are not included in the analyzed_fields")); + } + + public void testDetect_givenFeatureProcessorsFailures_RequiredMultiFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", Arrays.asList(buildPreProcessor("field_31.child", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors cannot be applied to required fields for analysis; ")); + + extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31.child", Arrays.asList(buildPreProcessor("field_31", "foo"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors cannot be applied to required fields for analysis; ")); + } + + public void testDetect_givenFeatureProcessorsFailures_BothMultiFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", + Arrays.asList( + buildPreProcessor("field_21", "foo"), + buildPreProcessor("field_21.child", "bar") + )), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors refer to both multi-field ")); + } + + public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFields() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", + Arrays.asList( + buildPreProcessor("field_11", "foo"), + buildPreProcessor("field_21", "foo") + )), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString("feature_processors must define unique output field names; duplicate fields [foo]")); + } + + public void testDetect_withFeatureProcessors() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_11", "float") + .addAggregatableField("field_21", "float") + .addNonAggregatableField("field_31", "float") + .addAggregatableField("field_31.child", "float") + .addNonAggregatableField("object_field", "object") + .build(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_11", + Arrays.asList(buildPreProcessor("field_31", "foo", "bar"))), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ExtractedFields extracted = extractedFieldsDetector.detect().v1(); + + assertThat(extracted.getProcessedFieldInputs(), containsInAnyOrder("field_31")); + assertThat(extracted.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()), + containsInAnyOrder("field_11", "field_21", "field_31")); + assertThat(extracted.getSourceFields(), arrayContainingInAnyOrder("field_31")); + assertThat(extracted.getDocValueFields().stream().map(ExtractedField::getName).collect(Collectors.toSet()), + containsInAnyOrder("field_21", "field_11")); + assertThat(extracted.getProcessedFields(), hasSize(1)); + } + private DataFrameAnalyticsConfig buildOutlierDetectionConfig() { return new DataFrameAnalyticsConfig.Builder() .setId("foo") @@ -954,24 +1142,41 @@ private DataFrameAnalyticsConfig buildOutlierDetectionConfig() { } private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable) { + return buildRegressionConfig(dependentVariable, Collections.emptyList()); + } + + private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) { return new DataFrameAnalyticsConfig.Builder() .setId("foo") .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering)) .setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD)) - .setAnalyzedFields(analyzedFields) - .setAnalysis(new Regression(dependentVariable)) + .setAnalysis(new Classification(dependentVariable)) .build(); } - private DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) { + private DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, List featureprocessors) { return new DataFrameAnalyticsConfig.Builder() .setId("foo") .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, sourceFiltering)) .setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD)) - .setAnalysis(new Classification(dependentVariable)) + .setAnalyzedFields(analyzedFields) + .setAnalysis(new Regression(dependentVariable, + BoostedTreeParams.builder().build(), + null, + null, + null, + null, + null, + featureprocessors)) .build(); } + private static PreProcessor buildPreProcessor(String inputField, String... outputFields) { + return new OneHotEncoding(inputField, + Arrays.stream(outputFields).collect(Collectors.toMap((s) -> randomAlphaOfLength(10), Function.identity())), + true); + } + /** * We assert each field individually to get useful error messages in case of failure */ @@ -987,6 +1192,23 @@ private static void assertFieldSelectionContains(List actual, Fi } } + public void testDetect_givenFeatureProcessorsFailures_DuplicateOutputFieldsWithUnProcessedField() { + FieldCapabilitiesResponse fieldCapabilities = simpleFieldResponse(); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + buildRegressionConfig("field_31", + Arrays.asList( + buildPreProcessor("field_11", "field_21") + )), + 100, + fieldCapabilities, + Collections.emptyMap()); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); + assertThat(ex.getMessage(), + containsString( + "feature_processors output fields must not include non-processed analysis fields; duplicate fields [field_21]")); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java index da01a75e20cc7..ae21775da7a4e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java @@ -80,6 +80,7 @@ public void setupTests() { public void testInferTestDocs() { ExtractedFields extractedFields = new ExtractedFields( Collections.singletonList(new SourceField("key", Collections.singleton("integer"))), + Collections.emptyList(), Collections.emptyMap()); Map doc1 = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java index a4db8de032af5..976b03a6a0d9a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java @@ -63,7 +63,9 @@ public void setUpConfigParams() { public void testToXContent_GivenOutlierDetection() throws IOException { ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), - new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap()); + new DocValueField("field_2", Collections.singleton("float"))), + Collections.emptyList(), + Collections.emptyMap()); DataFrameAnalysis analysis = new OutlierDetection.Builder().build(); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); @@ -82,7 +84,9 @@ public void testToXContent_GivenRegression() throws IOException { ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_2", Collections.singleton("float")), - new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap()); + new DocValueField("test_dep_var", Collections.singleton("keyword"))), + Collections.emptyList(), + Collections.emptyMap()); DataFrameAnalysis analysis = new Regression("test_dep_var"); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); @@ -103,7 +107,9 @@ public void testToXContent_GivenClassificationAndDepVarIsKeyword() throws IOExce ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_2", Collections.singleton("float")), - new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L)); + new DocValueField("test_dep_var", Collections.singleton("keyword"))), + Collections.emptyList(), + Collections.singletonMap("test_dep_var", 5L)); DataFrameAnalysis analysis = new Classification("test_dep_var"); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); @@ -126,7 +132,9 @@ public void testToXContent_GivenClassificationAndDepVarIsInteger() throws IOExce ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("double")), new DocValueField("field_2", Collections.singleton("float")), - new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L)); + new DocValueField("test_dep_var", Collections.singleton("integer"))), + Collections.emptyList(), + Collections.singletonMap("test_dep_var", 8L)); DataFrameAnalysis analysis = new Classification("test_dep_var"); AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 7a9144ccbc0d3..803148e188d2f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -105,7 +105,9 @@ public void setUpMocks() { OutlierDetectionTests.createRandom()).build(); dataExtractor = mock(DataFrameDataExtractor.class); when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); - when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), Collections.emptyMap())); + when(dataExtractor.getExtractedFields()).thenReturn(new ExtractedFields(Collections.emptyList(), + Collections.emptyList(), + Collections.emptyMap())); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 1e404360ae738..a77637bc59bd1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -314,6 +314,6 @@ private AnalyticsResultProcessor createResultProcessor(List fiel trainedModelProvider, auditor, statsPersister, - new ExtractedFields(fieldNames, Collections.emptyMap())); + new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap())); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java index ee01e297907d6..5c450df29b360 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -144,7 +144,7 @@ private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List{}, - new ExtractedFields(fieldNames, Collections.emptyMap())); + new ExtractedFields(fieldNames, Collections.emptyList(), Collections.emptyMap())); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java index a51eafd1d8b3d..d5c27f781036c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java @@ -16,6 +16,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -31,8 +32,10 @@ public void testAllTypesOfFields() { ExtractedField scriptField2 = new ScriptField("scripted2"); ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); - ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( - docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap()); + ExtractedFields extractedFields = new ExtractedFields( + Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), + Collections.emptyList(), + Collections.emptyMap()); assertThat(extractedFields.getAllFields().size(), equalTo(6)); assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new), @@ -53,8 +56,11 @@ public void testBuildGivenMixtureOfTypes() { when(fieldCapabilitiesResponse.getField("value")).thenReturn(valueCaps); when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps); - ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"), - new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap()); + ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("time", "value", "airline", "airport")), + new HashSet<>(Collections.singletonList("airport")), + fieldCapabilitiesResponse, + Collections.emptyMap(), + Collections.emptyList()); assertThat(extractedFields.getDocValueFields().size(), equalTo(2)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time")); @@ -76,8 +82,8 @@ public void testBuildGivenMultiFields() { when(fieldCapabilitiesResponse.getField("airport")).thenReturn(text); when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword); - ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"), - Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()); + ExtractedFields extractedFields = ExtractedFields.build(new TreeSet<>(Arrays.asList("airline.text", "airport.keyword")), + Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap(), Collections.emptyList()); assertThat(extractedFields.getDocValueFields().size(), equalTo(1)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword")); @@ -112,14 +118,18 @@ public void testApplyBooleanMapping() { assertThat(mapped.getName(), equalTo(aBool.getName())); assertThat(mapped.getMethod(), equalTo(aBool.getMethod())); assertThat(mapped.supportsFromSource(), is(false)); - expectThrows(UnsupportedOperationException.class, () -> mapped.newFromSource()); + expectThrows(UnsupportedOperationException.class, mapped::newFromSource); } public void testBuildGivenFieldWithoutMappings() { FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build( - Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap())); + Collections.singleton("value"), + Collections.emptySet(), + fieldCapabilitiesResponse, + Collections.emptyMap(), + Collections.emptyList())); assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java new file mode 100644 index 0000000000000..48604833f0814 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.extractor; + +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.ml.test.SearchHitBuilder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.emptyArray; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ProcessedFieldTests extends ESTestCase { + + public void testOneHotGetters() { + String inputField = "foo"; + ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + assertThat(processedField.getInputFieldNames(), hasItems(inputField)); + assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column")); + assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer"))); + assertThat(processedField.getOutputFieldType("baz_column"), equalTo(Collections.singleton("integer"))); + assertThat(processedField.getProcessorName(), equalTo(OneHotEncoding.NAME.getPreferredName())); + } + + public void testMissingExtractor() { + String inputField = "foo"; + ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + assertThat(processedField.value(makeHit(), (s) -> null), emptyArray()); + } + + public void testMissingInputValues() { + String inputField = "foo"; + ExtractedField extractedField = makeExtractedField(new Object[0]); + ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue()))); + } + + public void testProcessedField() { + ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz")); + assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0)); + assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1)); + } + + private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) { + return new OneHotEncoding(inputField, + Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")), + true); + } + + private static ExtractedField makeExtractedField(Object[] value) { + ExtractedField extractedField = mock(ExtractedField.class); + when(extractedField.value(any())).thenReturn(value); + return extractedField; + } + + private static SearchHit makeHit() { + return new SearchHitBuilder(42).addField("a_keyword", "bar").build(); + } + +} diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java index d6651382e4e4d..8e5bfba0e7da1 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MlConfigIndexMappingsFullClusterRestartIT.java @@ -60,6 +60,7 @@ public void waitForMlTemplates() throws Exception { XPackRestTestHelper.waitForTemplates(client(), templatesToWaitFor); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/pull/60528") public void testMlConfigIndexMappingsAfterMigration() throws Exception { Map expectedConfigIndexMappings = loadConfigIndexMappings(); if (isRunningAgainstOldCluster()) {