Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving ONNX export utils out into a separate module #203

Merged
merged 1 commit into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

import java.util.Collections;
import java.util.HashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

import java.util.Collections;
import java.util.HashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXPlaceholder;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -298,7 +298,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
// Build graph
writeONNXGraph(input).assignTo(output);

return onnx.model(domain, modelVersion, this);
return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXPlaceholder;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -176,7 +176,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
onnx.setName("Classification-LibSVM");

writeONNXGraph(input).assignTo(output);
return onnx.model(domain, modelVersion, this);
return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractFMModel;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXNode;

import java.util.LinkedHashMap;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXNode;

import java.io.IOException;
import java.util.LinkedHashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.onnx.ONNXMathUtils;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXPlaceholder;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -267,7 +268,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXPlaceholder input = onnx.floatInput("input", featureIDMap.size());
ONNXPlaceholder output = onnx.floatOutput("output", outputIDInfo.size());
writeONNXGraph(input).assignTo(output);
return onnx.model(domain, modelVersion, this);
return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.Matrix;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXPlaceholder;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -213,7 +214,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXPlaceholder input = onnx.floatInput("input", featureIDMap.size());
ONNXPlaceholder output = onnx.floatOutput("output", outputIDInfo.size());
writeONNXGraph(input).assignTo(output);
return onnx.model(domain, modelVersion, this);
return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
}

}
5 changes: 5 additions & 0 deletions Core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-core</artifactId>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-util-onnx</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-config-protobuf</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
* limitations under the License.
*/

package org.tribuo.onnx;
package org.tribuo;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
Expand All @@ -33,6 +38,12 @@
* Tribuo models export with a single input of size [-1, numFeatures] and a
* single output of size [-1, numOutputDimensions]. The first dimension in both
* is defined to be an unbound dimension called "batch", which denotes the batch size.
* <p>
* ONNX exported models use floats where Tribuo uses doubles, this is due
* to comparatively poor support for fp64 in ONNX deployment environments
* as compared to fp32. In addition, fp32 executes better on the various
* accelerator backends available in
* <a href="https://onnxruntime.ai">ONNX Runtime</a>.
*/
public interface ONNXExportable {

Expand All @@ -47,9 +58,39 @@ public interface ONNXExportable {
*/
public static final String PROVENANCE_METADATA_FIELD = "TRIBUO_PROVENANCE";

/**
* Creates an ONNX model protobuf for the supplied context.
*
* @param onnxContext The context which contains the ONNX graph.
* @param domain Domain for the produced model.
* @param modelVersion Model version for the produced model.
* @param model Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data
* from this model will be written into the ONNX Model proto.
* @param <M> The type of the provenanced model.
* @return An ONNX model proto of the graph represented by the supplied ONNXContext.
*/
public static <M extends Provenancable<ModelProvenance>> OnnxMl.ModelProto buildModel(ONNXContext onnxContext, String domain, long modelVersion, M model) {
return OnnxMl.ModelProto.newBuilder()
.setGraph(onnxContext.buildGraph())
.setDomain(domain)
.setProducerName("Tribuo")
.setProducerVersion(Tribuo.VERSION)
.setModelVersion(modelVersion)
.addOpsetImport(ONNXOperators.getOpsetProto())
.setIrVersion(6)
.setDocString(model.toString())
.addMetadataProps(OnnxMl.StringStringEntryProto
.newBuilder()
.setKey(PROVENANCE_METADATA_FIELD)
.setValue(SERIALIZER.marshalAndSerialize(model.getProvenance()))
.build())
.build();
}

/**
* Exports this {@link org.tribuo.Model} as an ONNX protobuf.
* @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
*
* @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
* @param modelVersion A version number for this model.
* @return The ONNX ModelProto representing this Tribuo Model.
*/
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -58,27 +99,30 @@ public interface ONNXExportable {
/**
* Writes this {@link org.tribuo.Model} into {@link OnnxMl.GraphProto.Builder} inside the input's
* {@link ONNXContext}.
*
* @param input The input to the model graph.
* @return the output node of the model graph.
*/
public ONNXNode writeONNXGraph(ONNXRef<?> input);

/**
* Exports this {@link org.tribuo.Model} as an ONNX file.
* @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
*
* @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
* @param modelVersion A version number for this model.
* @param outputPath The path to write to.
* @param outputPath The path to write to.
* @throws IOException if the file could not be written to.
*/
default public void saveONNXModel(String domain, long modelVersion, Path outputPath) throws IOException {
OnnxMl.ModelProto proto = exportONNXModel(domain,modelVersion);
OnnxMl.ModelProto proto = exportONNXModel(domain, modelVersion);
try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(outputPath.toFile()))) {
proto.writeTo(bos);
}
}

/**
* Serializes the model provenance to a String.
*
* @param provenance The provenance to serialize.
* @return The serialized form of the ModelProvenance.
*/
Expand Down
4 changes: 2 additions & 2 deletions Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXRef;

import java.io.Serializable;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXPlaceholder;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.impl.TimestampedTrainerProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.time.OffsetDateTime;
import java.util.ArrayList;
Expand Down Expand Up @@ -247,7 +247,7 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
// Build graph
writeONNXGraph(input).assignTo(output);

return onnx.model(domain, modelVersion, this);
return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
}

@Override
Expand Down
Loading