Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResponseProcessor State-setting and Tests #178

Merged
merged 7 commits into from
Oct 14, 2021
Merged
28 changes: 27 additions & 1 deletion Core/src/test/java/org/tribuo/test/Helpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

package org.tribuo.test;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.ConfigurationData;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
import org.junit.jupiter.api.Assertions;
Expand All @@ -27,7 +32,6 @@
import org.tribuo.MutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.impl.ListExample;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceModel;

import java.io.BufferedInputStream;
Expand All @@ -40,7 +44,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* Test helpers
Expand Down Expand Up @@ -76,6 +82,26 @@ public static Example<MockOutput> mkExample(MockOutput label, String... features
return ex;
}

/**
* Takes an object that is both {@link Provenancable} and {@link Configurable} and tests whether the configuration
* and provenance representations are the same using {@link ConfigurationData#structuralEquals(List, List, String, String)}.
* @param itm The object whose equality is to be tested
*/
public static <P extends ConfiguredObjectProvenance, C extends Configurable & Provenancable<P>> void testConfigurableRoundtrip(C itm) {
ConfigurationManager cm = new ConfigurationManager();
String name = cm.importConfigurable(itm, "item");
List<ConfigurationData> configData = cm.getComponentNames().stream()
.map(cm::getConfigurationData)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());

List<ConfigurationData> provenData = ProvenanceUtil.extractConfiguration(itm.getProvenance());

Assertions.assertTrue(ConfigurationData.structuralEquals(configData, provenData, name, provenData.get(0).getName()));
}


