From 3eeaf352d1baf50ae40dd85327472f07ea58f32b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 24 Sep 2021 17:34:21 -0400 Subject: [PATCH 1/8] Stubbing out onnx provenance serialization. --- Core/pom.xml | 4 + .../java/org/tribuo/onnx/ONNXExportable.java | 23 +++ Interop/ONNX/pom.xml | 5 + .../interop/onnx/ONNXExternalModel.java | 162 ++++++++++++------ pom.xml | 11 +- 5 files changed, 147 insertions(+), 58 deletions(-) diff --git a/Core/pom.xml b/Core/pom.xml index 4d03ecffd..c047464d1 100644 --- a/Core/pom.xml +++ b/Core/pom.xml @@ -37,6 +37,10 @@ com.oracle.labs.olcut olcut-core + + com.oracle.labs.olcut + olcut-config-protobuf + com.google.protobuf protobuf-java diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java b/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java index 0c1c54ebe..b10978426 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java +++ b/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java @@ -17,6 +17,9 @@ package org.tribuo.onnx; import ai.onnx.proto.OnnxMl; +import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization; +import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization; +import org.tribuo.provenance.ModelProvenance; import java.io.BufferedOutputStream; import java.io.FileOutputStream; @@ -29,6 +32,17 @@ */ public interface ONNXExportable { + /** + * The provenance serializer. + */ + public static final ProvenanceSerialization SERIALIZER = new ProtoProvenanceSerialization(true); + + /** + * The name of the ONNX metadata field where the provenance information is stored + * in exported models. + */ + public static final String PROVENANCE_METADATA_FIELD = "TRIBUO_PROVENANCE"; + /** * Exports this {@link org.tribuo.Model} as an ONNX protobuf. * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear). @@ -74,4 +88,13 @@ default public void saveONNXModel(String domain, long modelVersion, Path outputP } } + /** + * Serializes the model provenance to a String. + * @param provenance The provenance to serialize. + * @return The serialized form of the ModelProvenance. + */ + default public String serializeProvenance(ModelProvenance provenance) { + return SERIALIZER.marshalAndSerialize(provenance); + } + } diff --git a/Interop/ONNX/pom.xml b/Interop/ONNX/pom.xml index f80b135e0..f69d4d472 100644 --- a/Interop/ONNX/pom.xml +++ b/Interop/ONNX/pom.xml @@ -67,6 +67,11 @@ tribuo-util-tokenization ${project.version} + + com.oracle.labs.olcut + olcut-config-protobuf + ${olcut.version} + com.microsoft.onnxruntime onnxruntime diff --git a/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java b/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java index a5dd3ed03..d5a77cd9c 100644 --- a/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java +++ b/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java @@ -22,7 +22,10 @@ import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; +import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization; import com.oracle.labs.mlrg.olcut.provenance.Provenance; +import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization; +import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerializationException; import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance; import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; import com.oracle.labs.mlrg.olcut.util.Pair; @@ -37,6 +40,7 @@ import org.tribuo.interop.ExternalModel; import org.tribuo.interop.ExternalTrainerProvenance; import org.tribuo.math.la.SparseVector; +import org.tribuo.onnx.ONNXExportable; import org.tribuo.provenance.DatasetProvenance; import org.tribuo.provenance.ModelProvenance; @@ -52,6 +56,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; @@ -90,8 +95,8 @@ private ONNXExternalModel(String name, ModelProvenance provenance, this.inputName = inputName; this.featureTransformer = featureTransformer; this.outputTransformer = outputTransformer; - this.env = OrtEnvironment.getEnvironment("tribuo-"+name); - this.session = env.createSession(modelArray,options); + this.env = OrtEnvironment.getEnvironment("tribuo-" + name); + this.session = env.createSession(modelArray, options); } private ONNXExternalModel(String name, ModelProvenance provenance, @@ -99,21 +104,22 @@ private ONNXExternalModel(String name, ModelProvenance provenance, int[] featureForwardMapping, int[] featureBackwardMapping, byte[] modelArray, OrtSession.SessionOptions options, String inputName, ExampleTransformer featureTransformer, OutputTransformer outputTransformer) throws OrtException { - super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping, + super(name, provenance, featureIDMap, outputIDInfo, featureForwardMapping, featureBackwardMapping, outputTransformer.generatesProbabilities()); this.modelArray = modelArray; this.options = options; this.inputName = inputName; this.featureTransformer = featureTransformer; this.outputTransformer = outputTransformer; - this.env = OrtEnvironment.getEnvironment("tribuo-"+name); - this.session = env.createSession(modelArray,options); + this.env = OrtEnvironment.getEnvironment("tribuo-" + name); + this.session = env.createSession(modelArray, options); } /** * Closes the session and rebuilds it using the supplied options. *

* Used to select a different backend, or change the number of inference threads etc. + * * @param newOptions The new session options. * @throws OrtException If the model failed to rebuild the session with the supplied options. */ @@ -123,7 +129,7 @@ public synchronized void rebuild(OrtSession.SessionOptions newOptions) throws Or options.close(); } options = newOptions; - env.createSession(modelArray,newOptions); + env.createSession(modelArray, newOptions); } @Override @@ -131,7 +137,7 @@ protected OnnxTensor convertFeatures(SparseVector input) { try { return featureTransformer.transform(env, input); } catch (OrtException e) { - throw new IllegalStateException("Failed to construct input OnnxTensor",e); + throw new IllegalStateException("Failed to construct input OnnxTensor", e); } } @@ -140,7 +146,7 @@ protected OnnxTensor convertFeaturesList(List input) { try { return featureTransformer.transform(env, input); } catch (OrtException e) { - throw new IllegalStateException("Failed to construct input OnnxTensor",e); + throw new IllegalStateException("Failed to construct input OnnxTensor", e); } } @@ -148,6 +154,7 @@ protected OnnxTensor convertFeaturesList(List input) { * Runs the session to make a prediction. *

* Closes the input tensor after the prediction has been made. + * * @param input The input in the external model's format. * @return A tensor representing the output. */ @@ -155,29 +162,30 @@ protected OnnxTensor convertFeaturesList(List input) { protected List externalPrediction(OnnxTensor input) { try { // Note the output of the session is closed by the conversion methods, and should not be closed by the result object. - OrtSession.Result output = session.run(Collections.singletonMap(inputName,input)); + OrtSession.Result output = session.run(Collections.singletonMap(inputName, input)); input.close(); ArrayList outputs = new ArrayList<>(); - for (Map.Entry v : output) { + for (Map.Entry v : output) { outputs.add(v.getValue()); } return outputs; } catch (OrtException e) { - throw new IllegalStateException("Failed to execute ONNX model",e); + throw new IllegalStateException("Failed to execute ONNX model", e); } } /** * Converts a tensor into a prediction. * Closes the output tensor after it's been converted. - * @param output The output of the external model. + * + * @param output The output of the external model. * @param numValidFeatures The number of valid features in the input. - * @param example The input example, used to construct the Prediction. + * @param example The input example, used to construct the Prediction. * @return A {@link Prediction} representing this tensor output. */ @Override protected Prediction convertOutput(List output, int numValidFeatures, Example example) { - Prediction pred = outputTransformer.transformToPrediction(output,outputIDInfo,numValidFeatures,example); + Prediction pred = outputTransformer.transformToPrediction(output, outputIDInfo, numValidFeatures, example); OnnxValue.close(output); return pred; } @@ -185,14 +193,15 @@ protected Prediction convertOutput(List output, int numValidFeatur /** * Converts a tensor into a prediction. * Closes the output tensor after it's been converted. - * @param output The output of the external model. + * + * @param output The output of the external model. * @param numValidFeatures An array with the number of valid features in each example. - * @param examples The input examples, used to construct the Predictions. + * @param examples The input examples, used to construct the Predictions. * @return A list of {@link Prediction} representing this tensor output. */ @Override protected List> convertOutput(List output, int[] numValidFeatures, List> examples) { - List> predictions = outputTransformer.transformToBatchPrediction(output,outputIDInfo,numValidFeatures,examples); + List> predictions = outputTransformer.transformToBatchPrediction(output, outputIDInfo, numValidFeatures, examples); OnnxValue.close(output); return predictions; } @@ -204,13 +213,13 @@ public Map>> getTopFeatures(int n) { @Override protected synchronized Model copy(String newName, ModelProvenance newProvenance) { - byte[] newModelArray = Arrays.copyOf(modelArray,modelArray.length); + byte[] newModelArray = Arrays.copyOf(modelArray, modelArray.length); try { return new ONNXExternalModel<>(newName, newProvenance, featureIDMap, outputIDInfo, featureForwardMapping, featureBackwardMapping, newModelArray, options, inputName, featureTransformer, outputTransformer); } catch (OrtException e) { - throw new IllegalStateException("Failed to copy ONNX model",e); + throw new IllegalStateException("Failed to copy ONNX model", e); } } @@ -220,7 +229,7 @@ public void close() { try { session.close(); } catch (OrtException e) { - logger.log(Level.SEVERE,"Exception thrown when closing session",e); + logger.log(Level.SEVERE, "Exception thrown when closing session", e); } } if (options != null) { @@ -230,54 +239,97 @@ public void close() { try { env.close(); } catch (OrtException e) { - logger.log(Level.SEVERE,"Exception thrown when closing environment",e); + logger.log(Level.SEVERE, "Exception thrown when closing environment", e); } } } + /** + * Returns the model provenance from the ONNX model if that + * model was trained in Tribuo. + *

+ * Tribuo's ONNX export functionality stores the model provenance inside the + * ONNX file in the metadata field {@link ONNXExportable#PROVENANCE_METADATA_FIELD}, + * and this method provides the access point for it. + *

+ * Note it is different from the {@link Model#getProvenance()} call which + * returns information about the ONNX file itself, and when the {@code ONNXExternalModel} + * was created. It does not replace that provenance because instantiating this provenance + * may require classes which are not present on the classpath at deployment time. + * + * @return The model provenance from the original Tribuo training run, if it exists, and + * returns {@link Optional#empty()} otherwise. + */ + public Optional getTribuoProvenance() { + try { + OnnxModelMetadata metadata = session.getMetadata(); + Optional value = metadata.getCustomMetadataValue(ONNXExportable.PROVENANCE_METADATA_FIELD); + if (value.isPresent()) { + Provenance prov = ONNXExportable.SERIALIZER.deserializeAndUnmarshal(value.get()); + + if (prov instanceof ModelProvenance) { + return Optional.of((ModelProvenance) prov); + } else { + logger.log(Level.WARNING, "Found invalid provenance object, " + prov.toString()); + return Optional.empty(); + } + } else { + return Optional.empty(); + } + } catch (OrtException e) { + logger.log(Level.WARNING,"ORTException when reading session metadata",e); + return Optional.empty(); + } catch (ProvenanceSerializationException e) { + logger.log(Level.WARNING, "Failed to parse provenance from value.",e); + return Optional.empty(); + } + } + /** * Creates an {@code ONNXExternalModel} by loading the model from disk. - * @param factory The output factory to use. - * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids. - * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids. + * + * @param factory The output factory to use. + * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids. + * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids. * @param featureTransformer The transformation function for the features. - * @param outputTransformer The transformation function for the outputs. - * @param opts The session options for the ONNX model. - * @param filename The model path. - * @param inputName The name of the input node. - * @param The type of the output. + * @param outputTransformer The transformation function for the outputs. + * @param opts The session options for the ONNX model. + * @param filename The model path. + * @param inputName The name of the input node. + * @param The type of the output. * @return An ONNXExternalModel ready to score new inputs. * @throws OrtException If the onnx-runtime native library call failed. */ public static > ONNXExternalModel createOnnxModel(OutputFactory factory, Map featureMapping, - Map outputMapping, + Map outputMapping, ExampleTransformer featureTransformer, OutputTransformer outputTransformer, OrtSession.SessionOptions opts, String filename, String inputName) throws OrtException { Path path = Paths.get(filename); - return createOnnxModel(factory,featureMapping,outputMapping,featureTransformer,outputTransformer, - opts,path,inputName); + return createOnnxModel(factory, featureMapping, outputMapping, featureTransformer, outputTransformer, + opts, path, inputName); } /** * Creates an {@code ONNXExternalModel} by loading the model from disk. - * @param factory The output factory to use. - * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids. - * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids. + * + * @param factory The output factory to use. + * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids. + * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids. * @param featureTransformer The transformation function for the features. - * @param outputTransformer The transformation function for the outputs. - * @param opts The session options for the ONNX model. - * @param path The model path. - * @param inputName The name of the input node. - * @param The type of the output. + * @param outputTransformer The transformation function for the outputs. + * @param opts The session options for the ONNX model. + * @param path The model path. + * @param inputName The name of the input node. + * @param The type of the output. * @return An ONNXExternalModel ready to score new inputs. * @throws OrtException If the onnx-runtime native library call failed. */ public static > ONNXExternalModel createOnnxModel(OutputFactory factory, Map featureMapping, - Map outputMapping, + Map outputMapping, ExampleTransformer featureTransformer, OutputTransformer outputTransformer, OrtSession.SessionOptions opts, @@ -286,30 +338,30 @@ public static > ONNXExternalModel createOnnxModel(OutputF byte[] modelArray = Files.readAllBytes(path); URL provenanceLocation = path.toUri().toURL(); ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet()); - ImmutableOutputInfo outputInfo = ExternalModel.createOutputInfo(factory,outputMapping); + ImmutableOutputInfo outputInfo = ExternalModel.createOutputInfo(factory, outputMapping); OffsetDateTime now = OffsetDateTime.now(); ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation); - DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size()); + DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size()); HashMap runProvenance = new HashMap<>(); runProvenance.put("input-name", new StringProvenance("input-name", inputName)); try (OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession session = env.createSession(modelArray)) { OnnxModelMetadata metadata = session.getMetadata(); - runProvenance.put("model-producer", new StringProvenance("model-producer",metadata.getProducerName())); - runProvenance.put("model-domain", new StringProvenance("model-domain",metadata.getDomain())); - runProvenance.put("model-description", new StringProvenance("model-description",metadata.getDescription())); - runProvenance.put("model-graphname", new StringProvenance("model-graphname",metadata.getGraphName())); - runProvenance.put("model-version", new LongProvenance("model-version",metadata.getVersion())); - for (Map.Entry e : metadata.getCustomMetadata().entrySet()) { - String keyName = "model-metadata-"+e.getKey(); - runProvenance.put(keyName, new StringProvenance(keyName,e.getValue())); + runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName())); + runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain())); + runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription())); + runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName())); + runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion())); + for (Map.Entry e : metadata.getCustomMetadata().entrySet()) { + String keyName = "model-metadata-" + e.getKey(); + runProvenance.put(keyName, new StringProvenance(keyName, e.getValue())); } } catch (OrtException e) { throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e); } - ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(),now,datasetProvenance,trainerProvenance,runProvenance); - return new ONNXExternalModel<>("external-model",provenance,featureMap,outputInfo, - featureMapping,modelArray,opts,inputName,featureTransformer,outputTransformer); + ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance); + return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, + featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer); } catch (IOException e) { throw new IllegalArgumentException("Unable to load model from path " + path, e); } @@ -320,7 +372,7 @@ private void readObject(java.io.ObjectInputStream in) throws IOException, ClassN try { this.env = OrtEnvironment.getEnvironment(); this.options = new OrtSession.SessionOptions(); - this.session = env.createSession(modelArray,options); + this.session = env.createSession(modelArray, options); } catch (OrtException e) { throw new IllegalStateException("Could not construct ONNX Runtime session during deserialization."); } diff --git a/pom.xml b/pom.xml index 719eba40d..6e3c21dff 100644 --- a/pom.xml +++ b/pom.xml @@ -42,12 +42,12 @@ UTF-8 - 5.1.6 + 5.2.0 2.43 3.25 - 1.7.0 + 1.9.0 0.3.1 1.4.1 @@ -55,7 +55,7 @@ 5.7.1 5.4 3.6.1 - 3.16.0 + 3.17.3 @@ -125,6 +125,11 @@ olcut-core ${olcut.version} + + com.oracle.labs.olcut + olcut-config-protobuf + ${olcut.version} + com.google.protobuf protobuf-java From 7d549574ab2acc49688c4f0add86887075e6ca72 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 30 Sep 2021 14:27:00 -0400 Subject: [PATCH 2/8] Adding Model.castModel so users don't have an unchecked cast in their code. --- Core/src/main/java/org/tribuo/Model.java | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/Core/src/main/java/org/tribuo/Model.java b/Core/src/main/java/org/tribuo/Model.java index ea14efe85..1f42fbcf8 100644 --- a/Core/src/main/java/org/tribuo/Model.java +++ b/Core/src/main/java/org/tribuo/Model.java @@ -280,7 +280,7 @@ public Model copy() { } /** - * Copies a model, replacing it's provenance and name with the supplied values. + * Copies a model, replacing its provenance and name with the supplied values. *

