Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add protobuf serialization for LibSVM based models #272

Merged
merged 2 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,20 @@

package org.tribuo.anomaly.libsvm;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.libsvm.protos.LibSVMAnomalyModelProto;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.Tensor;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import libsvm.svm;
import libsvm.svm_model;
Expand All @@ -32,7 +39,7 @@
import java.util.List;

/**
* A anomaly detection model that uses an underlying libSVM model to make the
* An anomaly detection model that uses an underlying libSVM model to make the
* predictions.
* <p>
* See:
Expand All @@ -52,10 +59,39 @@
public class LibSVMAnomalyModel extends LibSVMModel<Event> {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

LibSVMAnomalyModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Event> labelIDMap, List<svm_model> models) {
super(name, description, featureIDMap, labelIDMap, models.get(0).param.probability == 1, models);
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static LibSVMAnomalyModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
LibSVMAnomalyModelProto proto = message.unpack(LibSVMAnomalyModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Event.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not an anomaly detection domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Event> outputDomain = (ImmutableOutputInfo<Event>) carrier.outputDomain();

svm_model model = deserializeModel(proto.getModel());

return new LibSVMAnomalyModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,Collections.singletonList(model));
}

/**
* Returns the number of support vectors.
* @return The number of support vectors.
Expand Down Expand Up @@ -85,4 +121,19 @@ protected LibSVMAnomalyModel copy(String newName, ModelProvenance newProvenance)
return new LibSVMAnomalyModel(newName,newProvenance,featureIDMap,outputIDInfo, Collections.singletonList(LibSVMModel.copyModel(models.get(0))));
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Event> carrier = createDataCarrier();

LibSVMAnomalyModelProto.Builder modelBuilder = LibSVMAnomalyModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setModel(serializeModel(models.get(0)));

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(LibSVMAnomalyModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}
}
Loading