com.oracle.labs.olcut
olcut-config-protobuf
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java b/Core/src/main/java/org/tribuo/ONNXExportable.java
similarity index 55%
rename from Core/src/main/java/org/tribuo/onnx/ONNXExportable.java
rename to Core/src/main/java/org/tribuo/ONNXExportable.java
index b88a3d46e..e1c8130d7 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXExportable.java
+++ b/Core/src/main/java/org/tribuo/ONNXExportable.java
@@ -14,12 +14,17 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo;
import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization;
+import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
import org.tribuo.provenance.ModelProvenance;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXRef;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
@@ -33,6 +38,12 @@
* Tribuo models export with a single input of size [-1, numFeatures] and a
* single output of size [-1, numOutputDimensions]. The first dimension in both
* is defined to be an unbound dimension called "batch", which denotes the batch size.
+ *
+ * ONNX exported models use floats where Tribuo uses doubles, this is due
+ * to comparatively poor support for fp64 in ONNX deployment environments
+ * as compared to fp32. In addition, fp32 executes better on the various
+ * accelerator backends available in
+ * ONNX Runtime.
*/
public interface ONNXExportable {
@@ -47,9 +58,39 @@ public interface ONNXExportable {
*/
public static final String PROVENANCE_METADATA_FIELD = "TRIBUO_PROVENANCE";
+ /**
+ * Creates an ONNX model protobuf for the supplied context.
+ *
+ * @param onnxContext The context which contains the ONNX graph.
+ * @param domain Domain for the produced model.
+ * @param modelVersion Model version for the produced model.
+ * @param model Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data
+ * from this model will be written into the ONNX Model proto.
+ * @param The type of the provenanced model.
+ * @return An ONNX model proto of the graph represented by the supplied ONNXContext.
+ */
+ public static > OnnxMl.ModelProto buildModel(ONNXContext onnxContext, String domain, long modelVersion, M model) {
+ return OnnxMl.ModelProto.newBuilder()
+ .setGraph(onnxContext.buildGraph())
+ .setDomain(domain)
+ .setProducerName("Tribuo")
+ .setProducerVersion(Tribuo.VERSION)
+ .setModelVersion(modelVersion)
+ .addOpsetImport(ONNXOperators.getOpsetProto())
+ .setIrVersion(6)
+ .setDocString(model.toString())
+ .addMetadataProps(OnnxMl.StringStringEntryProto
+ .newBuilder()
+ .setKey(PROVENANCE_METADATA_FIELD)
+ .setValue(SERIALIZER.marshalAndSerialize(model.getProvenance()))
+ .build())
+ .build();
+ }
+
/**
* Exports this {@link org.tribuo.Model} as an ONNX protobuf.
- * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
+ *
+ * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
* @param modelVersion A version number for this model.
* @return The ONNX ModelProto representing this Tribuo Model.
*/
@@ -58,6 +99,7 @@ public interface ONNXExportable {
/**
* Writes this {@link org.tribuo.Model} into {@link OnnxMl.GraphProto.Builder} inside the input's
* {@link ONNXContext}.
+ *
* @param input The input to the model graph.
* @return the output node of the model graph.
*/
@@ -65,13 +107,14 @@ public interface ONNXExportable {
/**
* Exports this {@link org.tribuo.Model} as an ONNX file.
- * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
+ *
+ * @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
* @param modelVersion A version number for this model.
- * @param outputPath The path to write to.
+ * @param outputPath The path to write to.
* @throws IOException if the file could not be written to.
*/
default public void saveONNXModel(String domain, long modelVersion, Path outputPath) throws IOException {
- OnnxMl.ModelProto proto = exportONNXModel(domain,modelVersion);
+ OnnxMl.ModelProto proto = exportONNXModel(domain, modelVersion);
try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(outputPath.toFile()))) {
proto.writeTo(bos);
}
@@ -79,6 +122,7 @@ default public void saveONNXModel(String domain, long modelVersion, Path outputP
/**
* Serializes the model provenance to a String.
+ *
* @param provenance The provenance to serialize.
* @return The serialized form of the ModelProvenance.
*/
diff --git a/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java b/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java
index 7ec788cb6..af3a11e01 100644
--- a/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java
+++ b/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java
@@ -22,8 +22,8 @@
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXRef;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXRef;
import java.io.Serializable;
import java.util.List;
diff --git a/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java b/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java
index 0d96d0981..40ff5a1f5 100644
--- a/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java
+++ b/Core/src/main/java/org/tribuo/ensemble/WeightedEnsembleModel.java
@@ -24,18 +24,18 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
+import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXPlaceholder;
-import org.tribuo.onnx.ONNXRef;
-import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.impl.TimestampedTrainerProvenance;
import org.tribuo.util.Util;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXPlaceholder;
+import org.tribuo.util.onnx.ONNXRef;
import java.time.OffsetDateTime;
import java.util.ArrayList;
@@ -247,7 +247,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
// Build graph
writeONNXGraph(input).assignTo(output);
- return onnx.model(domain, modelVersion, this);
+ return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
}
@Override
diff --git a/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java b/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java
index 046929174..668fc7716 100644
--- a/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java
+++ b/Interop/ONNX/src/main/java/org/tribuo/interop/onnx/ONNXExternalModel.java
@@ -22,9 +22,7 @@
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
-import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
-import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerializationException;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
@@ -33,6 +31,7 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
+import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
@@ -40,7 +39,6 @@
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.math.la.SparseVector;
-import org.tribuo.onnx.ONNXExportable;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
diff --git a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java
index d3e6c4b87..9203fbfc5 100644
--- a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java
+++ b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java
@@ -18,8 +18,8 @@
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
import java.nio.FloatBuffer;
import java.util.Arrays;
diff --git a/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java b/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java
index 1ac4869a8..1f7aadf49 100644
--- a/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java
+++ b/Math/src/main/java/org/tribuo/math/util/ExpNormalizer.java
@@ -16,8 +16,8 @@
package org.tribuo.math.util;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
import java.io.Serializable;
import java.util.Arrays;
diff --git a/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java b/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java
index d86885f77..b7e2a7398 100644
--- a/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java
+++ b/Math/src/main/java/org/tribuo/math/util/NoopNormalizer.java
@@ -16,7 +16,7 @@
package org.tribuo.math.util;
-import org.tribuo.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXNode;
import java.io.Serializable;
import java.util.Arrays;
diff --git a/Math/src/main/java/org/tribuo/math/util/Normalizer.java b/Math/src/main/java/org/tribuo/math/util/Normalizer.java
index 12559fced..608bfe74f 100644
--- a/Math/src/main/java/org/tribuo/math/util/Normalizer.java
+++ b/Math/src/main/java/org/tribuo/math/util/Normalizer.java
@@ -16,10 +16,10 @@
package org.tribuo.math.util;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
import java.io.Serializable;
import java.util.Arrays;
diff --git a/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java b/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java
index d5cd23c41..0e1057708 100644
--- a/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java
+++ b/Math/src/main/java/org/tribuo/math/util/SigmoidNormalizer.java
@@ -16,8 +16,8 @@
package org.tribuo.math.util;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
import java.io.Serializable;
import java.util.Arrays;
diff --git a/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java b/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java
index 255b6d699..742a87122 100644
--- a/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java
+++ b/Math/src/main/java/org/tribuo/math/util/VectorNormalizer.java
@@ -16,7 +16,8 @@
package org.tribuo.math.util;
-import org.tribuo.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXNode;
import java.io.Serializable;
import java.util.logging.Logger;
@@ -48,7 +49,7 @@ default public void normalizeInPlace(double[] input) {
/**
* Exports this normalizer to ONNX, returning the leaf of the appended graph
- * and writing the nodes needed for normalization into the {@link org.tribuo.onnx.ONNXContext}
+ * and writing the nodes needed for normalization into the {@link ONNXContext}
* that {@code input} belongs to.
*
* For compatibility reasons this method has a default implementation, though
diff --git a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java
index 4089e7bb4..2a4785d43 100644
--- a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java
+++ b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/ensemble/MultiLabelVotingCombiner.java
@@ -25,11 +25,11 @@
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.math.la.DenseVector;
import org.tribuo.multilabel.MultiLabel;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXRef;
-import org.tribuo.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXRef;
import java.util.Arrays;
import java.util.Collections;
diff --git a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java
index da95f4327..b355d5a9a 100644
--- a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java
+++ b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java
@@ -19,6 +19,7 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractFMModel;
@@ -26,9 +27,8 @@
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.MultiLabel;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
import org.tribuo.provenance.ModelProvenance;
+import org.tribuo.util.onnx.ONNXNode;
import java.util.HashMap;
import java.util.HashSet;
diff --git a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java
index 1342b8094..c54b7e654 100644
--- a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java
+++ b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java
@@ -19,6 +19,7 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
@@ -26,9 +27,8 @@
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.MultiLabel;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
import org.tribuo.provenance.ModelProvenance;
+import org.tribuo.util.onnx.ONNXNode;
import java.util.HashMap;
import java.util.HashSet;
diff --git a/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java b/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java
index 48d6964b7..d74565714 100644
--- a/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java
+++ b/Regression/Core/src/main/java/org/tribuo/regression/ensemble/AveragingCombiner.java
@@ -22,11 +22,12 @@
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.ensemble.EnsembleCombiner;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXRef;
-import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.regression.Regressor;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXRef;
import java.util.Arrays;
import java.util.Collections;
@@ -129,7 +130,7 @@ public ConfiguredObjectProvenance getProvenance() {
}
/**
- * Exports this averaging combiner, writing constructed nodes into the {@link org.tribuo.onnx.ONNXContext}
+ * Exports this averaging combiner, writing constructed nodes into the {@link ONNXContext}
* governing {@code input} and returning the leaf node of the combiner.
*
* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
@@ -145,7 +146,7 @@ public ONNXNode exportCombiner(ONNXNode input) {
}
/**
- * Exports this averaging combiner, writing constructed nodes into the {@link org.tribuo.onnx.ONNXContext}
+ * Exports this averaging combiner, writing constructed nodes into the {@link ONNXContext}
* governing {@code input} and returning the leaf node of the combiner.
*
* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
diff --git a/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java b/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java
index 4390a2f65..b9fc27d21 100644
--- a/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java
+++ b/Regression/LibLinear/src/main/java/org/tribuo/regression/liblinear/LibLinearRegressionModel.java
@@ -26,19 +26,19 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXPlaceholder;
-import org.tribuo.onnx.ONNXRef;
-import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXPlaceholder;
+import org.tribuo.util.onnx.ONNXRef;
import java.io.IOException;
import java.util.ArrayList;
@@ -199,7 +199,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXPlaceholder output = onnx.floatOutput(outputIDInfo.size());
onnx.setName("Regression-LibLinear");
- return writeONNXGraph(input).assignTo(output).onnxContext().model(domain, modelVersion, this);
+ return ONNXExportable.buildModel(writeONNXGraph(input).assignTo(output).onnxContext(), domain, modelVersion, this);
}
@Override
diff --git a/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java b/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java
index 40f74214b..ccf3b3bbd 100644
--- a/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java
+++ b/Regression/LibSVM/src/main/java/org/tribuo/regression/libsvm/LibSVMRegressionModel.java
@@ -23,21 +23,21 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXInitializer;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXPlaceholder;
-import org.tribuo.onnx.ONNXRef;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.util.Util;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXPlaceholder;
+import org.tribuo.util.onnx.ONNXRef;
import java.io.IOException;
import java.util.ArrayList;
@@ -204,7 +204,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXPlaceholder output = onnx.floatOutput(outputIDInfo.size());
onnx.setName("Regression-LibSVM");
- return writeONNXGraph(input).assignTo(output).onnxContext().model(domain, modelVersion, this);
+ return ONNXExportable.buildModel(writeONNXGraph(input).assignTo(output).onnxContext(), domain, modelVersion, this);
}
private static ONNXNode buildONNXSVMRegressor(int numFeatures, ONNXRef> input, svm_model model) {
diff --git a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java
index 7308644f5..9b1eaeb4f 100644
--- a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java
+++ b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java
@@ -19,16 +19,16 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.common.sgd.AbstractFMModel;
import org.tribuo.common.sgd.FMParameters;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
+import org.tribuo.util.onnx.ONNXInitializer;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
import java.util.Arrays;
diff --git a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java
index ec72ece1a..8f32b1e80 100644
--- a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java
+++ b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java
@@ -19,14 +19,14 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
+import org.tribuo.util.onnx.ONNXNode;
import java.io.IOException;
import java.util.Arrays;
diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java
index 03e91c818..4048ca794 100644
--- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java
+++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java
@@ -23,24 +23,24 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
+import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.VariableInfo;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
-import org.tribuo.onnx.ONNXContext;
-import org.tribuo.onnx.ONNXExportable;
-import org.tribuo.onnx.ONNXNode;
-import org.tribuo.onnx.ONNXOperators;
-import org.tribuo.onnx.ONNXPlaceholder;
-import org.tribuo.onnx.ONNXRef;
-import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.Regressor.DimensionTuple;
import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;
+import org.tribuo.util.onnx.ONNXContext;
+import org.tribuo.util.onnx.ONNXNode;
+import org.tribuo.util.onnx.ONNXOperators;
+import org.tribuo.util.onnx.ONNXPlaceholder;
+import org.tribuo.util.onnx.ONNXRef;
+import org.tribuo.util.onnx.ONNXInitializer;
import java.io.IOException;
import java.nio.FloatBuffer;
@@ -226,7 +226,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXPlaceholder output = onnx.floatOutput(outputIDInfo.size());
onnx.setName("Regression-SparseLinearModel");
- return writeONNXGraph(input).assignTo(output).onnxContext().model(domain, modelVersion, this);
+ return ONNXExportable.buildModel(writeONNXGraph(input).assignTo(output).onnxContext(), domain, modelVersion, this);
}
@Override
diff --git a/Util/ONNXExport/pom.xml b/Util/ONNXExport/pom.xml
new file mode 100644
index 000000000..1462c062c
--- /dev/null
+++ b/Util/ONNXExport/pom.xml
@@ -0,0 +1,63 @@
+
+
+
+
+ 4.0.0
+
+ org.tribuo
+ tribuo-util
+ 4.2.0-SNAPSHOT
+ ../pom.xml
+
+ ONNXExport
+ tribuo-util-onnx
+ jar
+
+ 1.8
+ 1.8
+
+
+
+
+ com.google.protobuf
+ protobuf-java
+
+
+ org.junit.jupiter
+ junit-jupiter
+ ${junit.version}
+ test
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+
+ org.tribuo.util.onnx
+
+
+
+
+
+
+
diff --git a/Core/src/main/java/ai/onnx/proto/NOTICE.md b/Util/ONNXExport/src/main/java/ai/onnx/proto/NOTICE.md
similarity index 100%
rename from Core/src/main/java/ai/onnx/proto/NOTICE.md
rename to Util/ONNXExport/src/main/java/ai/onnx/proto/NOTICE.md
diff --git a/Core/src/main/java/ai/onnx/proto/OnnxMl.java b/Util/ONNXExport/src/main/java/ai/onnx/proto/OnnxMl.java
similarity index 100%
rename from Core/src/main/java/ai/onnx/proto/OnnxMl.java
rename to Util/ONNXExport/src/main/java/ai/onnx/proto/OnnxMl.java
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXAttribute.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXAttribute.java
similarity index 99%
rename from Core/src/main/java/org/tribuo/onnx/ONNXAttribute.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXAttribute.java
index 337699b6e..00d7431bd 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXAttribute.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXAttribute.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import com.google.protobuf.ByteString;
import ai.onnx.proto.OnnxMl;
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXContext.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java
similarity index 89%
rename from Core/src/main/java/org/tribuo/onnx/ONNXContext.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java
index 9e7f172b7..8a175e5d3 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXContext.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java
@@ -14,13 +14,9 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
-import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
-import com.oracle.labs.mlrg.olcut.util.MutableLong;
-import org.tribuo.Tribuo;
-import org.tribuo.provenance.ModelProvenance;
import java.nio.FloatBuffer;
import java.util.Collections;
@@ -42,7 +38,7 @@
*/
public final class ONNXContext {
- private final Map nameMap;
+ private final Map nameMap;
private final OnnxMl.GraphProto.Builder protoBuilder;
@@ -292,42 +288,15 @@ public ONNXInitializer constant(String baseName, long value) {
return new ONNXInitializer(this, constant, baseName);
}
- /**
- * Creates an ONNX model protobuf for this context.
- * @param domain Domain for the produced model.
- * @param modelVersion Model version for the produced model.
- * @param model Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data
- * from this model will be written into the ONNX Model proto.
- * @param The type of the provenanced model.
- * @return An ONNX model proto of the graph represented by this ONNXContext.
- */
- public > OnnxMl.ModelProto model(String domain, long modelVersion, M model) {
- return OnnxMl.ModelProto.newBuilder()
- .setGraph(protoBuilder.build())
- .setDomain(domain)
- .setProducerName("Tribuo")
- .setProducerVersion(Tribuo.VERSION)
- .setModelVersion(modelVersion)
- .addOpsetImport(ONNXOperators.getOpsetProto())
- .setIrVersion(6)
- .setDocString(model.toString())
- .addMetadataProps(OnnxMl.StringStringEntryProto
- .newBuilder()
- .setKey(ONNXExportable.PROVENANCE_METADATA_FIELD)
- .setValue(ONNXExportable.SERIALIZER.marshalAndSerialize(model.getProvenance()))
- .build())
- .build();
- }
-
/**
* Generates a unique name by appending the counter for that name.
* @param name The name.
* @return A unique version of that name.
*/
String generateUniqueName(String name) {
- MutableLong counter = nameMap.computeIfAbsent(name,k -> new MutableLong());
- String newName = name + "_" + counter.longValue();
- counter.increment();
+ long counter = nameMap.computeIfAbsent(name,k -> 0L);
+ String newName = name + "_" + counter;
+ nameMap.put(name,counter + 1);
return newName;
}
@@ -338,4 +307,12 @@ String generateUniqueName(String name) {
public void setName(String name) {
protoBuilder.setName(name);
}
+
+ /**
+ * Builds the ONNX graph represented by this context.
+ * @return The ONNX graph proto.
+ */
+ public OnnxMl.GraphProto buildGraph() {
+ return protoBuilder.build();
+ }
}
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXInitializer.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXInitializer.java
similarity index 97%
rename from Core/src/main/java/org/tribuo/onnx/ONNXInitializer.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXInitializer.java
index e9316b67a..fe746e723 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXInitializer.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXInitializer.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXNode.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXNode.java
similarity index 97%
rename from Core/src/main/java/org/tribuo/onnx/ONNXNode.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXNode.java
index cf5b7b742..e86864f29 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXNode.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXNode.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java
similarity index 99%
rename from Core/src/main/java/org/tribuo/onnx/ONNXOperators.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java
index 220bd8f5f..028a62318 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
@@ -26,7 +26,7 @@
import java.util.Map;
import java.util.Set;
-import static org.tribuo.onnx.ONNXAttribute.VARIADIC_INPUT;
+import static org.tribuo.util.onnx.ONNXAttribute.VARIADIC_INPUT;
/**
* The supported ONNX operators.
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXPlaceholder.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXPlaceholder.java
similarity index 97%
rename from Core/src/main/java/org/tribuo/onnx/ONNXPlaceholder.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXPlaceholder.java
index 2aa4141c3..47fddb5da 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXPlaceholder.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXPlaceholder.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXRef.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java
similarity index 99%
rename from Core/src/main/java/org/tribuo/onnx/ONNXRef.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java
index fc085a659..5f7d26bf7 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXRef.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
import com.google.protobuf.GeneratedMessageV3;
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXShape.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXShape.java
similarity index 99%
rename from Core/src/main/java/org/tribuo/onnx/ONNXShape.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXShape.java
index fab4a3b22..dbb15ea26 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXShape.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXShape.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXUtils.java
similarity index 99%
rename from Core/src/main/java/org/tribuo/onnx/ONNXUtils.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXUtils.java
index 8e8c59149..2e3e47add 100644
--- a/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXUtils.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.tribuo.onnx;
+package org.tribuo.util.onnx;
import ai.onnx.proto.OnnxMl;
import com.google.protobuf.ByteString;
diff --git a/Core/src/main/java/org/tribuo/onnx/package-info.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java
similarity index 58%
rename from Core/src/main/java/org/tribuo/onnx/package-info.java
rename to Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java
index 64364b070..4abf7339f 100644
--- a/Core/src/main/java/org/tribuo/onnx/package-info.java
+++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java
@@ -15,13 +15,9 @@
*/
/**
- * Interfaces and utilities for exporting Tribuo {@link org.tribuo.Model}s in
- * ONNX format.
+ * Interfaces and utilities for writing ONNX Runtime.
+ * Developed to support Tribuo, but can be used to export
+ * other machine learning models from JVM languages.
*/
-package org.tribuo.onnx;
\ No newline at end of file
+package org.tribuo.util.onnx;
diff --git a/Util/ONNXExport/src/main/resources/LICENSE.txt b/Util/ONNXExport/src/main/resources/LICENSE.txt
new file mode 120000
index 000000000..385aeaf43
--- /dev/null
+++ b/Util/ONNXExport/src/main/resources/LICENSE.txt
@@ -0,0 +1 @@
+../../../../../LICENSE.txt
\ No newline at end of file
diff --git a/Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt b/Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt
new file mode 120000
index 000000000..7505da00b
--- /dev/null
+++ b/Util/ONNXExport/src/main/resources/THIRD_PARTY_LICENSES.txt
@@ -0,0 +1 @@
+../../../../../THIRD_PARTY_LICENSES.txt
\ No newline at end of file
diff --git a/Util/pom.xml b/Util/pom.xml
index 8d9204e94..ec97f0ba1 100644
--- a/Util/pom.xml
+++ b/Util/pom.xml
@@ -28,6 +28,7 @@
pom
InformationTheory
+ ONNXExport
Tokenization
@@ -40,6 +41,11 @@
olcut-core
${olcut.version}
+