Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protobuf serialization for MultinomialNaiveBayesModel #267

Merged
merged 4 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.classification.mnb;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -25,12 +27,16 @@
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.mnb.protos.MultinomialNaiveBayesProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.util.ExpNormalizer;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

import java.util.ArrayList;
Expand All @@ -56,6 +62,11 @@
public class MultinomialNaiveBayesModel extends Model<Label> {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final DenseSparseMatrix labelWordProbs;
private final double alpha;

Expand All @@ -67,6 +78,46 @@ public class MultinomialNaiveBayesModel extends Model<Label> {
this.alpha = alpha;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static MultinomialNaiveBayesModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
MultinomialNaiveBayesProto proto = message.unpack(MultinomialNaiveBayesProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

Tensor weights = Tensor.deserialize(proto.getLabelWordProbs());
if (!(weights instanceof DenseSparseMatrix)) {
throw new IllegalStateException("Invalid protobuf, label word probs must be a sparse matrix, found " + weights.getClass());
}
DenseSparseMatrix labelWordProbs = (DenseSparseMatrix) weights;
if (labelWordProbs.getDimension1Size() != carrier.outputDomain().size()) {
throw new IllegalStateException("Invalid protobuf, labelWordProbs not the right size, expected " + carrier.outputDomain().size() + ", found " + labelWordProbs.getDimension1Size());
}
if (labelWordProbs.getDimension2Size() != carrier.featureDomain().size()) {
throw new IllegalStateException("Invalid protobuf, labelWordProbs not the right size, expected " + carrier.featureDomain().size() + ", found " + labelWordProbs.getDimension2Size());
}

double alpha = proto.getAlpha();

if (alpha < 0.0) {
throw new IllegalStateException("Invalid protobuf, alpha must be non-negative, found " + alpha);
}

return new MultinomialNaiveBayesModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,labelWordProbs,alpha);
}

@Override
public Prediction<Label> predict(Example<Label> example) {
SparseVector exVector = SparseVector.createSparseVector(example, featureIDMap, false);
Expand Down Expand Up @@ -157,4 +208,21 @@ public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
protected MultinomialNaiveBayesModel copy(String newName, ModelProvenance newProvenance) {
return new MultinomialNaiveBayesModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseSparseMatrix(labelWordProbs),alpha);
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Label> carrier = createDataCarrier();

MultinomialNaiveBayesProto.Builder modelBuilder = MultinomialNaiveBayesProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setLabelWordProbs(labelWordProbs.serialize());
modelBuilder.setAlpha(alpha);

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(MultinomialNaiveBayesModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}
}
Loading