diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index 867c598da90d9..be7c3c00af2c2 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.client.ml.inference; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor; @@ -42,6 +44,10 @@ public List getNamedXContentParsers() { TargetMeanEncoding::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME), FrequencyEncoding::fromXContent)); + + // Model + namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent)); + return namedXContent; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java new file mode 100644 index 0000000000000..fb1f5c3b4ab92 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java @@ -0,0 +1,36 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.List; + +public interface TrainedModel extends ToXContentObject { + + /** + * @return List of featureNames expected by the model. In the order that they are expected + */ + List getFeatureNames(); + + /** + * @return The name of the model + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java new file mode 100644 index 0000000000000..de040ec6f9ed7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java @@ -0,0 +1,192 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.tree; + +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class Tree implements TrainedModel { + + public static final String NAME = "tree"; + + public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); + public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, Builder::new); + + static { + PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES); + PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE); + } + + public static Tree fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + private final List featureNames; + private final List nodes; + + Tree(List featureNames, List nodes) { + this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames)); + this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes)); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public List getFeatureNames() { + return featureNames; + } + + public List getNodes() { + return nodes; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Tree that = (Tree) o; + return Objects.equals(featureNames, that.featureNames) + && Objects.equals(nodes, that.nodes); + } + + @Override + public int hashCode() { + return Objects.hash(featureNames, nodes); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List featureNames; + private ArrayList nodes; + private int numNodes; + + public Builder() { + nodes = new ArrayList<>(); + // allocate space in the root node and set to a leaf + nodes.add(null); + addLeaf(0, 0.0); + numNodes = 1; + } + + public Builder setFeatureNames(List featureNames) { + this.featureNames = featureNames; + return this; + } + + public Builder addNode(TreeNode.Builder node) { + nodes.add(node); + return this; + } + + public Builder setNodes(List nodes) { + this.nodes = new ArrayList<>(nodes); + return this; + } + + public Builder setNodes(TreeNode.Builder... nodes) { + return setNodes(Arrays.asList(nodes)); + } + + /** + * Add a decision node. Space for the child nodes is allocated + * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index + * @param featureIndex The feature index the decision is made on + * @param isDefaultLeft Default left branch if the feature is missing + * @param decisionThreshold The decision threshold + * @return The created node + */ + public TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) { + int leftChild = numNodes++; + int rightChild = numNodes++; + nodes.ensureCapacity(nodeIndex + 1); + for (int i = nodes.size(); i < nodeIndex + 1; i++) { + nodes.add(null); + } + + TreeNode.Builder node = TreeNode.builder(nodeIndex) + .setDefaultLeft(isDefaultLeft) + .setLeftChild(leftChild) + .setRightChild(rightChild) + .setSplitFeature(featureIndex) + .setThreshold(decisionThreshold); + nodes.set(nodeIndex, node); + + // allocate space for the child nodes + while (nodes.size() <= rightChild) { + nodes.add(null); + } + + return node; + } + + /** + * Sets the node at {@code nodeIndex} to a leaf node. + * @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)} + * @param value The prediction value + * @return this + */ + public Builder addLeaf(int nodeIndex, double value) { + for (int i = nodes.size(); i < nodeIndex + 1; i++) { + nodes.add(null); + } + nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value)); + return this; + } + + public Tree build() { + return new Tree(featureNames, + nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList())); + } + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java new file mode 100644 index 0000000000000..020aaa097169e --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java @@ -0,0 +1,280 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.tree; + +import org.elasticsearch.client.ml.job.config.Operator; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class TreeNode implements ToXContentObject { + + public static final String NAME = "tree_node"; + + public static final ParseField DECISION_TYPE = new ParseField("decision_type"); + public static final ParseField THRESHOLD = new ParseField("threshold"); + public static final ParseField LEFT_CHILD = new ParseField("left_child"); + public static final ParseField RIGHT_CHILD = new ParseField("right_child"); + public static final ParseField DEFAULT_LEFT = new ParseField("default_left"); + public static final ParseField SPLIT_FEATURE = new ParseField("split_feature"); + public static final ParseField NODE_INDEX = new ParseField("node_index"); + public static final ParseField SPLIT_GAIN = new ParseField("split_gain"); + public static final ParseField LEAF_VALUE = new ParseField("leaf_value"); + + + private static final ObjectParser PARSER = new ObjectParser<>( + NAME, + true, + Builder::new); + static { + PARSER.declareDouble(Builder::setThreshold, THRESHOLD); + PARSER.declareField(Builder::setOperator, + p -> Operator.fromString(p.text()), + DECISION_TYPE, + ObjectParser.ValueType.STRING); + PARSER.declareInt(Builder::setLeftChild, LEFT_CHILD); + PARSER.declareInt(Builder::setRightChild, RIGHT_CHILD); + PARSER.declareBoolean(Builder::setDefaultLeft, DEFAULT_LEFT); + PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE); + PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX); + PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN); + PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE); + } + + public static Builder fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Operator operator; + private final Double threshold; + private final Integer splitFeature; + private final int nodeIndex; + private final Double splitGain; + private final Double leafValue; + private final Boolean defaultLeft; + private final Integer leftChild; + private final Integer rightChild; + + + TreeNode(Operator operator, + Double threshold, + Integer splitFeature, + int nodeIndex, + Double splitGain, + Double leafValue, + Boolean defaultLeft, + Integer leftChild, + Integer rightChild) { + this.operator = operator; + this.threshold = threshold; + this.splitFeature = splitFeature; + this.nodeIndex = nodeIndex; + this.splitGain = splitGain; + this.leafValue = leafValue; + this.defaultLeft = defaultLeft; + this.leftChild = leftChild; + this.rightChild = rightChild; + } + + public Operator getOperator() { + return operator; + } + + public Double getThreshold() { + return threshold; + } + + public Integer getSplitFeature() { + return splitFeature; + } + + public Integer getNodeIndex() { + return nodeIndex; + } + + public Double getSplitGain() { + return splitGain; + } + + public Double getLeafValue() { + return leafValue; + } + + public Boolean isDefaultLeft() { + return defaultLeft; + } + + public Integer getLeftChild() { + return leftChild; + } + + public Integer getRightChild() { + return rightChild; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + addOptionalField(builder, DECISION_TYPE, operator); + addOptionalField(builder, THRESHOLD, threshold); + addOptionalField(builder, SPLIT_FEATURE, splitFeature); + addOptionalField(builder, SPLIT_GAIN, splitGain); + addOptionalField(builder, NODE_INDEX, nodeIndex); + addOptionalField(builder, LEAF_VALUE, leafValue); + addOptionalField(builder, DEFAULT_LEFT, defaultLeft ); + addOptionalField(builder, LEFT_CHILD, leftChild); + addOptionalField(builder, RIGHT_CHILD, rightChild); + builder.endObject(); + return builder; + } + + private void addOptionalField(XContentBuilder builder, ParseField field, Object value) throws IOException { + if (value != null) { + builder.field(field.getPreferredName(), value); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TreeNode that = (TreeNode) o; + return Objects.equals(operator, that.operator) + && Objects.equals(threshold, that.threshold) + && Objects.equals(splitFeature, that.splitFeature) + && Objects.equals(nodeIndex, that.nodeIndex) + && Objects.equals(splitGain, that.splitGain) + && Objects.equals(leafValue, that.leafValue) + && Objects.equals(defaultLeft, that.defaultLeft) + && Objects.equals(leftChild, that.leftChild) + && Objects.equals(rightChild, that.rightChild); + } + + @Override + public int hashCode() { + return Objects.hash(operator, + threshold, + splitFeature, + splitGain, + nodeIndex, + leafValue, + defaultLeft, + leftChild, + rightChild); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static Builder builder(int nodeIndex) { + return new Builder(nodeIndex); + } + + public static class Builder { + private Operator operator; + private Double threshold; + private Integer splitFeature; + private int nodeIndex; + private Double splitGain; + private Double leafValue; + private Boolean defaultLeft; + private Integer leftChild; + private Integer rightChild; + + public Builder(int nodeIndex) { + nodeIndex = nodeIndex; + } + + private Builder() { + } + + public Builder setOperator(Operator operator) { + this.operator = operator; + return this; + } + + public Builder setThreshold(Double threshold) { + this.threshold = threshold; + return this; + } + + public Builder setSplitFeature(Integer splitFeature) { + this.splitFeature = splitFeature; + return this; + } + + public Builder setNodeIndex(int nodeIndex) { + this.nodeIndex = nodeIndex; + return this; + } + + public Builder setSplitGain(Double splitGain) { + this.splitGain = splitGain; + return this; + } + + public Builder setLeafValue(Double leafValue) { + this.leafValue = leafValue; + return this; + } + + public Builder setDefaultLeft(Boolean defaultLeft) { + this.defaultLeft = defaultLeft; + return this; + } + + public Builder setLeftChild(Integer leftChild) { + this.leftChild = leftChild; + return this; + } + + public Integer getLeftChild() { + return leftChild; + } + + public Builder setRightChild(Integer rightChild) { + this.rightChild = rightChild; + return this; + } + + public Integer getRightChild() { + return rightChild; + } + + public TreeNode build() { + return new TreeNode(operator, + threshold, + splitFeature, + nodeIndex, + splitGain, + leafValue, + defaultLeft, + leftChild, + rightChild); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index cad58a02af4b3..7641dd3032c83 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -65,6 +65,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding; @@ -680,7 +681,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(40, namedXContents.size()); + assertEquals(41, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -690,7 +691,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 10, categories.size()); + assertEquals("Had: " + categories, 11, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -739,6 +740,8 @@ public void testProvidedNamedXContents() { RSquaredMetric.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME)); + assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); + assertThat(names, hasItems(Tree.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java new file mode 100644 index 0000000000000..733a9ddc3d943 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java @@ -0,0 +1,72 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.tree; + +import org.elasticsearch.client.ml.job.config.Operator; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TreeNodeTests extends AbstractXContentTestCase { + + @Override + protected TreeNode doParseInstance(XContentParser parser) throws IOException { + return TreeNode.fromXContent(parser).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected TreeNode createTestInstance() { + Integer lft = randomBoolean() ? null : randomInt(100); + Integer rgt = randomBoolean() ? randomInt(100) : null; + Double threshold = lft != null || randomBoolean() ? randomDouble() : null; + Integer featureIndex = lft != null || randomBoolean() ? randomInt(100) : null; + return createRandom(randomInt(), lft, rgt, threshold, featureIndex, randomBoolean() ? null : randomFrom(Operator.values())).build(); + } + + public static TreeNode createRandomLeafNode(double internalValue) { + return TreeNode.builder(randomInt(100)) + .setDefaultLeft(randomBoolean() ? null : randomBoolean()) + .setLeafValue(internalValue) + .build(); + } + + public static TreeNode.Builder createRandom(int nodeIndex, + Integer left, + Integer right, + Double threshold, + Integer featureIndex, + Operator operator) { + return TreeNode.builder(nodeIndex) + .setLeafValue(left == null ? randomDouble() : null) + .setDefaultLeft(randomBoolean() ? null : randomBoolean()) + .setLeftChild(left) + .setRightChild(right) + .setThreshold(threshold) + .setOperator(operator) + .setSplitFeature(featureIndex) + .setSplitGain(randomBoolean() ? null : randomDouble()); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java new file mode 100644 index 0000000000000..66cdb44b10073 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.tree; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + + +public class TreeTests extends AbstractXContentTestCase { + + @Override + protected Tree doParseInstance(XContentParser parser) throws IOException { + return Tree.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.startsWith("feature_names"); + } + + @Override + protected Tree createTestInstance() { + return createRandom(); + } + + public static Tree createRandom() { + return buildRandomTree(randomIntBetween(2, 15), 6); + } + + public static Tree buildRandomTree(int numFeatures, int depth) { + + Tree.Builder builder = Tree.builder(); + List featureNames = new ArrayList<>(numFeatures); + for(int i = 0; i < numFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + builder.setFeatureNames(featureNames); + + TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); + List childNodes = List.of(node.getLeftChild(), node.getRightChild()); + + for (int i = 0; i < depth -1; i++) { + + List nextNodes = new ArrayList<>(); + for (int nodeId : childNodes) { + if (i == depth -2) { + builder.addLeaf(nodeId, randomDouble()); + } else { + TreeNode.Builder childNode = + builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble()); + nextNodes.add(childNode.getLeftChild()); + nextNodes.add(childNode.getRightChild()); + } + } + childNodes = nextNodes; + } + + return builder.build(); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index eba706a7026a5..19451e5833e94 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -138,6 +138,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; @@ -441,10 +443,12 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix.Result::new), - // ML - Inference + // ML - Inference preprocessing new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new), new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new), new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new), + // ML - Inference models + new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index d7da457b64c44..7f14077a1504e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -8,6 +8,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; @@ -40,6 +44,12 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME, FrequencyEncoding::fromXContentStrict)); + // Model Lenient + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); + + // Model Strict + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict)); + return namedXContent; } @@ -54,6 +64,9 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new)); + // Model + namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedTrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedTrainedModel.java new file mode 100644 index 0000000000000..208e07de17b62 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedTrainedModel.java @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface LenientlyParsedTrainedModel extends TrainedModel { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedTrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedTrainedModel.java new file mode 100644 index 0000000000000..48b38c161942f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedTrainedModel.java @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface StrictlyParsedTrainedModel extends TrainedModel { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java new file mode 100644 index 0000000000000..1d68e3d6d3f46 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +import java.util.List; +import java.util.Map; + +public interface TrainedModel extends NamedXContentObject, NamedWriteable { + + /** + * @return List of featureNames expected by the model. In the order that they are expected + */ + List getFeatureNames(); + + /** + * Infer against the provided fields + * + * @param fields The fields and their values to infer against + * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). + * For regression this is continuous. + */ + double infer(Map fields); + + /** + * @return {@code true} if the model is classification, {@code false} otherwise. + */ + boolean isClassification(); + + /** + * This gathers the probabilities for each potential classification value. + * + * This only should return if the implementation model is inferring classification values and not regression + * @param fields The fields and their values to infer against + * @return The probabilities of each classification value + */ + List inferProbabilities(Map fields); + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java new file mode 100644 index 0000000000000..8e48fa488a0a8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -0,0 +1,311 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; + +public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + + public static final ParseField NAME = new ParseField("tree"); + + public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); + public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); + + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + Tree.Builder::new); + parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES); + parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE); + return parser; + } + + public static Tree fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + public static Tree fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); + } + + private final List featureNames; + private final List nodes; + + Tree(List featureNames, List nodes) { + this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); + this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE)); + } + + public Tree(StreamInput in) throws IOException { + this.featureNames = Collections.unmodifiableList(in.readStringList()); + this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new)); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public List getFeatureNames() { + return featureNames; + } + + public List getNodes() { + return nodes; + } + + @Override + public double infer(Map fields) { + List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + return infer(features); + } + + private double infer(List features) { + TreeNode node = nodes.get(0); + while(node.isLeaf() == false) { + node = nodes.get(node.compare(features)); + } + return node.getLeafValue(); + } + + /** + * Trace the route predicting on the feature vector takes. + * @param features The feature vector + * @return The list of traversed nodes ordered from root to leaf + */ + public List trace(List features) { + List visited = new ArrayList<>(); + TreeNode node = nodes.get(0); + visited.add(node); + while(node.isLeaf() == false) { + node = nodes.get(node.compare(features)); + visited.add(node); + } + return visited; + } + + @Override + public boolean isClassification() { + return false; + } + + @Override + public List inferProbabilities(Map fields) { + throw new UnsupportedOperationException("Cannot infer probabilities against a regression model."); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(featureNames); + out.writeCollection(nodes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Tree that = (Tree) o; + return Objects.equals(featureNames, that.featureNames) + && Objects.equals(nodes, that.nodes); + } + + @Override + public int hashCode() { + return Objects.hash(featureNames, nodes); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List featureNames; + private ArrayList nodes; + private int numNodes; + + public Builder() { + nodes = new ArrayList<>(); + // allocate space in the root node and set to a leaf + nodes.add(null); + addLeaf(0, 0.0); + numNodes = 1; + } + + public Builder setFeatureNames(List featureNames) { + this.featureNames = featureNames; + return this; + } + + public Builder addNode(TreeNode.Builder node) { + nodes.add(node); + return this; + } + + public Builder setNodes(List nodes) { + this.nodes = new ArrayList<>(nodes); + return this; + } + + public Builder setNodes(TreeNode.Builder... nodes) { + return setNodes(Arrays.asList(nodes)); + } + + /** + * Add a decision node. Space for the child nodes is allocated + * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index + * @param featureIndex The feature index the decision is made on + * @param isDefaultLeft Default left branch if the feature is missing + * @param decisionThreshold The decision threshold + * @return The created node + */ + TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) { + int leftChild = numNodes++; + int rightChild = numNodes++; + nodes.ensureCapacity(nodeIndex + 1); + for (int i = nodes.size(); i < nodeIndex + 1; i++) { + nodes.add(null); + } + + TreeNode.Builder node = TreeNode.builder(nodeIndex) + .setDefaultLeft(isDefaultLeft) + .setLeftChild(leftChild) + .setRightChild(rightChild) + .setSplitFeature(featureIndex) + .setThreshold(decisionThreshold); + nodes.set(nodeIndex, node); + + // allocate space for the child nodes + while (nodes.size() <= rightChild) { + nodes.add(null); + } + + return node; + } + + void detectCycle(List nodes) { + if (nodes.isEmpty()) { + return; + } + Set visited = new HashSet<>(); + Queue toVisit = new ArrayDeque<>(nodes.size()); + toVisit.add(0); + while(toVisit.isEmpty() == false) { + Integer nodeIdx = toVisit.remove(); + if (visited.contains(nodeIdx)) { + throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx); + } + visited.add(nodeIdx); + TreeNode.Builder treeNode = nodes.get(nodeIdx); + if (treeNode.getLeftChild() != null) { + toVisit.add(treeNode.getLeftChild()); + } + if (treeNode.getRightChild() != null) { + toVisit.add(treeNode.getRightChild()); + } + } + } + + void detectNullOrMissingNode(List nodes) { + if (nodes.isEmpty()) { + return; + } + if (nodes.get(0) == null) { + throw new IllegalArgumentException("[tree] must have non-null root node."); + } + List nullOrMissingNodes = new ArrayList<>(); + for (int i = 0; i < nodes.size(); i++) { + TreeNode.Builder currentNode = nodes.get(i); + if (currentNode == null) { + continue; + } + if (nodeNullOrMissing(currentNode.getLeftChild())) { + nullOrMissingNodes.add(currentNode.getLeftChild()); + } + if (nodeNullOrMissing(currentNode.getRightChild())) { + nullOrMissingNodes.add(currentNode.getRightChild()); + } + } + if (nullOrMissingNodes.isEmpty() == false) { + throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes); + } + } + + private boolean nodeNullOrMissing(Integer nodeIdx) { + if (nodeIdx == null) { + return false; + } + return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null; + } + + /** + * Sets the node at {@code nodeIndex} to a leaf node. + * @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)} + * @param value The prediction value + * @return this + */ + Tree.Builder addLeaf(int nodeIndex, double value) { + for (int i = nodes.size(); i < nodeIndex + 1; i++) { + nodes.add(null); + } + nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value)); + return this; + } + + public Tree build() { + detectNullOrMissingNode(nodes); + detectCycle(nodes); + return new Tree(featureNames, + nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList())); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java new file mode 100644 index 0000000000000..f0dbb0617503b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.job.config.Operator; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class TreeNode implements ToXContentObject, Writeable { + + public static final String NAME = "tree_node"; + + public static final ParseField DECISION_TYPE = new ParseField("decision_type"); + public static final ParseField THRESHOLD = new ParseField("threshold"); + public static final ParseField LEFT_CHILD = new ParseField("left_child"); + public static final ParseField RIGHT_CHILD = new ParseField("right_child"); + public static final ParseField DEFAULT_LEFT = new ParseField("default_left"); + public static final ParseField SPLIT_FEATURE = new ParseField("split_feature"); + public static final ParseField NODE_INDEX = new ParseField("node_index"); + public static final ParseField SPLIT_GAIN = new ParseField("split_gain"); + public static final ParseField LEAF_VALUE = new ParseField("leaf_value"); + + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME, + lenient, + TreeNode.Builder::new); + parser.declareDouble(TreeNode.Builder::setThreshold, THRESHOLD); + parser.declareField(TreeNode.Builder::setOperator, + p -> Operator.fromString(p.text()), + DECISION_TYPE, + ObjectParser.ValueType.STRING); + parser.declareInt(TreeNode.Builder::setLeftChild, LEFT_CHILD); + parser.declareInt(TreeNode.Builder::setRightChild, RIGHT_CHILD); + parser.declareBoolean(TreeNode.Builder::setDefaultLeft, DEFAULT_LEFT); + parser.declareInt(TreeNode.Builder::setSplitFeature, SPLIT_FEATURE); + parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX); + parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN); + parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE); + return parser; + } + + public static TreeNode.Builder fromXContent(XContentParser parser, boolean lenient) { + return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final Operator operator; + private final Double threshold; + private final Integer splitFeature; + private final int nodeIndex; + private final Double splitGain; + private final Double leafValue; + private final boolean defaultLeft; + private final int leftChild; + private final int rightChild; + + + TreeNode(Operator operator, + Double threshold, + Integer splitFeature, + Integer nodeIndex, + Double splitGain, + Double leafValue, + Boolean defaultLeft, + Integer leftChild, + Integer rightChild) { + this.operator = operator == null ? Operator.LTE : operator; + this.threshold = threshold; + this.splitFeature = splitFeature; + this.nodeIndex = ExceptionsHelper.requireNonNull(nodeIndex, NODE_INDEX.getPreferredName()); + this.splitGain = splitGain; + this.leafValue = leafValue; + this.defaultLeft = defaultLeft == null ? false : defaultLeft; + this.leftChild = leftChild == null ? -1 : leftChild; + this.rightChild = rightChild == null ? -1 : rightChild; + } + + public TreeNode(StreamInput in) throws IOException { + operator = Operator.readFromStream(in); + threshold = in.readOptionalDouble(); + splitFeature = in.readOptionalInt(); + splitGain = in.readOptionalDouble(); + nodeIndex = in.readInt(); + leafValue = in.readOptionalDouble(); + defaultLeft = in.readBoolean(); + leftChild = in.readInt(); + rightChild = in.readInt(); + } + + + public Operator getOperator() { + return operator; + } + + public Double getThreshold() { + return threshold; + } + + public Integer getSplitFeature() { + return splitFeature; + } + + public Integer getNodeIndex() { + return nodeIndex; + } + + public Double getSplitGain() { + return splitGain; + } + + public Double getLeafValue() { + return leafValue; + } + + public boolean isDefaultLeft() { + return defaultLeft; + } + + public int getLeftChild() { + return leftChild; + } + + public int getRightChild() { + return rightChild; + } + + public boolean isLeaf() { + return leftChild < 1; + } + + public int compare(List features) { + if (isLeaf()) { + throw new IllegalArgumentException("cannot call compare against a leaf node."); + } + Double feature = features.get(splitFeature); + if (isMissing(feature)) { + return defaultLeft ? leftChild : rightChild; + } + return operator.test(feature, threshold) ? leftChild : rightChild; + } + + private boolean isMissing(Double feature) { + return feature == null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + operator.writeTo(out); + out.writeOptionalDouble(threshold); + out.writeOptionalInt(splitFeature); + out.writeOptionalDouble(splitGain); + out.writeInt(nodeIndex); + out.writeOptionalDouble(leafValue); + out.writeBoolean(defaultLeft); + out.writeInt(leftChild); + out.writeInt(rightChild); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + addOptionalField(builder, DECISION_TYPE, operator); + addOptionalField(builder, THRESHOLD, threshold); + addOptionalField(builder, SPLIT_FEATURE, splitFeature); + addOptionalField(builder, SPLIT_GAIN, splitGain); + builder.field(NODE_INDEX.getPreferredName(), nodeIndex); + addOptionalField(builder, LEAF_VALUE, leafValue); + builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft); + if (leftChild >= 0) { + builder.field(LEFT_CHILD.getPreferredName(), leftChild); + } + if (rightChild >= 0) { + builder.field(RIGHT_CHILD.getPreferredName(), rightChild); + } + builder.endObject(); + return builder; + } + + private void addOptionalField(XContentBuilder builder, ParseField field, Object value) throws IOException { + if (value != null) { + builder.field(field.getPreferredName(), value); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TreeNode that = (TreeNode) o; + return Objects.equals(operator, that.operator) + && Objects.equals(threshold, that.threshold) + && Objects.equals(splitFeature, that.splitFeature) + && Objects.equals(nodeIndex, that.nodeIndex) + && Objects.equals(splitGain, that.splitGain) + && Objects.equals(leafValue, that.leafValue) + && Objects.equals(defaultLeft, that.defaultLeft) + && Objects.equals(leftChild, that.leftChild) + && Objects.equals(rightChild, that.rightChild); + } + + @Override + public int hashCode() { + return Objects.hash(operator, + threshold, + splitFeature, + splitGain, + nodeIndex, + leafValue, + defaultLeft, + leftChild, + rightChild); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static Builder builder(int nodeIndex) { + return new Builder(nodeIndex); + } + + public static class Builder { + private Operator operator; + private Double threshold; + private Integer splitFeature; + private int nodeIndex; + private Double splitGain; + private Double leafValue; + private Boolean defaultLeft; + private Integer leftChild; + private Integer rightChild; + + public Builder(int nodeIndex) { + this.nodeIndex = nodeIndex; + } + + private Builder() { + } + + public Builder setOperator(Operator operator) { + this.operator = operator; + return this; + } + + public Builder setThreshold(Double threshold) { + this.threshold = threshold; + return this; + } + + public Builder setSplitFeature(Integer splitFeature) { + this.splitFeature = splitFeature; + return this; + } + + public Builder setNodeIndex(Integer nodeIndex) { + this.nodeIndex = nodeIndex; + return this; + } + + public Builder setSplitGain(Double splitGain) { + this.splitGain = splitGain; + return this; + } + + public Builder setLeafValue(Double leafValue) { + this.leafValue = leafValue; + return this; + } + + public Builder setDefaultLeft(Boolean defaultLeft) { + this.defaultLeft = defaultLeft; + return this; + } + + public Builder setLeftChild(Integer leftChild) { + this.leftChild = leftChild; + return this; + } + + Integer getLeftChild() { + return leftChild; + } + + public Builder setRightChild(Integer rightChild) { + this.rightChild = rightChild; + return this; + } + + Integer getRightChild() { + return rightChild; + } + + public void validate() { + if (nodeIndex < 0) { + throw new IllegalArgumentException("[node_index] must be a non-negative integer."); + } + if (leftChild == null) { // leaf validations + if (leafValue == null) { + throw new IllegalArgumentException("[leaf_value] is required for a leaf node."); + } + } else { + if (leftChild < 0) { + throw new IllegalArgumentException("[left_child] must be a non-negative integer."); + } + if (rightChild != null && rightChild < 0) { + throw new IllegalArgumentException("[right_child] must be a non-negative integer."); + } + if (threshold == null) { + throw new IllegalArgumentException("[threshold] must exist for non-leaf node."); + } + } + } + + public TreeNode build() { + validate(); + return new TreeNode(operator, + threshold, + splitFeature, + nodeIndex, + splitGain, + leafValue, + defaultLeft, + leftChild, + rightChild); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java index ad74cf01fde82..3a3856cbe95a4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java @@ -14,6 +14,10 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; @@ -36,6 +40,7 @@ public class NamedXContentObjectsTests extends AbstractXContentTestCase STRICT_PARSER = createParser(false); static final ObjectParser LENIENT_PARSER = createParser(true); @@ -51,16 +56,30 @@ private static ObjectParser createParser(boolean len lenient ? p.namedObject(LenientlyParsedPreProcessor.class, n, null) : p.namedObject(StrictlyParsedPreProcessor.class, n, null), (noc) -> noc.setUseExplicitPreprocessorOrder(true), PRE_PROCESSORS); + parser.declareNamedObjects(NamedObjectContainer::setTrainedModel, + (p, c, n) -> + lenient ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) : + p.namedObject(StrictlyParsedTrainedModel.class, n, null), + TRAINED_MODEL); return parser; } private boolean useExplicitPreprocessorOrder = false; private List preProcessors; + private TrainedModel trainedModel; void setPreProcessors(List preProcessors) { this.preProcessors = preProcessors; } + void setTrainedModel(List trainedModel) { + this.trainedModel = trainedModel.get(0); + } + + void setModel(TrainedModel trainedModel) { + this.trainedModel = trainedModel; + } + void setUseExplicitPreprocessorOrder(boolean value) { this.useExplicitPreprocessorOrder = value; } @@ -73,6 +92,7 @@ static NamedObjectContainer fromXContent(XContentParser parser, boolean lenient) public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); writeNamedObjects(builder, params, useExplicitPreprocessorOrder, PRE_PROCESSORS.getPreferredName(), preProcessors); + writeNamedObjects(builder, params, false, TRAINED_MODEL.getPreferredName(), Collections.singletonList(trainedModel)); builder.endObject(); return builder; } @@ -109,7 +129,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; NamedObjectContainer that = (NamedObjectContainer) o; - return Objects.equals(preProcessors, that.preProcessors); + return Objects.equals(preProcessors, that.preProcessors) && Objects.equals(trainedModel, that.trainedModel); } @Override @@ -137,6 +157,7 @@ public NamedObjectContainer createTestInstance() { NamedObjectContainer container = new NamedObjectContainer(); container.setPreProcessors(preProcessors); container.setUseExplicitPreprocessorOrder(true); + container.setModel(TreeTests.buildRandomTree(5, 4)); return container; } @@ -157,6 +178,7 @@ protected Predicate getRandomFieldsExcludeFilter() { (field.endsWith("frequency_encoding") || field.endsWith("one_hot_encoding") || field.endsWith("target_mean_encoding") || + field.startsWith("tree.tree_structure") || field.isEmpty()) == false; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java new file mode 100644 index 0000000000000..dd87270b95fc6 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.job.config.Operator; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class TreeNodeTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected TreeNode doParseInstance(XContentParser parser) throws IOException { + return TreeNode.fromXContent(parser, lenient).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected TreeNode createTestInstance() { + Integer lft = randomBoolean() ? null : randomInt(100); + Integer rgt = randomBoolean() ? randomInt(100) : null; + Double threshold = lft != null || randomBoolean() ? randomDouble() : null; + Integer featureIndex = lft != null || randomBoolean() ? randomInt(100) : null; + return createRandom(randomInt(100), + lft, + rgt, + threshold, + featureIndex, + randomBoolean() ? null : randomFrom(Operator.values())).build(); + } + + public static TreeNode createRandomLeafNode(double internalValue) { + return TreeNode.builder(randomInt(100)) + .setDefaultLeft(randomBoolean() ? null : randomBoolean()) + .setLeafValue(internalValue) + .build(); + } + + public static TreeNode.Builder createRandom(int nodeId, + Integer left, + Integer right, + Double threshold, + Integer featureIndex, + Operator operator) { + return TreeNode.builder(nodeId) + .setLeafValue(left == null ? randomDouble() : null) + .setDefaultLeft(randomBoolean() ? null : randomBoolean()) + .setLeftChild(left) + .setRightChild(right) + .setThreshold(threshold) + .setOperator(operator) + .setSplitFeature(randomBoolean() ? null : randomInt()) + .setSplitGain(randomBoolean() ? null : randomDouble()) + .setSplitFeature(featureIndex); + } + + @Override + protected Writeable.Reader instanceReader() { + return TreeNode::new; + } + + public void testCompare() { + expectThrows(IllegalArgumentException.class, + () -> createRandomLeafNode(randomDouble()).compare(Collections.singletonList(randomDouble()))); + + List featureValues = Arrays.asList(0.1, null); + assertThat(createRandom(0, 2, 3, 0.0, 0, null).build().compare(featureValues), + equalTo(3)); + assertThat(createRandom(0, 2, 3, 0.0, 0, Operator.GT).build().compare(featureValues), + equalTo(2)); + assertThat(createRandom(0, 2, 3, 0.2, 0, null).build().compare(featureValues), + equalTo(2)); + assertThat(createRandom(0, 2, 3, 0.0, 1, null).setDefaultLeft(true).build().compare(featureValues), + equalTo(2)); + assertThat(createRandom(0, 2, 3, 0.0, 1, null).setDefaultLeft(false).build().compare(featureValues), + equalTo(3)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java new file mode 100644 index 0000000000000..391f2e4b7e59a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + + +public class TreeTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected Tree doParseInstance(XContentParser parser) throws IOException { + return lenient ? Tree.fromXContentLenient(parser) : Tree.fromXContentStrict(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.startsWith("feature_names"); + } + + + @Override + protected Tree createTestInstance() { + return createRandom(); + } + + public static Tree createRandom() { + return buildRandomTree(randomIntBetween(2, 15), 6); + } + + public static Tree buildRandomTree(int numFeatures, int depth) { + + Tree.Builder builder = Tree.builder(); + List featureNames = new ArrayList<>(numFeatures); + for(int i = 0; i < numFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + builder.setFeatureNames(featureNames); + + TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); + List childNodes = List.of(node.getLeftChild(), node.getRightChild()); + + for (int i = 0; i < depth -1; i++) { + + List nextNodes = new ArrayList<>(); + for (int nodeId : childNodes) { + if (i == depth -2) { + builder.addLeaf(nodeId, randomDouble()); + } else { + TreeNode.Builder childNode = + builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble()); + nextNodes.add(childNode.getLeftChild()); + nextNodes.add(childNode.getRightChild()); + } + } + childNodes = nextNodes; + } + + return builder.build(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Tree::new; + } + + public void testInfer() { + // Build a tree with 2 nodes and 3 leaves using 2 features + // The leaves have unique values 0.1, 0.2, 0.3 + Tree.Builder builder = Tree.builder(); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), 0.3); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), 0.1); + builder.addLeaf(leftChildNode.getRightChild(), 0.2); + + List featureNames = Arrays.asList("foo", "bar"); + Tree tree = builder.setFeatureNames(featureNames).build(); + + // This feature vector should hit the right child of the root node + List featureVector = Arrays.asList(0.6, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.3, tree.infer(featureMap), 0.00001); + + // This should hit the left child of the left child of the root node + // i.e. it takes the path left, left + featureVector = Arrays.asList(0.3, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.1, tree.infer(featureMap), 0.00001); + + // This should hit the right child of the left child of the root node + // i.e. it takes the path left, right + featureVector = Arrays.asList(0.3, 0.9); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.2, tree.infer(featureMap), 0.00001); + } + + public void testTreeWithNullRoot() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, + () -> Tree.builder().setNodes(Collections.singletonList(null)) + .build()); + assertThat(ex.getMessage(), equalTo("[tree] must have non-null root node.")); + } + + public void testTreeWithInvalidNode() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, + () -> Tree.builder().setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .build()); + assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]")); + } + + public void testTreeWithNullNode() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, + () -> Tree.builder().setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble()), + null) + .build()); + assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]")); + } + + public void testTreeWithCycle() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, + () -> Tree.builder().setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble()), + TreeNode.builder(0) + .setLeftChild(0) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .build()); + assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0")); + } + + private static Map zipObjMap(List keys, List values) { + return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); + } +}