Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update version fix vector #11

Merged
merged 3 commits into from
Dec 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@
<dependency>
<groupId>ml.combust.mleap</groupId>
<artifactId>mleap-runtime_2.11</artifactId>
<version>0.14.0</version>
<version>0.15.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib-local -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib-local_2.11</artifactId>
<version>2.4.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,14 @@

package com.amazonaws.sagemaker.helper;

import com.amazonaws.sagemaker.dto.DataSchema;
import com.amazonaws.sagemaker.dto.ColumnSchema;
import com.amazonaws.sagemaker.dto.DataSchema;
import com.amazonaws.sagemaker.type.BasicDataType;
import com.amazonaws.sagemaker.type.DataStructureType;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.io.StringReader;
import java.util.List;
import java.util.stream.Collectors;
import ml.combust.mleap.core.types.BasicType;
import ml.combust.mleap.core.types.DataType;
import ml.combust.mleap.core.types.ListType;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.StructField;
import ml.combust.mleap.core.types.StructType;
import ml.combust.mleap.core.types.TensorType;
import ml.combust.mleap.core.types.*;
import ml.combust.mleap.runtime.frame.ArrayRow;
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
import ml.combust.mleap.runtime.frame.Row;
Expand All @@ -43,9 +33,15 @@
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.ml.linalg.Vectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.io.StringReader;
import java.util.List;
import java.util.stream.Collectors;

/**
* Converter class to convert data between input to MLeap expected types and convert back MLeap helper to Java types
* for output.
Expand Down Expand Up @@ -168,12 +164,12 @@ protected Object convertInputDataToJavaType(final String type, final String stru
default:
throw new IllegalArgumentException("Given type is not supported");
}
} else {
} else if (!StringUtils.isBlank(structure) && StringUtils.equals(structure, DataStructureType.ARRAY)) {
List<Object> listOfObjects;
try {
listOfObjects = (List<Object>) value;
} catch (ClassCastException cce) {
throw new IllegalArgumentException("Input val is not a list but struct passed is vector or array");
throw new IllegalArgumentException("Input val is not a list but struct passed is array");
}
switch (type) {
case BasicDataType.INTEGER:
Expand All @@ -194,7 +190,17 @@ protected Object convertInputDataToJavaType(final String type, final String stru
default:
throw new IllegalArgumentException("Given type is not supported");
}

} else {
if(!type.equals(BasicDataType.DOUBLE))
throw new IllegalArgumentException("Only Double type is supported for vector");
List<Double> vectorValues;
try {
vectorValues = (List<Double>)value;
} catch (ClassCastException cce) {
throw new IllegalArgumentException("Input val is not a list but struct passed is vector");
}
double[] primitiveVectorValues = vectorValues.stream().mapToDouble(d -> d).toArray();
return Vectors.dense(primitiveVectorValues);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import java.io.IOException;
import org.apache.commons.io.IOUtils;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;

public class SageMakerRequestObjectTest {

private ObjectMapper mapper = new ObjectMapper();
Expand Down Expand Up @@ -80,14 +81,14 @@ public void testParseCompleteInputJson() throws IOException {
Assert.assertEquals(sro.getSchema().getInput().get(0).getName(), "name_1");
Assert.assertEquals(sro.getSchema().getInput().get(1).getName(), "name_2");
Assert.assertEquals(sro.getSchema().getInput().get(2).getName(), "name_3");
Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "int");
Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "double");
Assert.assertEquals(sro.getSchema().getInput().get(1).getType(), "string");
Assert.assertEquals(sro.getSchema().getInput().get(2).getType(), "double");
Assert.assertEquals(sro.getSchema().getInput().get(0).getStruct(), "vector");
Assert.assertEquals(sro.getSchema().getInput().get(1).getStruct(), "basic");
Assert.assertEquals(sro.getSchema().getInput().get(2).getStruct(), "array");
Assert.assertEquals(sro.getData(),
Lists.newArrayList(Lists.newArrayList(1, 2, 3), "C", Lists.newArrayList(38.0, 24.0)));
Lists.newArrayList(Lists.newArrayList(1.0, 2.0, 3.0), "C", Lists.newArrayList(38.0, 24.0)));
Assert.assertEquals(sro.getSchema().getOutput().getName(), "features");
Assert.assertEquals(sro.getSchema().getOutput().getType(), "double");
Assert.assertEquals(sro.getSchema().getOutput().getStruct(), "vector");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import com.amazonaws.sagemaker.type.DataStructureType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.List;
import ml.combust.mleap.core.types.ListType;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.TensorType;
Expand All @@ -32,9 +30,13 @@
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder;
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilderSupport;
import org.apache.commons.io.IOUtils;
import org.apache.spark.ml.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.util.List;

public class DataConversionHelperTest {

private ObjectMapper mapper = new ObjectMapper();
Expand Down Expand Up @@ -143,21 +145,11 @@ public void testCastingInputToJavaTypeSingle() {

@Test
public void testCastingInputToJavaTypeList() {
Assert.assertEquals(Lists.newArrayList(1, 2), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR,
Lists.newArrayList(new Integer("1"), new Integer("2"))));

Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.FLOAT, DataStructureType.VECTOR,
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));

Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR,
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));

Assert.assertEquals(Lists.newArrayList(new Byte("1")), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.BYTE, DataStructureType.VECTOR,
Lists.newArrayList(new Byte("1"))));
//Check vector struct and double type returns a Spark vector
Assert.assertEquals(Vectors.dense(new double[]{1.0, 2.0}),dataConversionHelper
.convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR,
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));

Assert.assertEquals(Lists.newArrayList(1L, 2L), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.LONG, DataStructureType.ARRAY,
Expand All @@ -175,6 +167,12 @@ public void testCastingInputToJavaTypeList() {
Lists.newArrayList(Boolean.valueOf("1"))));
}

@Test(expected = IllegalArgumentException.class)
public void testConvertInputToJavaTypeNonDoibleVector() {
dataConversionHelper
.convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR, new Integer("1"));
}

@Test(expected = IllegalArgumentException.class)
public void testCastingInputToJavaTypeNonList() {
dataConversionHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"input": [
{
"name": "name_1",
"type": "int",
"type": "double",
"struct": "vector"
},
{
Expand All @@ -23,5 +23,5 @@
"struct": "vector"
}
},
"data": [[1, 2, 3], "C", [38.0, 24.0]]
"data": [[1.0, 2.0, 3.0], "C", [38.0, 24.0]]
}