Skip to content

Commit

Permalink
Merge pull request #11 from himanishk/update-version-fix-vector
Browse files Browse the repository at this point in the history
Update version fix vector
  • Loading branch information
mahithsuresh authored Dec 28, 2020
2 parents 5b5b3e6 + 3e64d69 commit 884f066
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 37 deletions.
8 changes: 7 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@
<dependency>
<groupId>ml.combust.mleap</groupId>
<artifactId>mleap-runtime_2.11</artifactId>
<version>0.14.0</version>
<version>0.15.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib-local -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib-local_2.11</artifactId>
<version>2.4.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,14 @@

package com.amazonaws.sagemaker.helper;

import com.amazonaws.sagemaker.dto.DataSchema;
import com.amazonaws.sagemaker.dto.ColumnSchema;
import com.amazonaws.sagemaker.dto.DataSchema;
import com.amazonaws.sagemaker.type.BasicDataType;
import com.amazonaws.sagemaker.type.DataStructureType;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.io.StringReader;
import java.util.List;
import java.util.stream.Collectors;
import ml.combust.mleap.core.types.BasicType;
import ml.combust.mleap.core.types.DataType;
import ml.combust.mleap.core.types.ListType;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.StructField;
import ml.combust.mleap.core.types.StructType;
import ml.combust.mleap.core.types.TensorType;
import ml.combust.mleap.core.types.*;
import ml.combust.mleap.runtime.frame.ArrayRow;
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
import ml.combust.mleap.runtime.frame.Row;
Expand All @@ -43,9 +33,15 @@
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.ml.linalg.Vectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.io.StringReader;
import java.util.List;
import java.util.stream.Collectors;

/**
* Converter class to convert data between input to MLeap expected types and convert back MLeap helper to Java types
* for output.
Expand Down Expand Up @@ -168,12 +164,12 @@ protected Object convertInputDataToJavaType(final String type, final String stru
default:
throw new IllegalArgumentException("Given type is not supported");
}
} else {
} else if (!StringUtils.isBlank(structure) && StringUtils.equals(structure, DataStructureType.ARRAY)) {
List<Object> listOfObjects;
try {
listOfObjects = (List<Object>) value;
} catch (ClassCastException cce) {
throw new IllegalArgumentException("Input val is not a list but struct passed is vector or array");
throw new IllegalArgumentException("Input val is not a list but struct passed is array");
}
switch (type) {
case BasicDataType.INTEGER:
Expand All @@ -194,7 +190,17 @@ protected Object convertInputDataToJavaType(final String type, final String stru
default:
throw new IllegalArgumentException("Given type is not supported");
}

} else {
if(!type.equals(BasicDataType.DOUBLE))
throw new IllegalArgumentException("Only Double type is supported for vector");
List<Double> vectorValues;
try {
vectorValues = (List<Double>)value;
} catch (ClassCastException cce) {
throw new IllegalArgumentException("Input val is not a list but struct passed is vector");
}
double[] primitiveVectorValues = vectorValues.stream().mapToDouble(d -> d).toArray();
return Vectors.dense(primitiveVectorValues);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import java.io.IOException;
import org.apache.commons.io.IOUtils;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;

public class SageMakerRequestObjectTest {

private ObjectMapper mapper = new ObjectMapper();
Expand Down Expand Up @@ -80,14 +81,14 @@ public void testParseCompleteInputJson() throws IOException {
Assert.assertEquals(sro.getSchema().getInput().get(0).getName(), "name_1");
Assert.assertEquals(sro.getSchema().getInput().get(1).getName(), "name_2");
Assert.assertEquals(sro.getSchema().getInput().get(2).getName(), "name_3");
Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "int");
Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "double");
Assert.assertEquals(sro.getSchema().getInput().get(1).getType(), "string");
Assert.assertEquals(sro.getSchema().getInput().get(2).getType(), "double");
Assert.assertEquals(sro.getSchema().getInput().get(0).getStruct(), "vector");
Assert.assertEquals(sro.getSchema().getInput().get(1).getStruct(), "basic");
Assert.assertEquals(sro.getSchema().getInput().get(2).getStruct(), "array");
Assert.assertEquals(sro.getData(),
Lists.newArrayList(Lists.newArrayList(1, 2, 3), "C", Lists.newArrayList(38.0, 24.0)));
Lists.newArrayList(Lists.newArrayList(1.0, 2.0, 3.0), "C", Lists.newArrayList(38.0, 24.0)));
Assert.assertEquals(sro.getSchema().getOutput().getName(), "features");
Assert.assertEquals(sro.getSchema().getOutput().getType(), "double");
Assert.assertEquals(sro.getSchema().getOutput().getStruct(), "vector");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import com.amazonaws.sagemaker.type.DataStructureType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.List;
import ml.combust.mleap.core.types.ListType;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.TensorType;
Expand All @@ -32,9 +30,13 @@
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder;
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilderSupport;
import org.apache.commons.io.IOUtils;
import org.apache.spark.ml.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.util.List;

public class DataConversionHelperTest {

private ObjectMapper mapper = new ObjectMapper();
Expand Down Expand Up @@ -143,21 +145,11 @@ public void testCastingInputToJavaTypeSingle() {

@Test
public void testCastingInputToJavaTypeList() {
Assert.assertEquals(Lists.newArrayList(1, 2), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR,
Lists.newArrayList(new Integer("1"), new Integer("2"))));

Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.FLOAT, DataStructureType.VECTOR,
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));

Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR,
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));

Assert.assertEquals(Lists.newArrayList(new Byte("1")), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.BYTE, DataStructureType.VECTOR,
Lists.newArrayList(new Byte("1"))));
//Check vector struct and double type returns a Spark vector
Assert.assertEquals(Vectors.dense(new double[]{1.0, 2.0}),dataConversionHelper
.convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR,
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));

Assert.assertEquals(Lists.newArrayList(1L, 2L), dataConversionHelper
.convertInputDataToJavaType(BasicDataType.LONG, DataStructureType.ARRAY,
Expand All @@ -175,6 +167,12 @@ public void testCastingInputToJavaTypeList() {
Lists.newArrayList(Boolean.valueOf("1"))));
}

@Test(expected = IllegalArgumentException.class)
public void testConvertInputToJavaTypeNonDoibleVector() {
dataConversionHelper
.convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR, new Integer("1"));
}

@Test(expected = IllegalArgumentException.class)
public void testCastingInputToJavaTypeNonList() {
dataConversionHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"input": [
{
"name": "name_1",
"type": "int",
"type": "double",
"struct": "vector"
},
{
Expand All @@ -23,5 +23,5 @@
"struct": "vector"
}
},
"data": [[1, 2, 3], "C", [38.0, 24.0]]
"data": [[1.0, 2.0, 3.0], "C", [38.0, 24.0]]
}

0 comments on commit 884f066

Please sign in to comment.