-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] handle new model metadata stream from native process (#59725)
This adds the serialization handling for the new model_metadata object from the native process.
- Loading branch information
Showing
18 changed files
with
834 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
242 changes: 242 additions & 0 deletions
242
...g/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; | ||
|
||
import org.elasticsearch.common.Nullable; | ||
import org.elasticsearch.common.ParseField; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.common.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.common.xcontent.ToXContentObject; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class TotalFeatureImportance implements ToXContentObject, Writeable { | ||
|
||
private static final String NAME = "total_feature_importance"; | ||
public static final ParseField FEATURE_NAME = new ParseField("feature_name"); | ||
public static final ParseField IMPORTANCE = new ParseField("importance"); | ||
public static final ParseField CLASSES = new ParseField("classes"); | ||
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); | ||
public static final ParseField MIN = new ParseField("min"); | ||
public static final ParseField MAX = new ParseField("max"); | ||
|
||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly | ||
public static final ConstructingObjectParser<TotalFeatureImportance, Void> LENIENT_PARSER = createParser(true); | ||
public static final ConstructingObjectParser<TotalFeatureImportance, Void> STRICT_PARSER = createParser(false); | ||
|
||
@SuppressWarnings("unchecked") | ||
private static ConstructingObjectParser<TotalFeatureImportance, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<TotalFeatureImportance, Void> parser = new ConstructingObjectParser<>(NAME, | ||
ignoreUnknownFields, | ||
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2])); | ||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); | ||
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), | ||
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER, | ||
IMPORTANCE); | ||
parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), | ||
ignoreUnknownFields ? ClassImportance.LENIENT_PARSER : ClassImportance.STRICT_PARSER, | ||
CLASSES); | ||
return parser; | ||
} | ||
|
||
public static TotalFeatureImportance fromXContent(XContentParser parser, boolean lenient) throws IOException { | ||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); | ||
} | ||
|
||
public final String featureName; | ||
public final Importance importance; | ||
public final List<ClassImportance> classImportances; | ||
|
||
public TotalFeatureImportance(StreamInput in) throws IOException { | ||
this.featureName = in.readString(); | ||
this.importance = in.readOptionalWriteable(Importance::new); | ||
this.classImportances = in.readList(ClassImportance::new); | ||
} | ||
|
||
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) { | ||
this.featureName = featureName; | ||
this.importance = importance; | ||
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(featureName); | ||
out.writeOptionalWriteable(importance); | ||
out.writeList(classImportances); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(FEATURE_NAME.getPreferredName(), featureName); | ||
if (importance != null) { | ||
builder.field(IMPORTANCE.getPreferredName(), importance); | ||
} | ||
if (classImportances.isEmpty() == false) { | ||
builder.field(CLASSES.getPreferredName(), classImportances); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
TotalFeatureImportance that = (TotalFeatureImportance) o; | ||
return Objects.equals(that.importance, importance) | ||
&& Objects.equals(featureName, that.featureName) | ||
&& Objects.equals(classImportances, that.classImportances); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(featureName, importance, classImportances); | ||
} | ||
|
||
public static class Importance implements ToXContentObject, Writeable { | ||
private static final String NAME = "importance"; | ||
|
||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly | ||
public static final ConstructingObjectParser<Importance, Void> LENIENT_PARSER = createParser(true); | ||
public static final ConstructingObjectParser<Importance, Void> STRICT_PARSER = createParser(false); | ||
|
||
private static ConstructingObjectParser<Importance, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<Importance, Void> parser = new ConstructingObjectParser<>(NAME, | ||
ignoreUnknownFields, | ||
a -> new Importance((double)a[0], (double)a[1], (double)a[2])); | ||
parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); | ||
parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN); | ||
parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX); | ||
return parser; | ||
} | ||
|
||
private final double meanMagnitude; | ||
private final double min; | ||
private final double max; | ||
|
||
public Importance(double meanMagnitude, double min, double max) { | ||
this.meanMagnitude = meanMagnitude; | ||
this.min = min; | ||
this.max = max; | ||
} | ||
|
||
public Importance(StreamInput in) throws IOException { | ||
this.meanMagnitude = in.readDouble(); | ||
this.min = in.readDouble(); | ||
this.max = in.readDouble(); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Importance that = (Importance) o; | ||
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && | ||
Double.compare(that.min, min) == 0 && | ||
Double.compare(that.max, max) == 0; | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(meanMagnitude, min, max); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeDouble(meanMagnitude); | ||
out.writeDouble(min); | ||
out.writeDouble(max); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); | ||
builder.field(MIN.getPreferredName(), min); | ||
builder.field(MAX.getPreferredName(), max); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
} | ||
|
||
public static class ClassImportance implements ToXContentObject, Writeable { | ||
private static final String NAME = "total_class_importance"; | ||
|
||
public static final ParseField CLASS_NAME = new ParseField("class_name"); | ||
public static final ParseField IMPORTANCE = new ParseField("importance"); | ||
|
||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly | ||
public static final ConstructingObjectParser<ClassImportance, Void> LENIENT_PARSER = createParser(true); | ||
public static final ConstructingObjectParser<ClassImportance, Void> STRICT_PARSER = createParser(false); | ||
|
||
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME, | ||
ignoreUnknownFields, | ||
a -> new ClassImportance((String)a[0], (Importance)a[1])); | ||
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME); | ||
parser.declareObject(ConstructingObjectParser.constructorArg(), | ||
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER, | ||
IMPORTANCE); | ||
return parser; | ||
} | ||
|
||
public static ClassImportance fromXContent(XContentParser parser, boolean lenient) throws IOException { | ||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); | ||
} | ||
|
||
public final String className; | ||
public final Importance importance; | ||
|
||
public ClassImportance(StreamInput in) throws IOException { | ||
this.className = in.readString(); | ||
this.importance = new Importance(in); | ||
} | ||
|
||
ClassImportance(String className, Importance importance) { | ||
this.className = className; | ||
this.importance = importance; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(className); | ||
importance.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(CLASS_NAME.getPreferredName(), className); | ||
builder.field(IMPORTANCE.getPreferredName(), importance); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
ClassImportance that = (ClassImportance) o; | ||
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(className, importance); | ||
} | ||
|
||
} | ||
} |
112 changes: 112 additions & 0 deletions
112
...org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; | ||
|
||
import org.elasticsearch.common.ParseField; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.common.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.common.xcontent.ToXContentObject; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class TrainedModelMetadata implements ToXContentObject, Writeable { | ||
|
||
public static final String NAME = "trained_model_metadata"; | ||
public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance"); | ||
public static final ParseField MODEL_ID = new ParseField("model_id"); | ||
|
||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly | ||
public static final ConstructingObjectParser<TrainedModelMetadata, Void> LENIENT_PARSER = createParser(true); | ||
public static final ConstructingObjectParser<TrainedModelMetadata, Void> STRICT_PARSER = createParser(false); | ||
|
||
@SuppressWarnings("unchecked") | ||
private static ConstructingObjectParser<TrainedModelMetadata, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<TrainedModelMetadata, Void> parser = new ConstructingObjectParser<>(NAME, | ||
ignoreUnknownFields, | ||
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1])); | ||
parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); | ||
parser.declareObjectArray(ConstructingObjectParser.constructorArg(), | ||
ignoreUnknownFields ? TotalFeatureImportance.LENIENT_PARSER : TotalFeatureImportance.STRICT_PARSER, | ||
TOTAL_FEATURE_IMPORTANCE); | ||
return parser; | ||
} | ||
|
||
public static TrainedModelMetadata fromXContent(XContentParser parser, boolean lenient) throws IOException { | ||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); | ||
} | ||
|
||
public static String docId(String modelId) { | ||
return NAME + "-" + modelId; | ||
} | ||
|
||
private final List<TotalFeatureImportance> totalFeatureImportances; | ||
private final String modelId; | ||
|
||
public TrainedModelMetadata(StreamInput in) throws IOException { | ||
this.modelId = in.readString(); | ||
this.totalFeatureImportances = in.readList(TotalFeatureImportance::new); | ||
} | ||
|
||
public TrainedModelMetadata(String modelId, List<TotalFeatureImportance> totalFeatureImportances) { | ||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); | ||
this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances); | ||
} | ||
|
||
public String getModelId() { | ||
return modelId; | ||
} | ||
|
||
public String getDocId() { | ||
return docId(modelId); | ||
} | ||
|
||
public List<TotalFeatureImportance> getTotalFeatureImportances() { | ||
return totalFeatureImportances; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
TrainedModelMetadata that = (TrainedModelMetadata) o; | ||
return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) && | ||
Objects.equals(modelId, that.modelId); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(totalFeatureImportances, modelId); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(modelId); | ||
out.writeList(totalFeatureImportances); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { | ||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); | ||
} | ||
builder.field(MODEL_ID.getPreferredName(), modelId); | ||
builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), totalFeatureImportances); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.