From 6e5ead97ac279b26ccd7819435b91e8ec321702b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 10 Dec 2021 14:18:27 -0500 Subject: [PATCH] Moving ONNX export utils out into a separate module. --- .../ensemble/FullyWeightedVotingCombiner.java | 8 +-- .../ensemble/VotingCombiner.java | 8 +-- .../LibLinearClassificationModel.java | 16 ++--- .../libsvm/LibSVMClassificationModel.java | 16 ++--- .../sgd/fm/FMClassificationModel.java | 4 +- .../sgd/linear/LinearSGDModel.java | 4 +- .../tribuo/common/sgd/AbstractFMModel.java | 15 ++--- .../common/sgd/AbstractLinearSGDModel.java | 15 ++--- Core/pom.xml | 5 ++ .../org/tribuo/{onnx => }/ONNXExportable.java | 54 ++++++++++++++-- .../org/tribuo/ensemble/EnsembleCombiner.java | 4 +- .../ensemble/WeightedEnsembleModel.java | 16 ++--- .../interop/onnx/ONNXExternalModel.java | 4 +- .../org/tribuo/math/onnx/ONNXMathUtils.java | 4 +- .../org/tribuo/math/util/ExpNormalizer.java | 4 +- .../org/tribuo/math/util/NoopNormalizer.java | 2 +- .../java/org/tribuo/math/util/Normalizer.java | 8 +-- .../tribuo/math/util/SigmoidNormalizer.java | 4 +- .../tribuo/math/util/VectorNormalizer.java | 5 +- .../ensemble/MultiLabelVotingCombiner.java | 10 +-- .../multilabel/sgd/fm/FMMultiLabelModel.java | 4 +- .../multilabel/sgd/linear/LinearSGDModel.java | 4 +- .../ensemble/AveragingCombiner.java | 13 ++-- .../liblinear/LibLinearRegressionModel.java | 16 ++--- .../libsvm/LibSVMRegressionModel.java | 16 ++--- .../regression/sgd/fm/FMRegressionModel.java | 8 +-- .../regression/sgd/linear/LinearSGDModel.java | 4 +- .../regression/slm/SparseLinearModel.java | 16 ++--- Util/ONNXExport/pom.xml | 63 +++++++++++++++++++ .../src/main/java/ai/onnx/proto/NOTICE.md | 0 .../src/main/java/ai/onnx/proto/OnnxMl.java | 0 .../org/tribuo/util}/onnx/ONNXAttribute.java | 2 +- .../org/tribuo/util}/onnx/ONNXContext.java | 49 ++++----------- .../tribuo/util}/onnx/ONNXInitializer.java | 2 +- .../java/org/tribuo/util}/onnx/ONNXNode.java | 2 +- .../org/tribuo/util}/onnx/ONNXOperators.java | 4 +- .../tribuo/util}/onnx/ONNXPlaceholder.java | 2 +- .../java/org/tribuo/util}/onnx/ONNXRef.java | 2 +- .../java/org/tribuo/util}/onnx/ONNXShape.java | 2 +- .../java/org/tribuo/util}/onnx/ONNXUtils.java | 2 +- .../org/tribuo/util}/onnx/package-info.java | 12 ++-- .../ONNXExport/src/main/resources/LICENSE.txt | 1 + .../main/resources/THIRD_PARTY_LICENSES.txt | 1 + Util/pom.xml | 6 ++ distribution/pom.xml | 5 ++ pom.xml | 9 +-- 46 files changed, 273 insertions(+), 178 deletions(-) rename Core/src/main/java/org/tribuo/{onnx => }/ONNXExportable.java (55%) create mode 100644 Util/ONNXExport/pom.xml rename {Core => Util/ONNXExport}/src/main/java/ai/onnx/proto/NOTICE.md (100%) rename {Core => Util/ONNXExport}/src/main/java/ai/onnx/proto/OnnxMl.java (100%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXAttribute.java (99%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXContext.java (89%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXInitializer.java (97%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXNode.java (97%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXOperators.java (99%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXPlaceholder.java (97%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXRef.java (99%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXShape.java (99%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/ONNXUtils.java (99%) rename {Core/src/main/java/org/tribuo => Util/ONNXExport/src/main/java/org/tribuo/util}/onnx/package-info.java (58%) create mode 120000 Util/ONNXExport/src/main/resources/LICENSE.txt create mode 120000 Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt diff --git a/Classification/Core/src/main/java/org/tribuo/classification/ensemble/FullyWeightedVotingCombiner.java b/Classification/Core/src/main/java/org/tribuo/classification/ensemble/FullyWeightedVotingCombiner.java index 6ec741356..527cef994 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/ensemble/FullyWeightedVotingCombiner.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/ensemble/FullyWeightedVotingCombiner.java @@ -23,10 +23,10 @@ import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.ensemble.EnsembleCombiner; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXRef; import java.util.Collections; import java.util.HashMap; diff --git a/Classification/Core/src/main/java/org/tribuo/classification/ensemble/VotingCombiner.java b/Classification/Core/src/main/java/org/tribuo/classification/ensemble/VotingCombiner.java index 991e42cdf..f73ff135f 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/ensemble/VotingCombiner.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/ensemble/VotingCombiner.java @@ -23,10 +23,10 @@ import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.ensemble.EnsembleCombiner; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXRef; import java.util.Collections; import java.util.HashMap; diff --git a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java index 1368183ec..462e23b09 100644 --- a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java +++ b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java @@ -26,18 +26,18 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.common.liblinear.LibLinearModel; import org.tribuo.common.liblinear.LibLinearTrainer; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.util.ArrayList; import java.util.Arrays; @@ -298,7 +298,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { // Build graph writeONNXGraph(input).assignTo(output); - return onnx.model(domain, modelVersion, this); + return ONNXExportable.buildModel(onnx, domain, modelVersion, this); } @Override diff --git a/Classification/LibSVM/src/main/java/org/tribuo/classification/libsvm/LibSVMClassificationModel.java b/Classification/LibSVM/src/main/java/org/tribuo/classification/libsvm/LibSVMClassificationModel.java index 0c1b40e38..b6f256df3 100644 --- a/Classification/LibSVM/src/main/java/org/tribuo/classification/libsvm/LibSVMClassificationModel.java +++ b/Classification/LibSVM/src/main/java/org/tribuo/classification/libsvm/LibSVMClassificationModel.java @@ -24,20 +24,20 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.common.libsvm.KernelType; import org.tribuo.common.libsvm.LibSVMModel; import org.tribuo.common.libsvm.LibSVMTrainer; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXInitializer; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; import org.tribuo.provenance.ModelProvenance; import org.tribuo.util.Util; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.util.ArrayList; import java.util.Arrays; @@ -176,7 +176,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { onnx.setName("Classification-LibSVM"); writeONNXGraph(input).assignTo(output); - return onnx.model(domain, modelVersion, this); + return ONNXExportable.buildModel(onnx, domain, modelVersion, this); } @Override diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java index e82f9acbf..42f5a65b9 100644 --- a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java @@ -19,15 +19,15 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.common.sgd.AbstractFMModel; import org.tribuo.common.sgd.FMParameters; import org.tribuo.math.la.DenseVector; import org.tribuo.math.util.VectorNormalizer; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXNode; import java.util.LinkedHashMap; import java.util.Map; diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java index dd36db3e4..1157086a1 100644 --- a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java @@ -19,6 +19,7 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.common.sgd.AbstractLinearSGDModel; @@ -26,9 +27,8 @@ import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseVector; import org.tribuo.math.util.VectorNormalizer; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXNode; import java.io.IOException; import java.util.LinkedHashMap; diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java index 8e57e3f30..ee706522b 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java @@ -22,6 +22,7 @@ import org.tribuo.Excuse; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Output; import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseVector; @@ -29,13 +30,13 @@ import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.Tensor; import org.tribuo.math.onnx.ONNXMathUtils; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.util.ArrayList; import java.util.Arrays; @@ -267,7 +268,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { ONNXPlaceholder input = onnx.floatInput("input", featureIDMap.size()); ONNXPlaceholder output = onnx.floatOutput("output", outputIDInfo.size()); writeONNXGraph(input).assignTo(output); - return onnx.model(domain, modelVersion, this); + return ONNXExportable.buildModel(onnx, domain, modelVersion, this); } } diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java index 3d6d28d53..e47f5a020 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java @@ -24,18 +24,19 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; +import org.tribuo.ONNXExportable; import org.tribuo.Output; import org.tribuo.Prediction; import org.tribuo.math.LinearParameters; import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.Matrix; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.util.ArrayList; import java.util.Arrays; @@ -213,7 +214,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { ONNXPlaceholder input = onnx.floatInput("input", featureIDMap.size()); ONNXPlaceholder output = onnx.floatOutput("output", outputIDInfo.size()); writeONNXGraph(input).assignTo(output); - return onnx.model(domain, modelVersion, this); + return ONNXExportable.buildModel(onnx, domain, modelVersion, this); } } diff --git a/Core/pom.xml b/Core/pom.xml index c047464d1..875a7d665 100644 --- a/Core/pom.xml +++ b/Core/pom.xml @@ -37,6 +37,11 @@ com.oracle.labs.olcut olcut-core + + org.tribuo + tribuo-util-onnx + ${project.version} + com.oracle.labs.olcut olcut-config-protobuf diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java b/Core/src/main/java/org/tribuo/ONNXExportable.java similarity index 55% rename from Core/src/main/java/org/tribuo/onnx/ONNXExportable.java rename to Core/src/main/java/org/tribuo/ONNXExportable.java index b88a3d46e..e1c8130d7 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java +++ b/Core/src/main/java/org/tribuo/ONNXExportable.java @@ -14,12 +14,17 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo; import ai.onnx.proto.OnnxMl; import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization; +import com.oracle.labs.mlrg.olcut.provenance.Provenancable; import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXRef; import java.io.BufferedOutputStream; import java.io.FileOutputStream; @@ -33,6 +38,12 @@ * Tribuo models export with a single input of size [-1, numFeatures] and a * single output of size [-1, numOutputDimensions]. The first dimension in both * is defined to be an unbound dimension called "batch", which denotes the batch size. + *

+ * ONNX exported models use floats where Tribuo uses doubles, this is due + * to comparatively poor support for fp64 in ONNX deployment environments + * as compared to fp32. In addition, fp32 executes better on the various + * accelerator backends available in + * ONNX Runtime. */ public interface ONNXExportable { @@ -47,9 +58,39 @@ public interface ONNXExportable { */ public static final String PROVENANCE_METADATA_FIELD = "TRIBUO_PROVENANCE"; + /** + * Creates an ONNX model protobuf for the supplied context. + * + * @param onnxContext The context which contains the ONNX graph. + * @param domain Domain for the produced model. + * @param modelVersion Model version for the produced model. + * @param model Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data + * from this model will be written into the ONNX Model proto. + * @param The type of the provenanced model. + * @return An ONNX model proto of the graph represented by the supplied ONNXContext. + */ + public static > OnnxMl.ModelProto buildModel(ONNXContext onnxContext, String domain, long modelVersion, M model) { + return OnnxMl.ModelProto.newBuilder() + .setGraph(onnxContext.buildGraph()) + .setDomain(domain) + .setProducerName("Tribuo") + .setProducerVersion(Tribuo.VERSION) + .setModelVersion(modelVersion) + .addOpsetImport(ONNXOperators.getOpsetProto()) + .setIrVersion(6) + .setDocString(model.toString()) + .addMetadataProps(OnnxMl.StringStringEntryProto + .newBuilder() + .setKey(PROVENANCE_METADATA_FIELD) + .setValue(SERIALIZER.marshalAndSerialize(model.getProvenance())) + .build()) + .build(); + } + /** * Exports this {@link org.tribuo.Model} as an ONNX protobuf. - * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear). + * + * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear). * @param modelVersion A version number for this model. * @return The ONNX ModelProto representing this Tribuo Model. */ @@ -58,6 +99,7 @@ public interface ONNXExportable { /** * Writes this {@link org.tribuo.Model} into {@link OnnxMl.GraphProto.Builder} inside the input's * {@link ONNXContext}. + * * @param input The input to the model graph. * @return the output node of the model graph. */ @@ -65,13 +107,14 @@ public interface ONNXExportable { /** * Exports this {@link org.tribuo.Model} as an ONNX file. - * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear). + * + * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear). * @param modelVersion A version number for this model. - * @param outputPath The path to write to. + * @param outputPath The path to write to. * @throws IOException if the file could not be written to. */ default public void saveONNXModel(String domain, long modelVersion, Path outputPath) throws IOException { - OnnxMl.ModelProto proto = exportONNXModel(domain,modelVersion); + OnnxMl.ModelProto proto = exportONNXModel(domain, modelVersion); try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(outputPath.toFile()))) { proto.writeTo(bos); } @@ -79,6 +122,7 @@ default public void saveONNXModel(String domain, long modelVersion, Path outputP /** * Serializes the model provenance to a String. + * * @param provenance The provenance to serialize. * @return The serialized form of the ModelProvenance. */ diff --git a/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java b/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java index 7ec788cb6..af3a11e01 100644 --- a/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java +++ b/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java @@ -22,8 +22,8 @@ import org.tribuo.ImmutableOutputInfo; import org.tribuo.Output; import org.tribuo.Prediction; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXRef; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXRef; import java.io.Serializable; import java.util.List; diff --git a/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java b/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java index 0d96d0981..40ff5a1f5 100644 --- a/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java +++ b/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java @@ -24,18 +24,18 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; +import org.tribuo.ONNXExportable; import org.tribuo.Output; import org.tribuo.Prediction; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.EnsembleModelProvenance; import org.tribuo.provenance.impl.TimestampedTrainerProvenance; import org.tribuo.util.Util; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.time.OffsetDateTime; import java.util.ArrayList; @@ -247,7 +247,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { // Build graph writeONNXGraph(input).assignTo(output); - return onnx.model(domain, modelVersion, this); + return ONNXExportable.buildModel(onnx, domain, modelVersion, this); } @Override diff --git a/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java b/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java index 046929174..668fc7716 100644 --- a/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java +++ b/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java @@ -22,9 +22,7 @@ import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; -import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization; import com.oracle.labs.mlrg.olcut.provenance.Provenance; -import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization; import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerializationException; import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance; import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; @@ -33,6 +31,7 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; +import org.tribuo.ONNXExportable; import org.tribuo.Output; import org.tribuo.OutputFactory; import org.tribuo.Prediction; @@ -40,7 +39,6 @@ import org.tribuo.interop.ExternalModel; import org.tribuo.interop.ExternalTrainerProvenance; import org.tribuo.math.la.SparseVector; -import org.tribuo.onnx.ONNXExportable; import org.tribuo.provenance.DatasetProvenance; import org.tribuo.provenance.ModelProvenance; diff --git a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java index d3e6c4b87..9203fbfc5 100644 --- a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java +++ b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java @@ -18,8 +18,8 @@ import org.tribuo.math.la.Matrix; import org.tribuo.math.la.SGDVector; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; import java.nio.FloatBuffer; import java.util.Arrays; diff --git a/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java b/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java index 1ac4869a8..1f7aadf49 100644 --- a/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java +++ b/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java @@ -16,8 +16,8 @@ package org.tribuo.math.util; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; import java.io.Serializable; import java.util.Arrays; diff --git a/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java b/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java index d86885f77..b7e2a7398 100644 --- a/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java +++ b/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java @@ -16,7 +16,7 @@ package org.tribuo.math.util; -import org.tribuo.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXNode; import java.io.Serializable; import java.util.Arrays; diff --git a/Math/src/main/java/org/tribuo/math/util/Normalizer.java b/Math/src/main/java/org/tribuo/math/util/Normalizer.java index 12559fced..608bfe74f 100644 --- a/Math/src/main/java/org/tribuo/math/util/Normalizer.java +++ b/Math/src/main/java/org/tribuo/math/util/Normalizer.java @@ -16,10 +16,10 @@ package org.tribuo.math.util; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; import java.io.Serializable; import java.util.Arrays; diff --git a/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java b/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java index d5cd23c41..0e1057708 100644 --- a/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java +++ b/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java @@ -16,8 +16,8 @@ package org.tribuo.math.util; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; import java.io.Serializable; import java.util.Arrays; diff --git a/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java b/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java index 255b6d699..742a87122 100644 --- a/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java +++ b/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java @@ -16,7 +16,8 @@ package org.tribuo.math.util; -import org.tribuo.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXNode; import java.io.Serializable; import java.util.logging.Logger; @@ -48,7 +49,7 @@ default public void normalizeInPlace(double[] input) { /** * Exports this normalizer to ONNX, returning the leaf of the appended graph - * and writing the nodes needed for normalization into the {@link org.tribuo.onnx.ONNXContext} + * and writing the nodes needed for normalization into the {@link ONNXContext} * that {@code input} belongs to. *

* For compatibility reasons this method has a default implementation, though diff --git a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java index 4089e7bb4..2a4785d43 100644 --- a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java +++ b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java @@ -25,11 +25,11 @@ import org.tribuo.ensemble.EnsembleCombiner; import org.tribuo.math.la.DenseVector; import org.tribuo.multilabel.MultiLabel; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXRef; import java.util.Arrays; import java.util.Collections; diff --git a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java index da95f4327..b355d5a9a 100644 --- a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java +++ b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java @@ -19,6 +19,7 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.common.sgd.AbstractFMModel; @@ -26,9 +27,8 @@ import org.tribuo.math.la.DenseVector; import org.tribuo.math.util.VectorNormalizer; import org.tribuo.multilabel.MultiLabel; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXNode; import java.util.HashMap; import java.util.HashSet; diff --git a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java index 1342b8094..c54b7e654 100644 --- a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java +++ b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java @@ -19,6 +19,7 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.common.sgd.AbstractLinearSGDModel; @@ -26,9 +27,8 @@ import org.tribuo.math.la.DenseVector; import org.tribuo.math.util.VectorNormalizer; import org.tribuo.multilabel.MultiLabel; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; import org.tribuo.provenance.ModelProvenance; +import org.tribuo.util.onnx.ONNXNode; import java.util.HashMap; import java.util.HashSet; diff --git a/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java b/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java index 48d6964b7..d74565714 100644 --- a/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java +++ b/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java @@ -22,11 +22,12 @@ import org.tribuo.ImmutableOutputInfo; import org.tribuo.Prediction; import org.tribuo.ensemble.EnsembleCombiner; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.regression.Regressor; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXRef; import java.util.Arrays; import java.util.Collections; @@ -129,7 +130,7 @@ public ConfiguredObjectProvenance getProvenance() { } /** - * Exports this averaging combiner, writing constructed nodes into the {@link org.tribuo.onnx.ONNXContext} + * Exports this averaging combiner, writing constructed nodes into the {@link ONNXContext} * governing {@code input} and returning the leaf node of the combiner. *

* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members]. @@ -145,7 +146,7 @@ public ONNXNode exportCombiner(ONNXNode input) { } /** - * Exports this averaging combiner, writing constructed nodes into the {@link org.tribuo.onnx.ONNXContext} + * Exports this averaging combiner, writing constructed nodes into the {@link ONNXContext} * governing {@code input} and returning the leaf node of the combiner. *

* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members]. diff --git a/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java b/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java index 4390a2f65..b9fc27d21 100644 --- a/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java +++ b/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java @@ -26,19 +26,19 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.common.liblinear.LibLinearModel; import org.tribuo.common.liblinear.LibLinearTrainer; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.ModelProvenance; import org.tribuo.regression.ImmutableRegressionInfo; import org.tribuo.regression.Regressor; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.io.IOException; import java.util.ArrayList; @@ -199,7 +199,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { ONNXPlaceholder output = onnx.floatOutput(outputIDInfo.size()); onnx.setName("Regression-LibLinear"); - return writeONNXGraph(input).assignTo(output).onnxContext().model(domain, modelVersion, this); + return ONNXExportable.buildModel(writeONNXGraph(input).assignTo(output).onnxContext(), domain, modelVersion, this); } @Override diff --git a/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java b/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java index 40f74214b..ccf3b3bbd 100644 --- a/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java +++ b/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java @@ -23,21 +23,21 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.common.libsvm.KernelType; import org.tribuo.common.libsvm.LibSVMModel; import org.tribuo.common.libsvm.LibSVMTrainer; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXInitializer; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; import org.tribuo.provenance.ModelProvenance; import org.tribuo.regression.ImmutableRegressionInfo; import org.tribuo.regression.Regressor; import org.tribuo.util.Util; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; import java.io.IOException; import java.util.ArrayList; @@ -204,7 +204,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { ONNXPlaceholder output = onnx.floatOutput(outputIDInfo.size()); onnx.setName("Regression-LibSVM"); - return writeONNXGraph(input).assignTo(output).onnxContext().model(domain, modelVersion, this); + return ONNXExportable.buildModel(writeONNXGraph(input).assignTo(output).onnxContext(), domain, modelVersion, this); } private static ONNXNode buildONNXSVMRegressor(int numFeatures, ONNXRef input, svm_model model) { diff --git a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java index 7308644f5..9b1eaeb4f 100644 --- a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java +++ b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java @@ -19,16 +19,16 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.common.sgd.AbstractFMModel; import org.tribuo.common.sgd.FMParameters; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.ModelProvenance; import org.tribuo.regression.ImmutableRegressionInfo; import org.tribuo.regression.Regressor; +import org.tribuo.util.onnx.ONNXInitializer; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; import java.util.Arrays; diff --git a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java index ec72ece1a..8f32b1e80 100644 --- a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java +++ b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java @@ -19,14 +19,14 @@ import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.common.sgd.AbstractLinearSGDModel; import org.tribuo.math.LinearParameters; import org.tribuo.math.la.DenseMatrix; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; import org.tribuo.provenance.ModelProvenance; import org.tribuo.regression.Regressor; +import org.tribuo.util.onnx.ONNXNode; import java.io.IOException; import java.util.Arrays; diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java index 03e91c818..4048ca794 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java @@ -23,24 +23,24 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; +import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.VariableInfo; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SparseVector; import org.tribuo.math.la.VectorTuple; -import org.tribuo.onnx.ONNXContext; -import org.tribuo.onnx.ONNXExportable; -import org.tribuo.onnx.ONNXNode; -import org.tribuo.onnx.ONNXOperators; -import org.tribuo.onnx.ONNXPlaceholder; -import org.tribuo.onnx.ONNXRef; -import org.tribuo.onnx.ONNXInitializer; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.regression.ImmutableRegressionInfo; import org.tribuo.regression.Regressor; import org.tribuo.regression.Regressor.DimensionTuple; import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel; +import org.tribuo.util.onnx.ONNXContext; +import org.tribuo.util.onnx.ONNXNode; +import org.tribuo.util.onnx.ONNXOperators; +import org.tribuo.util.onnx.ONNXPlaceholder; +import org.tribuo.util.onnx.ONNXRef; +import org.tribuo.util.onnx.ONNXInitializer; import java.io.IOException; import java.nio.FloatBuffer; @@ -226,7 +226,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { ONNXPlaceholder output = onnx.floatOutput(outputIDInfo.size()); onnx.setName("Regression-SparseLinearModel"); - return writeONNXGraph(input).assignTo(output).onnxContext().model(domain, modelVersion, this); + return ONNXExportable.buildModel(writeONNXGraph(input).assignTo(output).onnxContext(), domain, modelVersion, this); } @Override diff --git a/Util/ONNXExport/pom.xml b/Util/ONNXExport/pom.xml new file mode 100644 index 000000000..1462c062c --- /dev/null +++ b/Util/ONNXExport/pom.xml @@ -0,0 +1,63 @@ + + + + + 4.0.0 + + org.tribuo + tribuo-util + 4.2.0-SNAPSHOT + ../pom.xml + + ONNXExport + tribuo-util-onnx + jar + + 1.8 + 1.8 + + + + + com.google.protobuf + protobuf-java + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + org.tribuo.util.onnx + + + + + + + diff --git a/Core/src/main/java/ai/onnx/proto/NOTICE.md b/Util/ONNXExport/src/main/java/ai/onnx/proto/NOTICE.md similarity index 100% rename from Core/src/main/java/ai/onnx/proto/NOTICE.md rename to Util/ONNXExport/src/main/java/ai/onnx/proto/NOTICE.md diff --git a/Core/src/main/java/ai/onnx/proto/OnnxMl.java b/Util/ONNXExport/src/main/java/ai/onnx/proto/OnnxMl.java similarity index 100% rename from Core/src/main/java/ai/onnx/proto/OnnxMl.java rename to Util/ONNXExport/src/main/java/ai/onnx/proto/OnnxMl.java diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXAttribute.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXAttribute.java similarity index 99% rename from Core/src/main/java/org/tribuo/onnx/ONNXAttribute.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXAttribute.java index 337699b6e..00d7431bd 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXAttribute.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXAttribute.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import com.google.protobuf.ByteString; import ai.onnx.proto.OnnxMl; diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXContext.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java similarity index 89% rename from Core/src/main/java/org/tribuo/onnx/ONNXContext.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java index 9e7f172b7..8a175e5d3 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXContext.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java @@ -14,13 +14,9 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; -import com.oracle.labs.mlrg.olcut.provenance.Provenancable; -import com.oracle.labs.mlrg.olcut.util.MutableLong; -import org.tribuo.Tribuo; -import org.tribuo.provenance.ModelProvenance; import java.nio.FloatBuffer; import java.util.Collections; @@ -42,7 +38,7 @@ */ public final class ONNXContext { - private final Map nameMap; + private final Map nameMap; private final OnnxMl.GraphProto.Builder protoBuilder; @@ -292,42 +288,15 @@ public ONNXInitializer constant(String baseName, long value) { return new ONNXInitializer(this, constant, baseName); } - /** - * Creates an ONNX model protobuf for this context. - * @param domain Domain for the produced model. - * @param modelVersion Model version for the produced model. - * @param model Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data - * from this model will be written into the ONNX Model proto. - * @param The type of the provenanced model. - * @return An ONNX model proto of the graph represented by this ONNXContext. - */ - public > OnnxMl.ModelProto model(String domain, long modelVersion, M model) { - return OnnxMl.ModelProto.newBuilder() - .setGraph(protoBuilder.build()) - .setDomain(domain) - .setProducerName("Tribuo") - .setProducerVersion(Tribuo.VERSION) - .setModelVersion(modelVersion) - .addOpsetImport(ONNXOperators.getOpsetProto()) - .setIrVersion(6) - .setDocString(model.toString()) - .addMetadataProps(OnnxMl.StringStringEntryProto - .newBuilder() - .setKey(ONNXExportable.PROVENANCE_METADATA_FIELD) - .setValue(ONNXExportable.SERIALIZER.marshalAndSerialize(model.getProvenance())) - .build()) - .build(); - } - /** * Generates a unique name by appending the counter for that name. * @param name The name. * @return A unique version of that name. */ String generateUniqueName(String name) { - MutableLong counter = nameMap.computeIfAbsent(name,k -> new MutableLong()); - String newName = name + "_" + counter.longValue(); - counter.increment(); + long counter = nameMap.computeIfAbsent(name,k -> 0L); + String newName = name + "_" + counter; + nameMap.put(name,counter + 1); return newName; } @@ -338,4 +307,12 @@ String generateUniqueName(String name) { public void setName(String name) { protoBuilder.setName(name); } + + /** + * Builds the ONNX graph represented by this context. + * @return The ONNX graph proto. + */ + public OnnxMl.GraphProto buildGraph() { + return protoBuilder.build(); + } } diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXInitializer.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXInitializer.java similarity index 97% rename from Core/src/main/java/org/tribuo/onnx/ONNXInitializer.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXInitializer.java index e9316b67a..fe746e723 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXInitializer.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXInitializer.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXNode.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXNode.java similarity index 97% rename from Core/src/main/java/org/tribuo/onnx/ONNXNode.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXNode.java index cf5b7b742..e86864f29 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXNode.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXNode.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java similarity index 99% rename from Core/src/main/java/org/tribuo/onnx/ONNXOperators.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java index 220bd8f5f..028a62318 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; @@ -26,7 +26,7 @@ import java.util.Map; import java.util.Set; -import static org.tribuo.onnx.ONNXAttribute.VARIADIC_INPUT; +import static org.tribuo.util.onnx.ONNXAttribute.VARIADIC_INPUT; /** * The supported ONNX operators. diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXPlaceholder.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXPlaceholder.java similarity index 97% rename from Core/src/main/java/org/tribuo/onnx/ONNXPlaceholder.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXPlaceholder.java index 2aa4141c3..47fddb5da 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXPlaceholder.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXPlaceholder.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXRef.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java similarity index 99% rename from Core/src/main/java/org/tribuo/onnx/ONNXRef.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java index fc085a659..5f7d26bf7 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXRef.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; import com.google.protobuf.GeneratedMessageV3; diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXShape.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXShape.java similarity index 99% rename from Core/src/main/java/org/tribuo/onnx/ONNXShape.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXShape.java index fab4a3b22..dbb15ea26 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXShape.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXShape.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXUtils.java similarity index 99% rename from Core/src/main/java/org/tribuo/onnx/ONNXUtils.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXUtils.java index 8e8c59149..2e3e47add 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tribuo.onnx; +package org.tribuo.util.onnx; import ai.onnx.proto.OnnxMl; import com.google.protobuf.ByteString; diff --git a/Core/src/main/java/org/tribuo/onnx/package-info.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java similarity index 58% rename from Core/src/main/java/org/tribuo/onnx/package-info.java rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java index 64364b070..4abf7339f 100644 --- a/Core/src/main/java/org/tribuo/onnx/package-info.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java @@ -15,13 +15,9 @@ */ /** - * Interfaces and utilities for exporting Tribuo {@link org.tribuo.Model}s in - * ONNX format. + * Interfaces and utilities for writing ONNX Runtime. + * Developed to support Tribuo, but can be used to export + * other machine learning models from JVM languages. */ -package org.tribuo.onnx; \ No newline at end of file +package org.tribuo.util.onnx; diff --git a/Util/ONNXExport/src/main/resources/LICENSE.txt b/Util/ONNXExport/src/main/resources/LICENSE.txt new file mode 120000 index 000000000..385aeaf43 --- /dev/null +++ b/Util/ONNXExport/src/main/resources/LICENSE.txt @@ -0,0 +1 @@ +../../../../../LICENSE.txt \ No newline at end of file diff --git a/Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt b/Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt new file mode 120000 index 000000000..7505da00b --- /dev/null +++ b/Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt @@ -0,0 +1 @@ +../../../../../THIRD_PARTY_LICENSES.txt \ No newline at end of file diff --git a/Util/pom.xml b/Util/pom.xml index 8d9204e94..ec97f0ba1 100644 --- a/Util/pom.xml +++ b/Util/pom.xml @@ -28,6 +28,7 @@ pom InformationTheory + ONNXExport Tokenization @@ -40,6 +41,11 @@ olcut-core ${olcut.version} + + com.google.protobuf + protobuf-java + ${protobuf.version} + org.junit.jupiter junit-jupiter diff --git a/distribution/pom.xml b/distribution/pom.xml index 805267c77..f8ea68e58 100644 --- a/distribution/pom.xml +++ b/distribution/pom.xml @@ -220,6 +220,11 @@ tribuo-util-infotheory ${project.version} + + org.tribuo + tribuo-util-onnx + ${project.version} + org.tribuo tribuo-util-tokenization diff --git a/pom.xml b/pom.xml index 452466336..0deb3bb80 100644 --- a/pom.xml +++ b/pom.xml @@ -198,11 +198,6 @@ maven-site-plugin 3.9.1 - - org.apache.maven.plugins - maven-source-plugin - 3.1.0 - org.apache.maven.plugins maven-install-plugin @@ -245,7 +240,7 @@ Core Packages - org.tribuo:org.tribuo.dataset:org.tribuo.datasource:org.tribuo.ensemble:org.tribuo.evaluation*:org.tribuo.hash:org.tribuo.impl:org.tribuo.onnx:org.tribuo.provenance*:org.tribuo.sequence:org.tribuo.transform*:org.tribuo.util:org.tribuo.data*:org.tribuo.json:org.tribuo.math* + org.tribuo:org.tribuo.dataset:org.tribuo.datasource:org.tribuo.ensemble:org.tribuo.evaluation*:org.tribuo.hash:org.tribuo.impl:org.tribuo.provenance*:org.tribuo.sequence:org.tribuo.transform*:org.tribuo.util:org.tribuo.data*:org.tribuo.json:org.tribuo.math* @@ -296,7 +291,7 @@ Core Packages - org.tribuo:org.tribuo.dataset:org.tribuo.datasource:org.tribuo.ensemble:org.tribuo.evaluation*:org.tribuo.hash:org.tribuo.impl:org.tribuo.onnx:org.tribuo.provenance*:org.tribuo.sequence:org.tribuo.transform*:org.tribuo.util:org.tribuo.data*:org.tribuo.json:org.tribuo.math* + org.tribuo:org.tribuo.dataset:org.tribuo.datasource:org.tribuo.ensemble:org.tribuo.evaluation*:org.tribuo.hash:org.tribuo.impl:org.tribuo.provenance*:org.tribuo.sequence:org.tribuo.transform*:org.tribuo.util:org.tribuo.data*:org.tribuo.json:org.tribuo.math*