* Used to provide the provenance removal functionality. * @param newName The new name. @@ -297,5 +297,24 @@ public String toString() { return provenanceOutput; } } + + /** + * Casts the model to the specified output type, assuming it is valid. + *

+ * If it's not valid, throws {@link ClassCastException}. + * @param inputModel The model to cast. + * @param outputType The output type to cast to. + * @param The output type. + * @return The model cast to the correct value. + */ + public static > Model castModel(Model inputModel, Class outputType) { + if (inputModel.validate(outputType)) { + @SuppressWarnings("unchecked") // guarded by validate + Model castedModel = (Model) inputModel; + return castedModel; + } else { + throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + inputModel.toString()); + } + } } From 172f9b0c49a3f66bf4c6ddaadb5ce0de75c37f15 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 8 Oct 2021 12:34:16 -0400 Subject: [PATCH 3/8] Fixing a bug in IndependentMultiLabelTrainer's provenance. --- .../baseline/IndependentMultiLabelTrainer.java | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/baseline/IndependentMultiLabelTrainer.java b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/baseline/IndependentMultiLabelTrainer.java index a814cca63..83e12d065 100644 --- a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/baseline/IndependentMultiLabelTrainer.java +++ b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/baseline/IndependentMultiLabelTrainer.java @@ -81,8 +81,19 @@ public Model train(Dataset examples, Map> modelsList = new ArrayList<>(); ArrayList