From af1db521786940b7af44094f86c423414829cb3b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 13 Sep 2022 11:59:39 -0400 Subject: [PATCH] Adding SLM proto serialization. (#269) --- .../regression/slm/SparseLinearModel.java | 114 + .../slm/protos/SparseLinearModelProto.java | 1946 +++++++++++++++++ .../SparseLinearModelProtoOrBuilder.java | 143 ++ .../slm/protos/TribuoRegressionSlm.java | 60 + .../protos/tribuo-regression-slm.proto | 45 + .../org/tribuo/regression/slm/TestSLM.java | 8 + 6 files changed, 2316 insertions(+) create mode 100644 Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProto.java create mode 100644 Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProtoOrBuilder.java create mode 100644 Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/TribuoRegressionSlm.java create mode 100644 Regression/SLM/src/main/resources/protos/tribuo-regression-slm.proto diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java index 015dfc569..bf74a6640 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java @@ -17,6 +17,8 @@ package org.tribuo.regression.slm; import ai.onnx.proto.OnnxMl; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.Excuse; @@ -26,15 +28,22 @@ import org.tribuo.ONNXExportable; import org.tribuo.Prediction; import org.tribuo.VariableInfo; +import org.tribuo.impl.ModelDataCarrier; +import org.tribuo.math.la.DenseSparseMatrix; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SparseVector; +import org.tribuo.math.la.Tensor; import org.tribuo.math.la.VectorTuple; +import org.tribuo.math.protos.TensorProto; +import org.tribuo.protos.core.ModelProto; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.regression.ImmutableRegressionInfo; import org.tribuo.regression.Regressor; import org.tribuo.regression.Regressor.DimensionTuple; import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel; +import org.tribuo.regression.slm.protos.SparseLinearModelProto; +import org.tribuo.util.Util; import org.tribuo.util.onnx.ONNXContext; import org.tribuo.util.onnx.ONNXNode; import org.tribuo.util.onnx.ONNXOperators; @@ -54,6 +63,7 @@ import java.util.Optional; import java.util.PriorityQueue; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * The inference time version of a sparse linear regression model. @@ -64,6 +74,11 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel private static final long serialVersionUID = 3L; private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName()); + /** + * Protobuf serialization version. + */ + public static final int CURRENT_VERSION = 0; + private SparseVector[] weights; private final DenseVector featureMeans; /** @@ -93,6 +108,81 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel this.enet41MappingFix = true; } + /** + * Deserialization factory. + * @param version The serialized object version. + * @param className The class name. + * @param message The serialized data. + */ + public static SparseLinearModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException { + if (version < 0 || version > CURRENT_VERSION) { + throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION); + } + SparseLinearModelProto proto = message.unpack(SparseLinearModelProto.class); + + ModelDataCarrier carrier = ModelDataCarrier.deserialize(proto.getMetadata()); + if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) { + throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass()); + } + @SuppressWarnings("unchecked") // guarded by getClass + ImmutableOutputInfo outputDomain = (ImmutableOutputInfo) carrier.outputDomain(); + + String[] dimensions = new String[proto.getDimensionsCount()]; + if (dimensions.length != outputDomain.size()) { + throw new IllegalStateException("Invalid protobuf, found insufficient dimension names, expected " + outputDomain.size() + ", found " + dimensions.length); + } + for (int i = 0; i < dimensions.length; i++) { + dimensions[i] = proto.getDimensions(i); + } + + SparseVector[] weights = new SparseVector[outputDomain.size()]; + if (weights.length != proto.getWeightsCount()) { + throw new IllegalStateException("Invalid protobuf, expected same weight dimension as output domain size, found " + proto.getWeightsCount() + " weights and " + outputDomain.size() + " output dimensions"); + } + int featureSize = proto.getBias() ? carrier.featureDomain().size() + 1 : carrier.featureDomain().size(); + for (int i = 0; i < weights.length; i++) { + Tensor deser = Tensor.deserialize(proto.getWeights(i)); + if (deser instanceof SparseVector) { + SparseVector v = (SparseVector) deser; + if (v.size() == featureSize) { + weights[i] = v; + } else { + throw new IllegalStateException("Invalid protobuf, weights size and feature domain do not match, expected " + featureSize + ", found " + v.size()); + } + } else { + throw new IllegalStateException("Invalid protobuf, expected a SparseVector, found " + deser.getClass()); + } + } + + Tensor featureMeansTensor = Tensor.deserialize(proto.getFeatureMeans()); + if (!(featureMeansTensor instanceof DenseVector)) { + throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureMeansTensor.getClass()); + } + DenseVector featureMeans = (DenseVector) featureMeansTensor; + if (featureMeans.size() != featureSize) { + throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureMeans.size()); + } + Tensor featureNormsTensor = Tensor.deserialize(proto.getFeatureNorms()); + if (!(featureNormsTensor instanceof DenseVector)) { + throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureNormsTensor.getClass()); + } + DenseVector featureNorms = (DenseVector) featureNormsTensor; + if (featureNorms.size() != featureSize) { + throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureNorms.size()); + } + double[] yMean = Util.toPrimitiveDouble(proto.getYMeanList()); + if (yMean.length != outputDomain.size()) { + throw new IllegalStateException("Invalid protobuf, y means not the right size, expected " + carrier.outputDomain().size() + " found " + yMean.length); + } + double[] yNorm = Util.toPrimitiveDouble(proto.getYNormList()); + if (yNorm.length != outputDomain.size()) { + throw new IllegalStateException("Invalid protobuf, y norms not the right size, expected " + carrier.outputDomain().size() + " found " + yNorm.length); + } + + return new SparseLinearModel(carrier.name(),dimensions, carrier.provenance(),carrier.featureDomain(),outputDomain, + weights, featureMeans, featureNorms, yMean, yNorm, proto.getBias()); + } + private static Map> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) { Map> map = new HashMap<>(); @@ -224,6 +314,30 @@ public Map getWeights() { return output; } + @Override + public ModelProto serialize() { + ModelDataCarrier carrier = createDataCarrier(); + + SparseLinearModelProto.Builder modelBuilder = SparseLinearModelProto.newBuilder(); + modelBuilder.setMetadata(carrier.serialize()); + modelBuilder.addAllDimensions(Arrays.asList(dimensions)); + for (SparseVector v : weights) { + modelBuilder.addWeights(v.serialize()); + } + modelBuilder.setFeatureMeans(featureMeans.serialize()); + modelBuilder.setFeatureNorms(featureVariance.serialize()); + modelBuilder.setBias(bias); + modelBuilder.addAllYMean(Arrays.stream(yMean).boxed().collect(Collectors.toList())); + modelBuilder.addAllYNorm(Arrays.stream(yVariance).boxed().collect(Collectors.toList())); + + ModelProto.Builder builder = ModelProto.newBuilder(); + builder.setSerializedData(Any.pack(modelBuilder.build())); + builder.setClassName(SparseLinearModel.class.getName()); + builder.setVersion(CURRENT_VERSION); + + return builder.build(); + } + @Override public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { ONNXContext onnx = new ONNXContext(); diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProto.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProto.java new file mode 100644 index 000000000..5587d7fca --- /dev/null +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProto.java @@ -0,0 +1,1946 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tribuo-regression-slm.proto + +package org.tribuo.regression.slm.protos; + +/** + *
+ *SparseLinearModelProto proto
+ * 
+ * + * Protobuf type {@code tribuo.regression.mnb.SparseLinearModelProto} + */ +public final class SparseLinearModelProto extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:tribuo.regression.mnb.SparseLinearModelProto) + SparseLinearModelProtoOrBuilder { +private static final long serialVersionUID = 0L; + // Use SparseLinearModelProto.newBuilder() to construct. + private SparseLinearModelProto(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private SparseLinearModelProto() { + dimensions_ = com.google.protobuf.LazyStringArrayList.EMPTY; + weights_ = java.util.Collections.emptyList(); + yMean_ = emptyDoubleList(); + yNorm_ = emptyDoubleList(); + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SparseLinearModelProto(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private SparseLinearModelProto( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + org.tribuo.protos.core.ModelDataProto.Builder subBuilder = null; + if (metadata_ != null) { + subBuilder = metadata_.toBuilder(); + } + metadata_ = input.readMessage(org.tribuo.protos.core.ModelDataProto.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(metadata_); + metadata_ = subBuilder.buildPartial(); + } + + break; + } + case 18: { + java.lang.String s = input.readStringRequireUtf8(); + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + dimensions_ = new com.google.protobuf.LazyStringArrayList(); + mutable_bitField0_ |= 0x00000001; + } + dimensions_.add(s); + break; + } + case 26: { + if (!((mutable_bitField0_ & 0x00000002) != 0)) { + weights_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000002; + } + weights_.add( + input.readMessage(org.tribuo.math.protos.TensorProto.parser(), extensionRegistry)); + break; + } + case 34: { + org.tribuo.math.protos.TensorProto.Builder subBuilder = null; + if (featureMeans_ != null) { + subBuilder = featureMeans_.toBuilder(); + } + featureMeans_ = input.readMessage(org.tribuo.math.protos.TensorProto.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(featureMeans_); + featureMeans_ = subBuilder.buildPartial(); + } + + break; + } + case 42: { + org.tribuo.math.protos.TensorProto.Builder subBuilder = null; + if (featureNorms_ != null) { + subBuilder = featureNorms_.toBuilder(); + } + featureNorms_ = input.readMessage(org.tribuo.math.protos.TensorProto.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(featureNorms_); + featureNorms_ = subBuilder.buildPartial(); + } + + break; + } + case 48: { + + bias_ = input.readBool(); + break; + } + case 57: { + if (!((mutable_bitField0_ & 0x00000004) != 0)) { + yMean_ = newDoubleList(); + mutable_bitField0_ |= 0x00000004; + } + yMean_.addDouble(input.readDouble()); + break; + } + case 58: { + int length = input.readRawVarint32(); + int limit = input.pushLimit(length); + if (!((mutable_bitField0_ & 0x00000004) != 0) && input.getBytesUntilLimit() > 0) { + yMean_ = newDoubleList(); + mutable_bitField0_ |= 0x00000004; + } + while (input.getBytesUntilLimit() > 0) { + yMean_.addDouble(input.readDouble()); + } + input.popLimit(limit); + break; + } + case 65: { + if (!((mutable_bitField0_ & 0x00000008) != 0)) { + yNorm_ = newDoubleList(); + mutable_bitField0_ |= 0x00000008; + } + yNorm_.addDouble(input.readDouble()); + break; + } + case 66: { + int length = input.readRawVarint32(); + int limit = input.pushLimit(length); + if (!((mutable_bitField0_ & 0x00000008) != 0) && input.getBytesUntilLimit() > 0) { + yNorm_ = newDoubleList(); + mutable_bitField0_ |= 0x00000008; + } + while (input.getBytesUntilLimit() > 0) { + yNorm_.addDouble(input.readDouble()); + } + input.popLimit(limit); + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + dimensions_ = dimensions_.getUnmodifiableView(); + } + if (((mutable_bitField0_ & 0x00000002) != 0)) { + weights_ = java.util.Collections.unmodifiableList(weights_); + } + if (((mutable_bitField0_ & 0x00000004) != 0)) { + yMean_.makeImmutable(); // C + } + if (((mutable_bitField0_ & 0x00000008) != 0)) { + yNorm_.makeImmutable(); // C + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.tribuo.regression.slm.protos.TribuoRegressionSlm.internal_static_tribuo_regression_mnb_SparseLinearModelProto_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.tribuo.regression.slm.protos.TribuoRegressionSlm.internal_static_tribuo_regression_mnb_SparseLinearModelProto_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.tribuo.regression.slm.protos.SparseLinearModelProto.class, org.tribuo.regression.slm.protos.SparseLinearModelProto.Builder.class); + } + + public static final int METADATA_FIELD_NUMBER = 1; + private org.tribuo.protos.core.ModelDataProto metadata_; + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return Whether the metadata field is set. + */ + @java.lang.Override + public boolean hasMetadata() { + return metadata_ != null; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return The metadata. + */ + @java.lang.Override + public org.tribuo.protos.core.ModelDataProto getMetadata() { + return metadata_ == null ? org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + @java.lang.Override + public org.tribuo.protos.core.ModelDataProtoOrBuilder getMetadataOrBuilder() { + return getMetadata(); + } + + public static final int DIMENSIONS_FIELD_NUMBER = 2; + private com.google.protobuf.LazyStringList dimensions_; + /** + * repeated string dimensions = 2; + * @return A list containing the dimensions. + */ + public com.google.protobuf.ProtocolStringList + getDimensionsList() { + return dimensions_; + } + /** + * repeated string dimensions = 2; + * @return The count of dimensions. + */ + public int getDimensionsCount() { + return dimensions_.size(); + } + /** + * repeated string dimensions = 2; + * @param index The index of the element to return. + * @return The dimensions at the given index. + */ + public java.lang.String getDimensions(int index) { + return dimensions_.get(index); + } + /** + * repeated string dimensions = 2; + * @param index The index of the value to return. + * @return The bytes of the dimensions at the given index. + */ + public com.google.protobuf.ByteString + getDimensionsBytes(int index) { + return dimensions_.getByteString(index); + } + + public static final int WEIGHTS_FIELD_NUMBER = 3; + private java.util.List weights_; + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + @java.lang.Override + public java.util.List getWeightsList() { + return weights_; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + @java.lang.Override + public java.util.List + getWeightsOrBuilderList() { + return weights_; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + @java.lang.Override + public int getWeightsCount() { + return weights_.size(); + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProto getWeights(int index) { + return weights_.get(index); + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProtoOrBuilder getWeightsOrBuilder( + int index) { + return weights_.get(index); + } + + public static final int FEATURE_MEANS_FIELD_NUMBER = 4; + private org.tribuo.math.protos.TensorProto featureMeans_; + /** + * .tribuo.math.TensorProto feature_means = 4; + * @return Whether the featureMeans field is set. + */ + @java.lang.Override + public boolean hasFeatureMeans() { + return featureMeans_ != null; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + * @return The featureMeans. + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProto getFeatureMeans() { + return featureMeans_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : featureMeans_; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProtoOrBuilder getFeatureMeansOrBuilder() { + return getFeatureMeans(); + } + + public static final int FEATURE_NORMS_FIELD_NUMBER = 5; + private org.tribuo.math.protos.TensorProto featureNorms_; + /** + * .tribuo.math.TensorProto feature_norms = 5; + * @return Whether the featureNorms field is set. + */ + @java.lang.Override + public boolean hasFeatureNorms() { + return featureNorms_ != null; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + * @return The featureNorms. + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProto getFeatureNorms() { + return featureNorms_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : featureNorms_; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProtoOrBuilder getFeatureNormsOrBuilder() { + return getFeatureNorms(); + } + + public static final int BIAS_FIELD_NUMBER = 6; + private boolean bias_; + /** + * bool bias = 6; + * @return The bias. + */ + @java.lang.Override + public boolean getBias() { + return bias_; + } + + public static final int Y_MEAN_FIELD_NUMBER = 7; + private com.google.protobuf.Internal.DoubleList yMean_; + /** + * repeated double y_mean = 7; + * @return A list containing the yMean. + */ + @java.lang.Override + public java.util.List + getYMeanList() { + return yMean_; + } + /** + * repeated double y_mean = 7; + * @return The count of yMean. + */ + public int getYMeanCount() { + return yMean_.size(); + } + /** + * repeated double y_mean = 7; + * @param index The index of the element to return. + * @return The yMean at the given index. + */ + public double getYMean(int index) { + return yMean_.getDouble(index); + } + private int yMeanMemoizedSerializedSize = -1; + + public static final int Y_NORM_FIELD_NUMBER = 8; + private com.google.protobuf.Internal.DoubleList yNorm_; + /** + * repeated double y_norm = 8; + * @return A list containing the yNorm. + */ + @java.lang.Override + public java.util.List + getYNormList() { + return yNorm_; + } + /** + * repeated double y_norm = 8; + * @return The count of yNorm. + */ + public int getYNormCount() { + return yNorm_.size(); + } + /** + * repeated double y_norm = 8; + * @param index The index of the element to return. + * @return The yNorm at the given index. + */ + public double getYNorm(int index) { + return yNorm_.getDouble(index); + } + private int yNormMemoizedSerializedSize = -1; + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getSerializedSize(); + if (metadata_ != null) { + output.writeMessage(1, getMetadata()); + } + for (int i = 0; i < dimensions_.size(); i++) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 2, dimensions_.getRaw(i)); + } + for (int i = 0; i < weights_.size(); i++) { + output.writeMessage(3, weights_.get(i)); + } + if (featureMeans_ != null) { + output.writeMessage(4, getFeatureMeans()); + } + if (featureNorms_ != null) { + output.writeMessage(5, getFeatureNorms()); + } + if (bias_ != false) { + output.writeBool(6, bias_); + } + if (getYMeanList().size() > 0) { + output.writeUInt32NoTag(58); + output.writeUInt32NoTag(yMeanMemoizedSerializedSize); + } + for (int i = 0; i < yMean_.size(); i++) { + output.writeDoubleNoTag(yMean_.getDouble(i)); + } + if (getYNormList().size() > 0) { + output.writeUInt32NoTag(66); + output.writeUInt32NoTag(yNormMemoizedSerializedSize); + } + for (int i = 0; i < yNorm_.size(); i++) { + output.writeDoubleNoTag(yNorm_.getDouble(i)); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (metadata_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(1, getMetadata()); + } + { + int dataSize = 0; + for (int i = 0; i < dimensions_.size(); i++) { + dataSize += computeStringSizeNoTag(dimensions_.getRaw(i)); + } + size += dataSize; + size += 1 * getDimensionsList().size(); + } + for (int i = 0; i < weights_.size(); i++) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, weights_.get(i)); + } + if (featureMeans_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, getFeatureMeans()); + } + if (featureNorms_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(5, getFeatureNorms()); + } + if (bias_ != false) { + size += com.google.protobuf.CodedOutputStream + .computeBoolSize(6, bias_); + } + { + int dataSize = 0; + dataSize = 8 * getYMeanList().size(); + size += dataSize; + if (!getYMeanList().isEmpty()) { + size += 1; + size += com.google.protobuf.CodedOutputStream + .computeInt32SizeNoTag(dataSize); + } + yMeanMemoizedSerializedSize = dataSize; + } + { + int dataSize = 0; + dataSize = 8 * getYNormList().size(); + size += dataSize; + if (!getYNormList().isEmpty()) { + size += 1; + size += com.google.protobuf.CodedOutputStream + .computeInt32SizeNoTag(dataSize); + } + yNormMemoizedSerializedSize = dataSize; + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.tribuo.regression.slm.protos.SparseLinearModelProto)) { + return super.equals(obj); + } + org.tribuo.regression.slm.protos.SparseLinearModelProto other = (org.tribuo.regression.slm.protos.SparseLinearModelProto) obj; + + if (hasMetadata() != other.hasMetadata()) return false; + if (hasMetadata()) { + if (!getMetadata() + .equals(other.getMetadata())) return false; + } + if (!getDimensionsList() + .equals(other.getDimensionsList())) return false; + if (!getWeightsList() + .equals(other.getWeightsList())) return false; + if (hasFeatureMeans() != other.hasFeatureMeans()) return false; + if (hasFeatureMeans()) { + if (!getFeatureMeans() + .equals(other.getFeatureMeans())) return false; + } + if (hasFeatureNorms() != other.hasFeatureNorms()) return false; + if (hasFeatureNorms()) { + if (!getFeatureNorms() + .equals(other.getFeatureNorms())) return false; + } + if (getBias() + != other.getBias()) return false; + if (!getYMeanList() + .equals(other.getYMeanList())) return false; + if (!getYNormList() + .equals(other.getYNormList())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasMetadata()) { + hash = (37 * hash) + METADATA_FIELD_NUMBER; + hash = (53 * hash) + getMetadata().hashCode(); + } + if (getDimensionsCount() > 0) { + hash = (37 * hash) + DIMENSIONS_FIELD_NUMBER; + hash = (53 * hash) + getDimensionsList().hashCode(); + } + if (getWeightsCount() > 0) { + hash = (37 * hash) + WEIGHTS_FIELD_NUMBER; + hash = (53 * hash) + getWeightsList().hashCode(); + } + if (hasFeatureMeans()) { + hash = (37 * hash) + FEATURE_MEANS_FIELD_NUMBER; + hash = (53 * hash) + getFeatureMeans().hashCode(); + } + if (hasFeatureNorms()) { + hash = (37 * hash) + FEATURE_NORMS_FIELD_NUMBER; + hash = (53 * hash) + getFeatureNorms().hashCode(); + } + hash = (37 * hash) + BIAS_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( + getBias()); + if (getYMeanCount() > 0) { + hash = (37 * hash) + Y_MEAN_FIELD_NUMBER; + hash = (53 * hash) + getYMeanList().hashCode(); + } + if (getYNormCount() > 0) { + hash = (37 * hash) + Y_NORM_FIELD_NUMBER; + hash = (53 * hash) + getYNormList().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.tribuo.regression.slm.protos.SparseLinearModelProto parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.tribuo.regression.slm.protos.SparseLinearModelProto prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + *
+   *SparseLinearModelProto proto
+   * 
+ * + * Protobuf type {@code tribuo.regression.mnb.SparseLinearModelProto} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:tribuo.regression.mnb.SparseLinearModelProto) + org.tribuo.regression.slm.protos.SparseLinearModelProtoOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.tribuo.regression.slm.protos.TribuoRegressionSlm.internal_static_tribuo_regression_mnb_SparseLinearModelProto_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.tribuo.regression.slm.protos.TribuoRegressionSlm.internal_static_tribuo_regression_mnb_SparseLinearModelProto_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.tribuo.regression.slm.protos.SparseLinearModelProto.class, org.tribuo.regression.slm.protos.SparseLinearModelProto.Builder.class); + } + + // Construct using org.tribuo.regression.slm.protos.SparseLinearModelProto.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + getWeightsFieldBuilder(); + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + if (metadataBuilder_ == null) { + metadata_ = null; + } else { + metadata_ = null; + metadataBuilder_ = null; + } + dimensions_ = com.google.protobuf.LazyStringArrayList.EMPTY; + bitField0_ = (bitField0_ & ~0x00000001); + if (weightsBuilder_ == null) { + weights_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000002); + } else { + weightsBuilder_.clear(); + } + if (featureMeansBuilder_ == null) { + featureMeans_ = null; + } else { + featureMeans_ = null; + featureMeansBuilder_ = null; + } + if (featureNormsBuilder_ == null) { + featureNorms_ = null; + } else { + featureNorms_ = null; + featureNormsBuilder_ = null; + } + bias_ = false; + + yMean_ = emptyDoubleList(); + bitField0_ = (bitField0_ & ~0x00000004); + yNorm_ = emptyDoubleList(); + bitField0_ = (bitField0_ & ~0x00000008); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.tribuo.regression.slm.protos.TribuoRegressionSlm.internal_static_tribuo_regression_mnb_SparseLinearModelProto_descriptor; + } + + @java.lang.Override + public org.tribuo.regression.slm.protos.SparseLinearModelProto getDefaultInstanceForType() { + return org.tribuo.regression.slm.protos.SparseLinearModelProto.getDefaultInstance(); + } + + @java.lang.Override + public org.tribuo.regression.slm.protos.SparseLinearModelProto build() { + org.tribuo.regression.slm.protos.SparseLinearModelProto result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.tribuo.regression.slm.protos.SparseLinearModelProto buildPartial() { + org.tribuo.regression.slm.protos.SparseLinearModelProto result = new org.tribuo.regression.slm.protos.SparseLinearModelProto(this); + int from_bitField0_ = bitField0_; + if (metadataBuilder_ == null) { + result.metadata_ = metadata_; + } else { + result.metadata_ = metadataBuilder_.build(); + } + if (((bitField0_ & 0x00000001) != 0)) { + dimensions_ = dimensions_.getUnmodifiableView(); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.dimensions_ = dimensions_; + if (weightsBuilder_ == null) { + if (((bitField0_ & 0x00000002) != 0)) { + weights_ = java.util.Collections.unmodifiableList(weights_); + bitField0_ = (bitField0_ & ~0x00000002); + } + result.weights_ = weights_; + } else { + result.weights_ = weightsBuilder_.build(); + } + if (featureMeansBuilder_ == null) { + result.featureMeans_ = featureMeans_; + } else { + result.featureMeans_ = featureMeansBuilder_.build(); + } + if (featureNormsBuilder_ == null) { + result.featureNorms_ = featureNorms_; + } else { + result.featureNorms_ = featureNormsBuilder_.build(); + } + result.bias_ = bias_; + if (((bitField0_ & 0x00000004) != 0)) { + yMean_.makeImmutable(); + bitField0_ = (bitField0_ & ~0x00000004); + } + result.yMean_ = yMean_; + if (((bitField0_ & 0x00000008) != 0)) { + yNorm_.makeImmutable(); + bitField0_ = (bitField0_ & ~0x00000008); + } + result.yNorm_ = yNorm_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.tribuo.regression.slm.protos.SparseLinearModelProto) { + return mergeFrom((org.tribuo.regression.slm.protos.SparseLinearModelProto)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.tribuo.regression.slm.protos.SparseLinearModelProto other) { + if (other == org.tribuo.regression.slm.protos.SparseLinearModelProto.getDefaultInstance()) return this; + if (other.hasMetadata()) { + mergeMetadata(other.getMetadata()); + } + if (!other.dimensions_.isEmpty()) { + if (dimensions_.isEmpty()) { + dimensions_ = other.dimensions_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureDimensionsIsMutable(); + dimensions_.addAll(other.dimensions_); + } + onChanged(); + } + if (weightsBuilder_ == null) { + if (!other.weights_.isEmpty()) { + if (weights_.isEmpty()) { + weights_ = other.weights_; + bitField0_ = (bitField0_ & ~0x00000002); + } else { + ensureWeightsIsMutable(); + weights_.addAll(other.weights_); + } + onChanged(); + } + } else { + if (!other.weights_.isEmpty()) { + if (weightsBuilder_.isEmpty()) { + weightsBuilder_.dispose(); + weightsBuilder_ = null; + weights_ = other.weights_; + bitField0_ = (bitField0_ & ~0x00000002); + weightsBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? + getWeightsFieldBuilder() : null; + } else { + weightsBuilder_.addAllMessages(other.weights_); + } + } + } + if (other.hasFeatureMeans()) { + mergeFeatureMeans(other.getFeatureMeans()); + } + if (other.hasFeatureNorms()) { + mergeFeatureNorms(other.getFeatureNorms()); + } + if (other.getBias() != false) { + setBias(other.getBias()); + } + if (!other.yMean_.isEmpty()) { + if (yMean_.isEmpty()) { + yMean_ = other.yMean_; + bitField0_ = (bitField0_ & ~0x00000004); + } else { + ensureYMeanIsMutable(); + yMean_.addAll(other.yMean_); + } + onChanged(); + } + if (!other.yNorm_.isEmpty()) { + if (yNorm_.isEmpty()) { + yNorm_ = other.yNorm_; + bitField0_ = (bitField0_ & ~0x00000008); + } else { + ensureYNormIsMutable(); + yNorm_.addAll(other.yNorm_); + } + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + org.tribuo.regression.slm.protos.SparseLinearModelProto parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (org.tribuo.regression.slm.protos.SparseLinearModelProto) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + private org.tribuo.protos.core.ModelDataProto metadata_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.protos.core.ModelDataProto, org.tribuo.protos.core.ModelDataProto.Builder, org.tribuo.protos.core.ModelDataProtoOrBuilder> metadataBuilder_; + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return Whether the metadata field is set. + */ + public boolean hasMetadata() { + return metadataBuilder_ != null || metadata_ != null; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return The metadata. + */ + public org.tribuo.protos.core.ModelDataProto getMetadata() { + if (metadataBuilder_ == null) { + return metadata_ == null ? org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } else { + return metadataBuilder_.getMessage(); + } + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder setMetadata(org.tribuo.protos.core.ModelDataProto value) { + if (metadataBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + metadata_ = value; + onChanged(); + } else { + metadataBuilder_.setMessage(value); + } + + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder setMetadata( + org.tribuo.protos.core.ModelDataProto.Builder builderForValue) { + if (metadataBuilder_ == null) { + metadata_ = builderForValue.build(); + onChanged(); + } else { + metadataBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder mergeMetadata(org.tribuo.protos.core.ModelDataProto value) { + if (metadataBuilder_ == null) { + if (metadata_ != null) { + metadata_ = + org.tribuo.protos.core.ModelDataProto.newBuilder(metadata_).mergeFrom(value).buildPartial(); + } else { + metadata_ = value; + } + onChanged(); + } else { + metadataBuilder_.mergeFrom(value); + } + + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder clearMetadata() { + if (metadataBuilder_ == null) { + metadata_ = null; + onChanged(); + } else { + metadata_ = null; + metadataBuilder_ = null; + } + + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public org.tribuo.protos.core.ModelDataProto.Builder getMetadataBuilder() { + + onChanged(); + return getMetadataFieldBuilder().getBuilder(); + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public org.tribuo.protos.core.ModelDataProtoOrBuilder getMetadataOrBuilder() { + if (metadataBuilder_ != null) { + return metadataBuilder_.getMessageOrBuilder(); + } else { + return metadata_ == null ? + org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.protos.core.ModelDataProto, org.tribuo.protos.core.ModelDataProto.Builder, org.tribuo.protos.core.ModelDataProtoOrBuilder> + getMetadataFieldBuilder() { + if (metadataBuilder_ == null) { + metadataBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.protos.core.ModelDataProto, org.tribuo.protos.core.ModelDataProto.Builder, org.tribuo.protos.core.ModelDataProtoOrBuilder>( + getMetadata(), + getParentForChildren(), + isClean()); + metadata_ = null; + } + return metadataBuilder_; + } + + private com.google.protobuf.LazyStringList dimensions_ = com.google.protobuf.LazyStringArrayList.EMPTY; + private void ensureDimensionsIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + dimensions_ = new com.google.protobuf.LazyStringArrayList(dimensions_); + bitField0_ |= 0x00000001; + } + } + /** + * repeated string dimensions = 2; + * @return A list containing the dimensions. + */ + public com.google.protobuf.ProtocolStringList + getDimensionsList() { + return dimensions_.getUnmodifiableView(); + } + /** + * repeated string dimensions = 2; + * @return The count of dimensions. + */ + public int getDimensionsCount() { + return dimensions_.size(); + } + /** + * repeated string dimensions = 2; + * @param index The index of the element to return. + * @return The dimensions at the given index. + */ + public java.lang.String getDimensions(int index) { + return dimensions_.get(index); + } + /** + * repeated string dimensions = 2; + * @param index The index of the value to return. + * @return The bytes of the dimensions at the given index. + */ + public com.google.protobuf.ByteString + getDimensionsBytes(int index) { + return dimensions_.getByteString(index); + } + /** + * repeated string dimensions = 2; + * @param index The index to set the value at. + * @param value The dimensions to set. + * @return This builder for chaining. + */ + public Builder setDimensions( + int index, java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + ensureDimensionsIsMutable(); + dimensions_.set(index, value); + onChanged(); + return this; + } + /** + * repeated string dimensions = 2; + * @param value The dimensions to add. + * @return This builder for chaining. + */ + public Builder addDimensions( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + ensureDimensionsIsMutable(); + dimensions_.add(value); + onChanged(); + return this; + } + /** + * repeated string dimensions = 2; + * @param values The dimensions to add. + * @return This builder for chaining. + */ + public Builder addAllDimensions( + java.lang.Iterable values) { + ensureDimensionsIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, dimensions_); + onChanged(); + return this; + } + /** + * repeated string dimensions = 2; + * @return This builder for chaining. + */ + public Builder clearDimensions() { + dimensions_ = com.google.protobuf.LazyStringArrayList.EMPTY; + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + /** + * repeated string dimensions = 2; + * @param value The bytes of the dimensions to add. + * @return This builder for chaining. + */ + public Builder addDimensionsBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + ensureDimensionsIsMutable(); + dimensions_.add(value); + onChanged(); + return this; + } + + private java.util.List weights_ = + java.util.Collections.emptyList(); + private void ensureWeightsIsMutable() { + if (!((bitField0_ & 0x00000002) != 0)) { + weights_ = new java.util.ArrayList(weights_); + bitField0_ |= 0x00000002; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> weightsBuilder_; + + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public java.util.List getWeightsList() { + if (weightsBuilder_ == null) { + return java.util.Collections.unmodifiableList(weights_); + } else { + return weightsBuilder_.getMessageList(); + } + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public int getWeightsCount() { + if (weightsBuilder_ == null) { + return weights_.size(); + } else { + return weightsBuilder_.getCount(); + } + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public org.tribuo.math.protos.TensorProto getWeights(int index) { + if (weightsBuilder_ == null) { + return weights_.get(index); + } else { + return weightsBuilder_.getMessage(index); + } + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder setWeights( + int index, org.tribuo.math.protos.TensorProto value) { + if (weightsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureWeightsIsMutable(); + weights_.set(index, value); + onChanged(); + } else { + weightsBuilder_.setMessage(index, value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder setWeights( + int index, org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (weightsBuilder_ == null) { + ensureWeightsIsMutable(); + weights_.set(index, builderForValue.build()); + onChanged(); + } else { + weightsBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder addWeights(org.tribuo.math.protos.TensorProto value) { + if (weightsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureWeightsIsMutable(); + weights_.add(value); + onChanged(); + } else { + weightsBuilder_.addMessage(value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder addWeights( + int index, org.tribuo.math.protos.TensorProto value) { + if (weightsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureWeightsIsMutable(); + weights_.add(index, value); + onChanged(); + } else { + weightsBuilder_.addMessage(index, value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder addWeights( + org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (weightsBuilder_ == null) { + ensureWeightsIsMutable(); + weights_.add(builderForValue.build()); + onChanged(); + } else { + weightsBuilder_.addMessage(builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder addWeights( + int index, org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (weightsBuilder_ == null) { + ensureWeightsIsMutable(); + weights_.add(index, builderForValue.build()); + onChanged(); + } else { + weightsBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder addAllWeights( + java.lang.Iterable values) { + if (weightsBuilder_ == null) { + ensureWeightsIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, weights_); + onChanged(); + } else { + weightsBuilder_.addAllMessages(values); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder clearWeights() { + if (weightsBuilder_ == null) { + weights_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000002); + onChanged(); + } else { + weightsBuilder_.clear(); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public Builder removeWeights(int index) { + if (weightsBuilder_ == null) { + ensureWeightsIsMutable(); + weights_.remove(index); + onChanged(); + } else { + weightsBuilder_.remove(index); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public org.tribuo.math.protos.TensorProto.Builder getWeightsBuilder( + int index) { + return getWeightsFieldBuilder().getBuilder(index); + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public org.tribuo.math.protos.TensorProtoOrBuilder getWeightsOrBuilder( + int index) { + if (weightsBuilder_ == null) { + return weights_.get(index); } else { + return weightsBuilder_.getMessageOrBuilder(index); + } + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public java.util.List + getWeightsOrBuilderList() { + if (weightsBuilder_ != null) { + return weightsBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(weights_); + } + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public org.tribuo.math.protos.TensorProto.Builder addWeightsBuilder() { + return getWeightsFieldBuilder().addBuilder( + org.tribuo.math.protos.TensorProto.getDefaultInstance()); + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public org.tribuo.math.protos.TensorProto.Builder addWeightsBuilder( + int index) { + return getWeightsFieldBuilder().addBuilder( + index, org.tribuo.math.protos.TensorProto.getDefaultInstance()); + } + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + public java.util.List + getWeightsBuilderList() { + return getWeightsFieldBuilder().getBuilderList(); + } + private com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> + getWeightsFieldBuilder() { + if (weightsBuilder_ == null) { + weightsBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder>( + weights_, + ((bitField0_ & 0x00000002) != 0), + getParentForChildren(), + isClean()); + weights_ = null; + } + return weightsBuilder_; + } + + private org.tribuo.math.protos.TensorProto featureMeans_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> featureMeansBuilder_; + /** + * .tribuo.math.TensorProto feature_means = 4; + * @return Whether the featureMeans field is set. + */ + public boolean hasFeatureMeans() { + return featureMeansBuilder_ != null || featureMeans_ != null; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + * @return The featureMeans. + */ + public org.tribuo.math.protos.TensorProto getFeatureMeans() { + if (featureMeansBuilder_ == null) { + return featureMeans_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : featureMeans_; + } else { + return featureMeansBuilder_.getMessage(); + } + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + public Builder setFeatureMeans(org.tribuo.math.protos.TensorProto value) { + if (featureMeansBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + featureMeans_ = value; + onChanged(); + } else { + featureMeansBuilder_.setMessage(value); + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + public Builder setFeatureMeans( + org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (featureMeansBuilder_ == null) { + featureMeans_ = builderForValue.build(); + onChanged(); + } else { + featureMeansBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + public Builder mergeFeatureMeans(org.tribuo.math.protos.TensorProto value) { + if (featureMeansBuilder_ == null) { + if (featureMeans_ != null) { + featureMeans_ = + org.tribuo.math.protos.TensorProto.newBuilder(featureMeans_).mergeFrom(value).buildPartial(); + } else { + featureMeans_ = value; + } + onChanged(); + } else { + featureMeansBuilder_.mergeFrom(value); + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + public Builder clearFeatureMeans() { + if (featureMeansBuilder_ == null) { + featureMeans_ = null; + onChanged(); + } else { + featureMeans_ = null; + featureMeansBuilder_ = null; + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + public org.tribuo.math.protos.TensorProto.Builder getFeatureMeansBuilder() { + + onChanged(); + return getFeatureMeansFieldBuilder().getBuilder(); + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + public org.tribuo.math.protos.TensorProtoOrBuilder getFeatureMeansOrBuilder() { + if (featureMeansBuilder_ != null) { + return featureMeansBuilder_.getMessageOrBuilder(); + } else { + return featureMeans_ == null ? + org.tribuo.math.protos.TensorProto.getDefaultInstance() : featureMeans_; + } + } + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> + getFeatureMeansFieldBuilder() { + if (featureMeansBuilder_ == null) { + featureMeansBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder>( + getFeatureMeans(), + getParentForChildren(), + isClean()); + featureMeans_ = null; + } + return featureMeansBuilder_; + } + + private org.tribuo.math.protos.TensorProto featureNorms_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> featureNormsBuilder_; + /** + * .tribuo.math.TensorProto feature_norms = 5; + * @return Whether the featureNorms field is set. + */ + public boolean hasFeatureNorms() { + return featureNormsBuilder_ != null || featureNorms_ != null; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + * @return The featureNorms. + */ + public org.tribuo.math.protos.TensorProto getFeatureNorms() { + if (featureNormsBuilder_ == null) { + return featureNorms_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : featureNorms_; + } else { + return featureNormsBuilder_.getMessage(); + } + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + public Builder setFeatureNorms(org.tribuo.math.protos.TensorProto value) { + if (featureNormsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + featureNorms_ = value; + onChanged(); + } else { + featureNormsBuilder_.setMessage(value); + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + public Builder setFeatureNorms( + org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (featureNormsBuilder_ == null) { + featureNorms_ = builderForValue.build(); + onChanged(); + } else { + featureNormsBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + public Builder mergeFeatureNorms(org.tribuo.math.protos.TensorProto value) { + if (featureNormsBuilder_ == null) { + if (featureNorms_ != null) { + featureNorms_ = + org.tribuo.math.protos.TensorProto.newBuilder(featureNorms_).mergeFrom(value).buildPartial(); + } else { + featureNorms_ = value; + } + onChanged(); + } else { + featureNormsBuilder_.mergeFrom(value); + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + public Builder clearFeatureNorms() { + if (featureNormsBuilder_ == null) { + featureNorms_ = null; + onChanged(); + } else { + featureNorms_ = null; + featureNormsBuilder_ = null; + } + + return this; + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + public org.tribuo.math.protos.TensorProto.Builder getFeatureNormsBuilder() { + + onChanged(); + return getFeatureNormsFieldBuilder().getBuilder(); + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + public org.tribuo.math.protos.TensorProtoOrBuilder getFeatureNormsOrBuilder() { + if (featureNormsBuilder_ != null) { + return featureNormsBuilder_.getMessageOrBuilder(); + } else { + return featureNorms_ == null ? + org.tribuo.math.protos.TensorProto.getDefaultInstance() : featureNorms_; + } + } + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> + getFeatureNormsFieldBuilder() { + if (featureNormsBuilder_ == null) { + featureNormsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder>( + getFeatureNorms(), + getParentForChildren(), + isClean()); + featureNorms_ = null; + } + return featureNormsBuilder_; + } + + private boolean bias_ ; + /** + * bool bias = 6; + * @return The bias. + */ + @java.lang.Override + public boolean getBias() { + return bias_; + } + /** + * bool bias = 6; + * @param value The bias to set. + * @return This builder for chaining. + */ + public Builder setBias(boolean value) { + + bias_ = value; + onChanged(); + return this; + } + /** + * bool bias = 6; + * @return This builder for chaining. + */ + public Builder clearBias() { + + bias_ = false; + onChanged(); + return this; + } + + private com.google.protobuf.Internal.DoubleList yMean_ = emptyDoubleList(); + private void ensureYMeanIsMutable() { + if (!((bitField0_ & 0x00000004) != 0)) { + yMean_ = mutableCopy(yMean_); + bitField0_ |= 0x00000004; + } + } + /** + * repeated double y_mean = 7; + * @return A list containing the yMean. + */ + public java.util.List + getYMeanList() { + return ((bitField0_ & 0x00000004) != 0) ? + java.util.Collections.unmodifiableList(yMean_) : yMean_; + } + /** + * repeated double y_mean = 7; + * @return The count of yMean. + */ + public int getYMeanCount() { + return yMean_.size(); + } + /** + * repeated double y_mean = 7; + * @param index The index of the element to return. + * @return The yMean at the given index. + */ + public double getYMean(int index) { + return yMean_.getDouble(index); + } + /** + * repeated double y_mean = 7; + * @param index The index to set the value at. + * @param value The yMean to set. + * @return This builder for chaining. + */ + public Builder setYMean( + int index, double value) { + ensureYMeanIsMutable(); + yMean_.setDouble(index, value); + onChanged(); + return this; + } + /** + * repeated double y_mean = 7; + * @param value The yMean to add. + * @return This builder for chaining. + */ + public Builder addYMean(double value) { + ensureYMeanIsMutable(); + yMean_.addDouble(value); + onChanged(); + return this; + } + /** + * repeated double y_mean = 7; + * @param values The yMean to add. + * @return This builder for chaining. + */ + public Builder addAllYMean( + java.lang.Iterable values) { + ensureYMeanIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, yMean_); + onChanged(); + return this; + } + /** + * repeated double y_mean = 7; + * @return This builder for chaining. + */ + public Builder clearYMean() { + yMean_ = emptyDoubleList(); + bitField0_ = (bitField0_ & ~0x00000004); + onChanged(); + return this; + } + + private com.google.protobuf.Internal.DoubleList yNorm_ = emptyDoubleList(); + private void ensureYNormIsMutable() { + if (!((bitField0_ & 0x00000008) != 0)) { + yNorm_ = mutableCopy(yNorm_); + bitField0_ |= 0x00000008; + } + } + /** + * repeated double y_norm = 8; + * @return A list containing the yNorm. + */ + public java.util.List + getYNormList() { + return ((bitField0_ & 0x00000008) != 0) ? + java.util.Collections.unmodifiableList(yNorm_) : yNorm_; + } + /** + * repeated double y_norm = 8; + * @return The count of yNorm. + */ + public int getYNormCount() { + return yNorm_.size(); + } + /** + * repeated double y_norm = 8; + * @param index The index of the element to return. + * @return The yNorm at the given index. + */ + public double getYNorm(int index) { + return yNorm_.getDouble(index); + } + /** + * repeated double y_norm = 8; + * @param index The index to set the value at. + * @param value The yNorm to set. + * @return This builder for chaining. + */ + public Builder setYNorm( + int index, double value) { + ensureYNormIsMutable(); + yNorm_.setDouble(index, value); + onChanged(); + return this; + } + /** + * repeated double y_norm = 8; + * @param value The yNorm to add. + * @return This builder for chaining. + */ + public Builder addYNorm(double value) { + ensureYNormIsMutable(); + yNorm_.addDouble(value); + onChanged(); + return this; + } + /** + * repeated double y_norm = 8; + * @param values The yNorm to add. + * @return This builder for chaining. + */ + public Builder addAllYNorm( + java.lang.Iterable values) { + ensureYNormIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, yNorm_); + onChanged(); + return this; + } + /** + * repeated double y_norm = 8; + * @return This builder for chaining. + */ + public Builder clearYNorm() { + yNorm_ = emptyDoubleList(); + bitField0_ = (bitField0_ & ~0x00000008); + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:tribuo.regression.mnb.SparseLinearModelProto) + } + + // @@protoc_insertion_point(class_scope:tribuo.regression.mnb.SparseLinearModelProto) + private static final org.tribuo.regression.slm.protos.SparseLinearModelProto DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.tribuo.regression.slm.protos.SparseLinearModelProto(); + } + + public static org.tribuo.regression.slm.protos.SparseLinearModelProto getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public SparseLinearModelProto parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new SparseLinearModelProto(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.tribuo.regression.slm.protos.SparseLinearModelProto getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + +} + diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProtoOrBuilder.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProtoOrBuilder.java new file mode 100644 index 000000000..d0d110b1f --- /dev/null +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/SparseLinearModelProtoOrBuilder.java @@ -0,0 +1,143 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tribuo-regression-slm.proto + +package org.tribuo.regression.slm.protos; + +public interface SparseLinearModelProtoOrBuilder extends + // @@protoc_insertion_point(interface_extends:tribuo.regression.mnb.SparseLinearModelProto) + com.google.protobuf.MessageOrBuilder { + + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return Whether the metadata field is set. + */ + boolean hasMetadata(); + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return The metadata. + */ + org.tribuo.protos.core.ModelDataProto getMetadata(); + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + org.tribuo.protos.core.ModelDataProtoOrBuilder getMetadataOrBuilder(); + + /** + * repeated string dimensions = 2; + * @return A list containing the dimensions. + */ + java.util.List + getDimensionsList(); + /** + * repeated string dimensions = 2; + * @return The count of dimensions. + */ + int getDimensionsCount(); + /** + * repeated string dimensions = 2; + * @param index The index of the element to return. + * @return The dimensions at the given index. + */ + java.lang.String getDimensions(int index); + /** + * repeated string dimensions = 2; + * @param index The index of the value to return. + * @return The bytes of the dimensions at the given index. + */ + com.google.protobuf.ByteString + getDimensionsBytes(int index); + + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + java.util.List + getWeightsList(); + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + org.tribuo.math.protos.TensorProto getWeights(int index); + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + int getWeightsCount(); + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + java.util.List + getWeightsOrBuilderList(); + /** + * repeated .tribuo.math.TensorProto weights = 3; + */ + org.tribuo.math.protos.TensorProtoOrBuilder getWeightsOrBuilder( + int index); + + /** + * .tribuo.math.TensorProto feature_means = 4; + * @return Whether the featureMeans field is set. + */ + boolean hasFeatureMeans(); + /** + * .tribuo.math.TensorProto feature_means = 4; + * @return The featureMeans. + */ + org.tribuo.math.protos.TensorProto getFeatureMeans(); + /** + * .tribuo.math.TensorProto feature_means = 4; + */ + org.tribuo.math.protos.TensorProtoOrBuilder getFeatureMeansOrBuilder(); + + /** + * .tribuo.math.TensorProto feature_norms = 5; + * @return Whether the featureNorms field is set. + */ + boolean hasFeatureNorms(); + /** + * .tribuo.math.TensorProto feature_norms = 5; + * @return The featureNorms. + */ + org.tribuo.math.protos.TensorProto getFeatureNorms(); + /** + * .tribuo.math.TensorProto feature_norms = 5; + */ + org.tribuo.math.protos.TensorProtoOrBuilder getFeatureNormsOrBuilder(); + + /** + * bool bias = 6; + * @return The bias. + */ + boolean getBias(); + + /** + * repeated double y_mean = 7; + * @return A list containing the yMean. + */ + java.util.List getYMeanList(); + /** + * repeated double y_mean = 7; + * @return The count of yMean. + */ + int getYMeanCount(); + /** + * repeated double y_mean = 7; + * @param index The index of the element to return. + * @return The yMean at the given index. + */ + double getYMean(int index); + + /** + * repeated double y_norm = 8; + * @return A list containing the yNorm. + */ + java.util.List getYNormList(); + /** + * repeated double y_norm = 8; + * @return The count of yNorm. + */ + int getYNormCount(); + /** + * repeated double y_norm = 8; + * @param index The index of the element to return. + * @return The yNorm at the given index. + */ + double getYNorm(int index); +} diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/TribuoRegressionSlm.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/TribuoRegressionSlm.java new file mode 100644 index 000000000..6dffa4b2b --- /dev/null +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/protos/TribuoRegressionSlm.java @@ -0,0 +1,60 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tribuo-regression-slm.proto + +package org.tribuo.regression.slm.protos; + +public final class TribuoRegressionSlm { + private TribuoRegressionSlm() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistryLite registry) { + } + + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions( + (com.google.protobuf.ExtensionRegistryLite) registry); + } + static final com.google.protobuf.Descriptors.Descriptor + internal_static_tribuo_regression_mnb_SparseLinearModelProto_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_tribuo_regression_mnb_SparseLinearModelProto_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\033tribuo-regression-slm.proto\022\025tribuo.re" + + "gression.mnb\032\021tribuo-core.proto\032\021tribuo-" + + "math.proto\"\226\002\n\026SparseLinearModelProto\022-\n" + + "\010metadata\030\001 \001(\0132\033.tribuo.core.ModelDataP" + + "roto\022\022\n\ndimensions\030\002 \003(\t\022)\n\007weights\030\003 \003(" + + "\0132\030.tribuo.math.TensorProto\022/\n\rfeature_m" + + "eans\030\004 \001(\0132\030.tribuo.math.TensorProto\022/\n\r" + + "feature_norms\030\005 \001(\0132\030.tribuo.math.Tensor" + + "Proto\022\014\n\004bias\030\006 \001(\010\022\016\n\006y_mean\030\007 \003(\001\022\016\n\006y" + + "_norm\030\010 \003(\001B$\n org.tribuo.regression.slm" + + ".protosP\001b\006proto3" + }; + descriptor = com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + org.tribuo.protos.core.TribuoCore.getDescriptor(), + org.tribuo.math.protos.TribuoMath.getDescriptor(), + }); + internal_static_tribuo_regression_mnb_SparseLinearModelProto_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_tribuo_regression_mnb_SparseLinearModelProto_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_tribuo_regression_mnb_SparseLinearModelProto_descriptor, + new java.lang.String[] { "Metadata", "Dimensions", "Weights", "FeatureMeans", "FeatureNorms", "Bias", "YMean", "YNorm", }); + org.tribuo.protos.core.TribuoCore.getDescriptor(); + org.tribuo.math.protos.TribuoMath.getDescriptor(); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/Regression/SLM/src/main/resources/protos/tribuo-regression-slm.proto b/Regression/SLM/src/main/resources/protos/tribuo-regression-slm.proto new file mode 100644 index 000000000..a88e493db --- /dev/null +++ b/Regression/SLM/src/main/resources/protos/tribuo-regression-slm.proto @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +/* + * Protos for serializing Tribuo's sparse linear models. + */ +package tribuo.regression.mnb; + +option java_multiple_files = true; +option java_package = "org.tribuo.regression.slm.protos"; + +// Import Tribuo's core protos +import "tribuo-core.proto"; + +// Import Tribuo's math protos +import "tribuo-math.proto"; + +/* +SparseLinearModelProto proto + */ +message SparseLinearModelProto { + tribuo.core.ModelDataProto metadata = 1; + repeated string dimensions = 2; + repeated tribuo.math.TensorProto weights = 3; + tribuo.math.TensorProto feature_means = 4; + tribuo.math.TensorProto feature_norms = 5; + bool bias = 6; + repeated double y_mean = 7; + repeated double y_norm = 8; +} diff --git a/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java b/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java index b1b23727a..b6d7e4ea6 100644 --- a/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java +++ b/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java @@ -129,14 +129,19 @@ public void testDenseData() { Pair,Dataset> p = RegressionDataGenerator.denseTrainTest(); Model sfs = testSFS(p,false); Helpers.testModelSerialization(sfs,Regressor.class); + Helpers.testModelProtoSerialization(sfs,Regressor.class,p.getB()); Model sfsn = testSFSN(p,false); Helpers.testModelSerialization(sfsn,Regressor.class); + Helpers.testModelProtoSerialization(sfsn,Regressor.class,p.getB()); Model lars = testLARS(p,false); Helpers.testModelSerialization(lars,Regressor.class); + Helpers.testModelProtoSerialization(lars,Regressor.class,p.getB()); Model lasso = testLASSO(p,false); Helpers.testModelSerialization(lasso,Regressor.class); + Helpers.testModelProtoSerialization(lasso,Regressor.class,p.getB()); Model elastic = testElasticNet(p,false); Helpers.testModelSerialization(elastic,Regressor.class); + Helpers.testModelProtoSerialization(elastic,Regressor.class,p.getB()); } @Test @@ -220,9 +225,11 @@ public void testNegativeInvocationCount(){ }); } + @Test public void testThreeDenseDataLARS() { Pair,Dataset> p = RegressionDataGenerator.threeDimDenseTrainTest(1.0, false); SparseModel llModel = LARS.train(p.getA()); + Helpers.testModelProtoSerialization(llModel, Regressor.class, p.getB()); RegressionEvaluation llEval = e.evaluate(llModel,p.getB()); double expectedDim1 = 0.5671244360433836; double expectedDim2 = 0.5671244360433927; @@ -248,6 +255,7 @@ public void testThreeDenseDataLARS() { public void testThreeDenseDataENet() { Pair,Dataset> p = RegressionDataGenerator.threeDimDenseTrainTest(1.0, false); SparseModel llModel = ELASTIC_NET.train(p.getA()); + Helpers.testModelProtoSerialization(llModel, Regressor.class, p.getB()); RegressionEvaluation llEval = e.evaluate(llModel,p.getB()); double expectedDim1 = 0.5902193395184064; double expectedDim2 = 0.5902193395184064;