Skip to content

Commit

Permalink
ONNX export support for LibLinear and LibSVM (#191)
Browse files Browse the repository at this point in the history
* Stubbing out LibSVM ONNX export.

* Adding a first draft of LibLinear ONNX export support.

* Code formatting.

* Fixing LibLinear classification ONNX export.

* Tidying up Liblinear and libsvm export after rebase.

* Adding SVM Regression ONNX export.

* Adding graph names to Liblinear and LibSVM ONNX export.

* Adding LibSVM classification support, but it only works for binary classes at the moment due to a label permutation issue.

* Adding LibSVM export support for multi-class problems. Binary non-probabilistic models still don't work.

* Modifying the tests for a binary libsvm.

* Adding provenance serialization to libsvm and liblinear onnx export.

* Uncommenting rearranged indices test.

* Adding fixes for multi-class Libsvm export. Binary class is still non-functional.

* Fixing LibSVM binary classification.

* Updates after rebase to pick up ONNX ensemble export. Added LibSVM to the ensemble test.

* Cleaning up unused imports.

* Update Interop/ONNX/src/main/java/org/tribuo/interop/onnx/LabelOneVOneTransformer.java

Co-authored-by: Jack Sullivan <john.t.sullivan@gmail.com>

* Updating the docs for LabelOneVOneTransformer after the reviews.

Co-authored-by: Jack Sullivan <john.t.sullivan@gmail.com>
  • Loading branch information
Craigacp and JackSullivan authored Nov 29, 2021
1 parent ccba01b commit 3dfb1a2
Show file tree
Hide file tree
Showing 19 changed files with 1,231 additions and 92 deletions.
15 changes: 14 additions & 1 deletion Classification/LibLinear/pom.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<!--
~ Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
~ Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,19 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-onnx</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-onnx</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.classification.liblinear;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -24,13 +26,22 @@
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.Tribuo;
import org.tribuo.classification.Label;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXShape;
import org.tribuo.onnx.ONNXUtils;
import org.tribuo.provenance.ModelProvenance;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -61,7 +72,7 @@
* Machine Learning, 1995.
* </pre>
*/
public class LibLinearClassificationModel extends LibLinearModel<Label> {
public class LibLinearClassificationModel extends LibLinearModel<Label> implements ONNXExportable {
private static final long serialVersionUID = 3L;

private static final Logger logger = Logger.getLogger(LibLinearClassificationModel.class.getName());
Expand Down Expand Up @@ -277,4 +288,122 @@ protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWe

return new Excuse<>(e, prediction, weightMap);
}

@Override
public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXContext context = new ONNXContext();

// Make inputs and outputs
OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT);
OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build();
context.addInput(inputValueProto);
OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT);
OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build();
context.addOutput(outputValueProto);
context.setName("Classification-LibLinear");

// Build graph
writeONNXGraph(context, inputValueProto.getName(), outputValueProto.getName());

// Build model
OnnxMl.ModelProto.Builder builder = OnnxMl.ModelProto.newBuilder();
builder.setGraph(context.buildGraph());
builder.setDomain(domain);
builder.setProducerName("Tribuo");
builder.setProducerVersion(Tribuo.VERSION);
builder.setModelVersion(modelVersion);
builder.setDocString(toString());
builder.addOpsetImport(ONNXOperators.getOpsetProto());
builder.setIrVersion(6);

// Extract provenance and store in metadata
OnnxMl.StringStringEntryProto.Builder metaBuilder = OnnxMl.StringStringEntryProto.newBuilder();
metaBuilder.setKey(ONNXExportable.PROVENANCE_METADATA_FIELD);
metaBuilder.setValue(serializeProvenance(getProvenance()));
builder.addMetadataProps(metaBuilder.build());

return builder.build();
}

