Skip to content

Commit

Permalink
Adding SLM proto serialization. (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp authored Sep 13, 2022
1 parent 2a04a30 commit af1db52
Show file tree
Hide file tree
Showing 6 changed files with 2,316 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.tribuo.regression.slm;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -26,15 +28,22 @@
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.VariableInfo;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.Regressor.DimensionTuple;
import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;
import org.tribuo.regression.slm.protos.SparseLinearModelProto;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
Expand All @@ -54,6 +63,7 @@
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* The inference time version of a sparse linear regression model.
Expand All @@ -64,6 +74,11 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel
private static final long serialVersionUID = 3L;
private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName());

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private SparseVector[] weights;
private final DenseVector featureMeans;
/**
Expand Down Expand Up @@ -93,6 +108,81 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel
this.enet41MappingFix = true;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static SparseLinearModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
SparseLinearModelProto proto = message.unpack(SparseLinearModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Regressor> outputDomain = (ImmutableOutputInfo<Regressor>) carrier.outputDomain();

String[] dimensions = new String[proto.getDimensionsCount()];
if (dimensions.length != outputDomain.size()) {
throw new IllegalStateException("Invalid protobuf, found insufficient dimension names, expected " + outputDomain.size() + ", found " + dimensions.length);
}
for (int i = 0; i < dimensions.length; i++) {
dimensions[i] = proto.getDimensions(i);
}

SparseVector[] weights = new SparseVector[outputDomain.size()];
if (weights.length != proto.getWeightsCount()) {
throw new IllegalStateException("Invalid protobuf, expected same weight dimension as output domain size, found " + proto.getWeightsCount() + " weights and " + outputDomain.size() + " output dimensions");
}
int featureSize = proto.getBias() ? carrier.featureDomain().size() + 1 : carrier.featureDomain().size();
for (int i = 0; i < weights.length; i++) {
Tensor deser = Tensor.deserialize(proto.getWeights(i));
if (deser instanceof SparseVector) {
SparseVector v = (SparseVector) deser;
if (v.size() == featureSize) {
weights[i] = v;
} else {
throw new IllegalStateException("Invalid protobuf, weights size and feature domain do not match, expected " + featureSize + ", found " + v.size());
}
} else {
throw new IllegalStateException("Invalid protobuf, expected a SparseVector, found " + deser.getClass());
}
}

Tensor featureMeansTensor = Tensor.deserialize(proto.getFeatureMeans());
if (!(featureMeansTensor instanceof DenseVector)) {
throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureMeansTensor.getClass());
}
DenseVector featureMeans = (DenseVector) featureMeansTensor;
if (featureMeans.size() != featureSize) {
throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureMeans.size());
}
Tensor featureNormsTensor = Tensor.deserialize(proto.getFeatureNorms());
if (!(featureNormsTensor instanceof DenseVector)) {
throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureNormsTensor.getClass());
}
DenseVector featureNorms = (DenseVector) featureNormsTensor;
if (featureNorms.size() != featureSize) {
throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureNorms.size());
}
double[] yMean = Util.toPrimitiveDouble(proto.getYMeanList());
if (yMean.length != outputDomain.size()) {
throw new IllegalStateException("Invalid protobuf, y means not the right size, expected " + carrier.outputDomain().size() + " found " + yMean.length);
}
double[] yNorm = Util.toPrimitiveDouble(proto.getYNormList());
if (yNorm.length != outputDomain.size()) {
throw new IllegalStateException("Invalid protobuf, y norms not the right size, expected " + carrier.outputDomain().size() + " found " + yNorm.length);
}

return new SparseLinearModel(carrier.name(),dimensions, carrier.provenance(),carrier.featureDomain(),outputDomain,
weights, featureMeans, featureNorms, yMean, yNorm, proto.getBias());
}

private static Map<String, List<String>> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) {
Map<String, List<String>> map = new HashMap<>();

Expand Down Expand Up @@ -224,6 +314,30 @@ public Map<String, SparseVector> getWeights() {
return output;
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Regressor> carrier = createDataCarrier();

SparseLinearModelProto.Builder modelBuilder = SparseLinearModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.addAllDimensions(Arrays.asList(dimensions));
for (SparseVector v : weights) {
modelBuilder.addWeights(v.serialize());
}
modelBuilder.setFeatureMeans(featureMeans.serialize());
modelBuilder.setFeatureNorms(featureVariance.serialize());
modelBuilder.setBias(bias);
modelBuilder.addAllYMean(Arrays.stream(yMean).boxed().collect(Collectors.toList()));
modelBuilder.addAllYNorm(Arrays.stream(yVariance).boxed().collect(Collectors.toList()));

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(SparseLinearModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

@Override
public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXContext onnx = new ONNXContext();
Expand Down
Loading

0 comments on commit af1db52

Please sign in to comment.