Skip to content

Commit

Permalink
Add an array filter to our serialize/deserialize methods and narrow d…
Browse files Browse the repository at this point in the history
…own previous filter (#849)

* Add an array filter to our serialize/deserialize methods and narrow down previous filter

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Further narrowing down accept list

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Keep narrowing down accept list

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

* Add test for deserialization methods in all built-in models

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>

---------

Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
  • Loading branch information
b4sjoo authored and rbhavna committed Jun 16, 2023
1 parent 42fa5cd commit 5765ca9
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 6 deletions.
3 changes: 3 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ repositories {
dependencies {
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation project(':opensearch-ml-common')
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
implementation group: 'org.reflections', name: 'reflections', version: '0.9.12'
implementation group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.1'
Expand All @@ -34,6 +36,7 @@ dependencies {
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.3.1'
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.1'
implementation platform("ai.djl:bom:0.19.0")
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
implementation group: 'ai.djl', name: 'api'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,56 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.ObjectInputFilter;
import java.util.Base64;

@UtilityClass
public class ModelSerDeSer {
// Welcome list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries.
// Accept list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries.
public static final String[] ACCEPT_CLASS_PATTERNS = {
"java.lang.*",
"java.util.*",
"java.time.*",
"org.tribuo.*",
"com.oracle.labs.mlrg.olcut.provenance.*",
"com.oracle.labs.mlrg.olcut.util.*",
"[I",
"[Z",
"[J",
"[C",
"[D",
"[F",
"[Ljava.lang.*",
"[Lorg.tribuo.*",
"[Llibsvm.*",
"[[I",
"[[Z",
"[[J",
"[[C",
"[[D",
"[[F",
"[[Ljava.lang.*",
"[[Lorg.tribuo.*",
"[[Llibsvm.*",
"org.opensearch.ml.*",
"*org.tribuo.*",
"libsvm.*",
"com.oracle.labs.*",
"[*",
"com.amazon.randomcutforest.*"
};

public static final String[] REJECT_CLASS_PATTERNS = {
"java.util.logging.*",
"java.util.zip.*",
"java.util.jar.*",
"java.util.random.*",
"java.util.spi.*",
"java.util.stream.*",
"java.util.regex.*",
"java.util.concurrent.*",
"java.util.function.*",
"java.util.prefs.*",
"java.time.zone.*",
"java.time.format.*",
"java.time.temporal.*",
"java.time.chrono.*",
};

public static String serializeToBase64(Object model) {
Expand All @@ -47,11 +82,15 @@ public static byte[] serialize(Object model) {
}
}

// This method has been tested in K-means, Linear Regression, Logistic regression, Anomaly Detection and Random Cut Forest summarization and passed.
public static Object deserialize(byte[] modelBin) {
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin);
ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){
// Validate the model class type to avoid deserialization attack.
validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS);
validatingObjectInputStream
.accept(ACCEPT_CLASS_PATTERNS)
.reject(REJECT_CLASS_PATTERNS)
.setObjectInputFilter(ObjectInputFilter.Config.createFilter("maxdepth=20;maxrefs=5000;maxbytes=10000000;maxarray=100000"));
return validatingObjectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import org.opensearch.ml.common.input.parameter.ad.AnomalyDetectionLibSVMParams;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.example.AnomalyDataGenerator;
import org.tribuo.common.libsvm.LibSVMModel;

import java.util.ArrayList;
import java.util.Iterator;
Expand Down Expand Up @@ -117,6 +119,13 @@ public void train() {
Assert.assertNotNull(model.getContent());
}

@Test
public void testModelSerDeSer() {
MLModel model = anomalyDetection.train(trainDataFrameInput);
LibSVMModel deserializedModel = (LibSVMModel) ModelSerDeSer.deserialize(model);
Assert.assertNotNull(deserializedModel);
}

@Test
public void trainWithFullParams() {
AnomalyDetectionLibSVMParams parameters = AnomalyDetectionLibSVMParams.builder().gamma(gamma).nu(nu).cost(1.0).coeff(0.01).epsilon(0.001).degree(1).kernelType(AnomalyDetectionLibSVMParams.ADKernelType.LINEAR).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.clustering;

import com.amazon.randomcutforest.returntypes.SampleSummary;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.ml.common.input.parameter.clustering.RCFSummarizeParams;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.utils.ModelSerDeSer;

import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

Expand Down Expand Up @@ -61,6 +63,13 @@ public void predictWithTrivalModelExpectBoNorminalOutput() {
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}

@Test
public void testModelSerDeSer() {
MLModel model = rcfSummarize.train(trainDataFrameInput);
SampleSummary deserializedModel = ((SerializableSummary) ModelSerDeSer.deserialize(model)).getSummary();
Assert.assertNotNull(deserializedModel);
}

@Test
public void trainAndPredictWithRegularInputExpectNotNullOutput() {
RCFSummarizeParams parameters = RCFSummarizeParams.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.tribuo.classification.Label;

import static org.opensearch.ml.engine.helper.LogisticRegressionHelper.constructLogisticRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LogisticRegressionHelper.constructLogisticRegressionTrainDataFrame;
Expand Down Expand Up @@ -109,6 +111,14 @@ public void predict() {
Assert.assertEquals(2, predictions.size());
}

@Test
public void testModelSerDeSer() {
LogisticRegression classification = new LogisticRegression(parameters);
MLModel model = classification.train(trainDataFrameInput);
org.tribuo.Model<Label> deserializedModel = (org.tribuo.Model<Label>) ModelSerDeSer.deserialize(model);
Assert.assertNotNull(deserializedModel);
}

@Test
public void predictWithoutModel() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down

0 comments on commit 5765ca9

Please sign in to comment.