@Override
public void writeONNXGraph(ONNXContext context, String inputName, String outputName) {
de.bwaldvogel.liblinear.Model model = models.get(0);
double[] weights = model.getFeatureWeights();
int[] labels = model.getLabels();
int numFeatures = featureIDMap.size();
int numLabels = labels.length;
if (numLabels != outputIDInfo.size()) {
throw new IllegalStateException("Unexpected number of labels, output domain = " + outputIDInfo.size() + ", LibLinear's internal count = " + numLabels);
}

// setup weight arrays for easy processing
if (model.getNrClass() == 2) {
// Replicate weights in binary problems
double[] newWeights = new double[weights.length*2];
for (int i = 0; i < weights.length; i++) {
if (labels[0] == 0) {
newWeights[i * 2] = weights[i];
newWeights[(i * 2) + 1] = -weights[i];
} else {
newWeights[i * 2] = -weights[i];
newWeights[(i * 2) + 1] = weights[i];
}
}
weights = newWeights;
} else {
double[] newWeights = new double[weights.length];
for (int j = 0; j < numFeatures + 1; j++) {
for (int i = 0; i < numLabels; i++) {
int newIdx = (j * numLabels) + labels[i];
int oldIdx = (j * numLabels) + i;
newWeights[newIdx] = weights[oldIdx];
}
}
weights = newWeights;
}

// Add weights
OnnxMl.TensorProto.Builder weightBuilder = OnnxMl.TensorProto.newBuilder();
weightBuilder.setName(context.generateUniqueName("liblinear-weights"));
weightBuilder.addDims(numFeatures);
weightBuilder.addDims(numLabels);
weightBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
ByteBuffer buffer = ByteBuffer.allocate(numFeatures * numLabels * 4).order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer floatBuffer = buffer.asFloatBuffer();
// Biases are stored last in the weight matrix, and it's in column major order
for (int i = 0; i < weights.length - numLabels; i++) {
floatBuffer.put((float) weights[i]);
}
floatBuffer.rewind();
weightBuilder.setRawData(ByteString.copyFrom(buffer));
context.addInitializer(weightBuilder.build());

// Add biases
OnnxMl.TensorProto.Builder biasBuilder = OnnxMl.TensorProto.newBuilder();
biasBuilder.setName(context.generateUniqueName("liblinear-biases"));
biasBuilder.addDims(numLabels);
biasBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
ByteBuffer biasBuffer = ByteBuffer.allocate(numLabels * 4).order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer floatBiasBuffer = biasBuffer.asFloatBuffer();
// Biases are stored last in the weight matrix, and it's in column major order
for (int i = numFeatures * numLabels; i < weights.length; i++) {
floatBiasBuffer.put((float) weights[i]);
}
floatBiasBuffer.rewind();
biasBuilder.setRawData(ByteString.copyFrom(biasBuffer));
context.addInitializer(biasBuilder.build());

String gemmOutput = model.isProbabilityModel() ? context.generateUniqueName("gemm_output") : outputName;

// Make gemm
String[] gemmInputs = new String[]{inputName,weightBuilder.getName(),biasBuilder.getName()};
OnnxMl.NodeProto gemm = ONNXOperators.GEMM.build(context,gemmInputs,gemmOutput);
context.addNode(gemm);

if (model.isProbabilityModel()) {
// Make output normalizer if producing probabilities
context.addNode(ONNXOperators.SOFTMAX.build(context,gemm.getOutput(0),outputName, Collections.singletonMap("axis",1)));
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.tribuo.classification.liblinear;

import ai.onnxruntime.OrtException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.CategoricalIDInfo;
import org.tribuo.CategoricalInfo;
Expand Down Expand Up @@ -43,13 +44,15 @@
import de.bwaldvogel.liblinear.FeatureNode;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.interop.onnx.OnnxTestUtils;
import org.tribuo.test.Helpers;
import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
Expand Down Expand Up @@ -259,6 +262,20 @@ public void testEmptyExample() {
});
}

@Test
public void testOnnxSerialization() throws IOException, OrtException {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
LibLinearClassificationModel model = (LibLinearClassificationModel) t.train(p.getA());

// Write out model
Path onnxFile = Files.createTempFile("tribuo-liblinear-test",".onnx");
model.saveONNXModel("org.tribuo.classification.liblinear.test",1,onnxFile);

OnnxTestUtils.onnxLabelComparison(model,onnxFile,p.getB(),1e-6);

onnxFile.toFile().delete();
}

private static int[] getIndices(FeatureNode[] nodes) {
int[] indices = new int[nodes.length];

Expand Down
13 changes: 13 additions & 0 deletions Classification/LibSVM/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-onnx</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-onnx</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
Expand Down
Loading

0 comments on commit 3dfb1a2

Please sign in to comment.