public static void testProvenanceMarshalling(ObjectProvenance inputProvenance) {
List<ObjectMarshalledProvenance> provenanceList = ProvenanceUtil.marshalProvenance(inputProvenance);
ObjectProvenance unmarshalledProvenance = ProvenanceUtil.unmarshalProvenance(provenanceList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@
* A {@link ResponseProcessor} that takes a single value of the
* field as the positive class and all other values as the negative
* class.
* <p>
* We support specifying field names and default values both singly through {@link #fieldName} and {@link #positiveResponse}
* and in a list through {@link #fieldNames} and {@link #positiveResponses}. The constructors and configuration preprocessing
* have differing behaviors based on which fields are populated:
* <ul>
* <li> {@link #fieldNames} and {@link #positiveResponses} are both populated and the same length: fieldNames[i]'s positiveResponse is positiveResponses[i]
* <li> {@link #fieldNames} and {@link #positiveResponse} are both populated: positiveResponse is broadcast across all fieldNames
* <li> {@link #fieldName} and {@link #positiveResponse} are both populated: fieldNames[0] == fieldName, positiveResponses[0] == positiveResponse
* </ul>
* All other settings are invalid.
*/
public class BinaryResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> {

Expand Down Expand Up @@ -77,30 +87,37 @@ public class BinaryResponseProcessor<T extends Output<T>> implements ResponsePro

@Override
public void postConfig() {
if (fieldName != null && fieldNames != null) { // we can only have one path
throw new PropertyException(configName, "fieldName, FieldNames", "only one of fieldName or fieldNames can be populated");
} else if (fieldNames != null) {
if(positiveResponse != null) {
positiveResponses = positiveResponses == null ? Collections.nCopies(fieldNames.size(), positiveResponse) : positiveResponses;
if(positiveResponses.size() != fieldNames.size()) {
throw new PropertyException(configName, "positiveResponses", "must either be empty or match the length of fieldNames");
}
} else {
throw new PropertyException(configName, "positiveResponse, positiveResponses", "one of positiveResponse or positiveResponses must be populated");
}
} else if (fieldName != null) {
if(positiveResponses != null) {
throw new PropertyException(configName, "positiveResponses", "if fieldName is populated, positiveResponses must be blank");
}
/*
* Canonically all internal logic is driven by fieldNames and positiveResponses, so this method takes values
* populated in fieldName and positiveResponse and sets them appropriately.
*/
boolean bothFieldNamesPopulated = fieldName != null && fieldNames != null;
boolean neitherFieldNamesPopulated = fieldName == null && fieldNames == null;
boolean multipleFieldNamesPopulated = fieldNames != null;
boolean singleFieldNamePopulated = fieldName != null;

boolean bothPositiveResponsesPopulated = positiveResponses != null && positiveResponse != null;
boolean neitherPositiveResponsesPopulated = positiveResponse == null && positiveResponses == null;
boolean multiplePositiveResponsesPopulated = positiveResponses != null;
boolean singlePositiveResponsePopulated = positiveResponse != null;

if (bothFieldNamesPopulated || neitherFieldNamesPopulated) {
throw new PropertyException(configName, "fieldName, FieldNames", "exactly one of fieldName or fieldNames must be populated");
} else if (bothPositiveResponsesPopulated || neitherPositiveResponsesPopulated) {
throw new PropertyException(configName, "positiveResponse, positiveResponses", "exactly one of positiveResponse or positiveResponses must be populated");
} else if(multipleFieldNamesPopulated && multiplePositiveResponsesPopulated && fieldNames.size() != positiveResponses.size()) { //sizes don't match
throw new PropertyException(configName, "positiveResponses", "must match the length of fieldNames");
} else if(multipleFieldNamesPopulated && singlePositiveResponsePopulated) {
positiveResponses = Collections.nCopies(fieldNames.size(), positiveResponse);
positiveResponse = null;
} else if(singleFieldNamePopulated && multiplePositiveResponsesPopulated) {
throw new PropertyException(configName, "positiveResponses", "if fieldName is populated, positiveResponses must be blank");
} else if(singleFieldNamePopulated && singlePositiveResponsePopulated) {
fieldNames = Collections.singletonList(fieldName);
if(positiveResponse != null) {
positiveResponses = Collections.singletonList(positiveResponse);
} else {
throw new PropertyException(configName, "positiveResponse", "if fieldName is populated positiveResponse must be populated");
}
} else {
throw new PropertyException(configName, "fieldName, fieldNames", "One of fieldName or fieldNames must be populated");
}
fieldName = null;
positiveResponses = Collections.singletonList(positiveResponse);
positiveResponse = null;
} // the case where both positiveResponses and fieldNames are populated and their sizes match requires no action
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@
import java.util.Optional;

/**
* A response processor that returns the value in a given field.
* A response processor that returns the value(s) in a given (set of) fields.
* <p>
* We support specifying field names and default values both singly through {@link #fieldName} and {@link #defaultValue}
* and in a list through {@link #fieldNames} and {@link #defaultValues}. The constructors and configuration preprocessing
* have differing behaviors based on which fields are populated:
* <ul>
* <li> {@link #fieldNames} and {@link #defaultValues} are both populated and the same length: fieldNames[i]'s defaultValue is defaultValues[i]
* <li> {@link #fieldNames} and {@link #defaultValue} are both populated: defaultValue is broadcast across all fieldNames
* <li> {@link #fieldName} and {@link #defaultValue} are both populated: fieldNames[0] == fieldName, defaultValues[0] == defaultValue
* </ul>
* All other settings are invalid.
*/
public class FieldResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> {

Expand Down Expand Up @@ -62,30 +72,38 @@ public class FieldResponseProcessor<T extends Output<T>> implements ResponseProc

@Override
public void postConfig() {
if (fieldName != null && fieldNames != null) {
throw new PropertyException(configName, "fieldName, FieldNames", "only one of fieldName or fieldNames can be populated");
} else if (fieldNames != null) {
if(defaultValue != null) {
defaultValues = defaultValues == null ? Collections.nCopies(fieldNames.size(), defaultValue) : defaultValues;
} else {
throw new PropertyException(configName, "defaultValue, defaultValues", "one of defaultValue or defaultValues must be populated");
}
if(defaultValues.size() != fieldNames.size()) {
throw new PropertyException(configName, "defaultValues", "must either be empty or match the length of fieldNames");
}
} else if (fieldName != null) {
if(defaultValues != null) {
throw new PropertyException(configName, "defaultValues", "if fieldName is populated, defaultValues must be blank");
}
/*
* Canonically all internal logic is driven by
* fieldNames and defaultValues, so this method takes values populated in fieldName and defaultValue and sets them
* appropriately.
*/
boolean bothFieldNamesPopulated = fieldName != null && fieldNames != null;
boolean neitherFieldNamesPopulated = fieldName == null && fieldNames == null;
boolean multipleFieldNamesPopulated = fieldNames != null;
boolean singleFieldNamePopulated = fieldName != null;

boolean bothDefaultValuesPopulated = defaultValues != null && defaultValue != null;
boolean neitherDefaultValuesPopulated = defaultValue == null && defaultValues == null;
boolean multipleDefaultValuesPopulated = defaultValues != null;
boolean singleDefaultValuePopulated = defaultValue != null;

if (bothFieldNamesPopulated || neitherFieldNamesPopulated) {
throw new PropertyException(configName, "fieldName, FieldNames", "exactly one of fieldName or fieldNames must be populated");
} else if (bothDefaultValuesPopulated || neitherDefaultValuesPopulated) {
throw new PropertyException(configName, "defaultValue, defaultValues", "exactly one of defaultValue or defaultValues must be populated");
} else if(multipleFieldNamesPopulated && multipleDefaultValuesPopulated && fieldNames.size() != defaultValues.size()) { //sizes don't match
throw new PropertyException(configName, "defaultValues", "must match the length of fieldNames");
} else if(multipleFieldNamesPopulated && singleDefaultValuePopulated) {
defaultValues = Collections.nCopies(fieldNames.size(), defaultValue);
defaultValue = null;
} else if(singleFieldNamePopulated && multipleDefaultValuesPopulated) {
throw new PropertyException(configName, "defaultValues", "if fieldName is populated, defaultValues must be blank");
} else if(singleFieldNamePopulated && singleDefaultValuePopulated) {
fieldNames = Collections.singletonList(fieldName);
if (defaultValue != null) {
defaultValues = Collections.singletonList(defaultValue);
} else {
throw new PropertyException(configName, "defaultValue", "if fieldName is populated, defaultValue must be populated");
}
} else {
throw new PropertyException(configName, "fieldName, fieldNames", "One of fieldName or fieldNames must be populated");
}
fieldName = null;
defaultValues = Collections.singletonList(defaultValue);
defaultValue = null;
} // the case where both defaultValues and fieldNames are populated and their sizes match requires no action
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.tribuo.data.columnar.processors.response;

import org.junit.jupiter.api.Test;
import org.tribuo.test.Helpers;
import org.tribuo.test.MockMultiOutput;
import org.tribuo.test.MockMultiOutputFactory;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;

import java.util.Arrays;

public class ResponseProcessorRoundtripTest {

@Test
public void binaryTest() {
BinaryResponseProcessor<MockMultiOutput> multiRespProc = new BinaryResponseProcessor<>(
Arrays.asList("R1", "R2"),
Arrays.asList("TRUE", "TRUE"),
new MockMultiOutputFactory(),
"true", "false", true);

Helpers.testConfigurableRoundtrip(multiRespProc);

BinaryResponseProcessor<MockOutput> singleRespProc = new BinaryResponseProcessor<>("R1", "TRUE", new MockOutputFactory());

Helpers.testConfigurableRoundtrip(singleRespProc);
}

@Test
public void fieldTest() {
FieldResponseProcessor<MockMultiOutput> multiRespProc = new FieldResponseProcessor<>(
Arrays.asList("R1", "R2"),
Arrays.asList("A", "B"),
new MockMultiOutputFactory(),
true, false);

Helpers.testConfigurableRoundtrip(multiRespProc);

FieldResponseProcessor<MockOutput> singleRespProc = new FieldResponseProcessor<>("R1", "A", new MockOutputFactory());

Helpers.testConfigurableRoundtrip(singleRespProc);
}
}
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>

<!-- MLRG dependencies -->
<olcut.version>5.1.6</olcut.version>
<olcut.version>5.2.0</olcut.version>

<!-- 3rd party backend dependencies -->
<liblinear.version>2.43</liblinear.version>
Expand Down