Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for multioutputs to ResponseProcesser, with tests. #150

Merged
merged 7 commits into from
Jul 15, 2021
31 changes: 30 additions & 1 deletion Core/src/test/java/org/tribuo/test/MockMultiOutputFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.tribuo.test;

import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.MutableOutputInfo;
Expand Down Expand Up @@ -84,7 +85,7 @@ public boolean equals(Object obj) {

@Override
public OutputFactoryProvenance getProvenance() {
return null;
return new MockMultiOutputFactoryProvenance();
}

/**
Expand Down Expand Up @@ -163,4 +164,32 @@ public static MockMultiOutput createFromPairList(List<Pair<String,Boolean>> dime
}
return new MockMultiOutput(labels);
}

public static class MockMultiOutputFactoryProvenance implements OutputFactoryProvenance {
private static final long serialVersionUID=1L;

MockMultiOutputFactoryProvenance() {}

public MockMultiOutputFactoryProvenance(Map<String, Provenance> map) {}

@Override
public String getClassName() {
return MockMultiOutputFactory.class.getName();
}

@Override
public String toString() {
return generateString("MockMultiOutputFactory");
}

@Override
public boolean equals(Object other) {
return other instanceof MockMultiOutputFactoryProvenance;
}

@Override
public int hashCode() {
return 32;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tribuo.Output;
import org.tribuo.OutputFactory;

import java.util.List;
import java.util.Optional;

/**
Expand All @@ -36,9 +37,11 @@ public interface ResponseProcessor<T extends Output<T>> extends Configurable, Pr
public OutputFactory<T> getOutputFactory();

/**
* @deprecated use {@link #getFieldNames()} and support multiple values instead.
* Gets the field name this ResponseProcessor uses.
* @return The field name.
*/
@Deprecated
public String getFieldName();

/**
Expand All @@ -47,13 +50,38 @@ public interface ResponseProcessor<T extends Output<T>> extends Configurable, Pr
* @param fieldName The field name.
*
*/
@Deprecated()
@Deprecated
public void setFieldName(String fieldName);

/**
* @deprecated use {@link #process(List)} and support multiple values instead.
* Returns Optional.empty() if it failed to process out a response.
* @param value The value to process.
* @return The response value if found.
*/
@Deprecated
public Optional<T> process(String value);

/**
* Returns Optional.empty() if it failed to process out a response.This method has a default
* implementation for backwards compatibility with Tribuo 4.0 and 4.1. This method should be
* overridden by code which depends on newer versions of Tribuo. The default implementation
* will be removed when the deprecated members are removed. Unless is is overridden it will
* throw an {@link IllegalArgumentException} when called with multiple values.
* @param values The value to process.
* @return The response values if found.
*/
default Optional<T> process(List<String> values) {
if (values.size() != 1) {
throw new IllegalArgumentException(getClass().getSimpleName() + " does not implement support for multiple response values");
} else {
return process(values.get(0));
}
}

/**
* Gets the field names this ResponseProcessor uses.
* @return The field names.
*/
List<String> getFieldNames();
}
7 changes: 5 additions & 2 deletions Data/src/main/java/org/tribuo/data/columnar/RowProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,11 @@ public Set<FeatureProcessor> getFeatureProcessors() {
* @return An Optional containing an Example if the row was valid, an empty Optional otherwise.
*/
public Optional<Example<T>> generateExample(ColumnarIterator.Row row, boolean outputRequired) {
String responseValue = row.getRowData().get(responseProcessor.getFieldName());
Optional<T> labelOpt = responseProcessor.process(responseValue);

List<String> responseValues = responseProcessor.getFieldNames().stream()
.map(f -> row.getRowData().getOrDefault(f, ""))
.collect(Collectors.toList());
Optional<T> labelOpt = responseProcessor.process(responseValues);
if (!labelOpt.isPresent() && outputRequired) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
package org.tribuo.data.columnar.processors.response;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.ConfigurableName;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.data.columnar.ResponseProcessor;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
Expand All @@ -32,20 +37,64 @@
*/
public class BinaryResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> {

@Config(mandatory = true,description="The field name to read.")
@Config(description="The field name to read, you should use only one of this or fieldNames")
@Deprecated
private String fieldName;

@Config(mandatory = true,description="The string which triggers a positive response.")
@Config(description="The string which triggers a positive response.")
private String positiveResponse;

@Config(mandatory = true,description="Output factory to use to create the response.")
@Config(mandatory = true, description="Output factory to use to create the response.")
private OutputFactory<T> outputFactory;

public static final String POSITIVE_NAME = "1";
public static final String NEGATIVE_NAME = "0";

@Config(description="The positive response to emit.")
private String positiveName = "1";
private String positiveName = POSITIVE_NAME;

@Config(description="The negative response to emit.")
private String negativeName = "0";
private String negativeName = NEGATIVE_NAME;

@Config(description = "A list of field names to read, you should use only one of this or fieldName.")
private List<String> fieldNames;

@Config(description = "A list of strings that trigger positive responses; it should be the same length as fieldNames or empty")
private List<String> positiveResponses;

@Config(description = "Whether to display field names as part of the generated output, defaults to false")
private boolean displayField;

@ConfigurableName
private String configName;

@Override
public void postConfig() {
if (fieldName != null && fieldNames != null) { // we can only have one path
throw new PropertyException(configName, "fieldName, FieldNames", "only one of fieldName or fieldNames can be populated");
} else if (fieldNames != null) {
if(positiveResponse != null) {
positiveResponses = positiveResponses == null ? Collections.nCopies(fieldNames.size(), positiveResponse) : positiveResponses;
if(positiveResponses.size() != fieldNames.size()) {
throw new PropertyException(configName, "positiveResponses", "must either be empty or match the length of fieldNames");
}
} else {
throw new PropertyException(configName, "positiveResponse, positiveResponses", "one of positiveResponse or positiveResponses must be populated");
}
} else if (fieldName != null) {
if(positiveResponses != null) {
throw new PropertyException(configName, "positiveResponses", "if fieldName is populated, positiveResponses must be blank");
}
fieldNames = Collections.singletonList(fieldName);
if(positiveResponse != null) {
positiveResponses = Collections.singletonList(positiveResponse);
} else {
throw new PropertyException(configName, "positiveResponse", "if fieldName is populated positiveResponse must be populated");
}
} else {
throw new PropertyException(configName, "fieldName, fieldNames", "One of fieldName or fieldNames must be populated");
}
}

/**
* for OLCUT.
Expand All @@ -54,25 +103,87 @@ private BinaryResponseProcessor() {}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values.
* and a negative value for all other field values. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME}
* for negative outputs.
* @param fieldName The field name to read.
* @param positiveResponse The positive response to look for.
* @param outputFactory The output factory to use.
*/
public BinaryResponseProcessor(String fieldName, String positiveResponse, OutputFactory<T> outputFactory) {
this.fieldName = fieldName;
this.positiveResponse = positiveResponse;
this(Collections.singletonList(fieldName), positiveResponse, outputFactory);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME}
* for negative outputs.
* @param fieldNames The field names to read.
* @param positiveResponse The positive response to look for.
* @param outputFactory The output factory to use.
*/
public BinaryResponseProcessor(List<String> fieldNames, String positiveResponse, OutputFactory<T> outputFactory) {
this(fieldNames, Collections.nCopies(fieldNames.size(), positiveResponse), outputFactory);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. The lengths of fieldNames and positiveResponses
* must be the same. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME}
* for negative outputs.
* @param fieldNames The field names to read.
* @param positiveResponses The positive responses to look for.
* @param outputFactory The output factory to use.
*/
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory) {
this(fieldNames, positiveResponses, outputFactory, false);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. The lengths of fieldNames and positiveResponses
* must be the same. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME}
* for negative outputs.
* @param fieldNames The field names to read.
* @param positiveResponses The positive responses to look for.
* @param outputFactory The output factory to use.
* @param displayField whether to include field names in the generated labels.
*/
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory, boolean displayField) {
this(fieldNames, positiveResponses, outputFactory, POSITIVE_NAME, NEGATIVE_NAME, displayField);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. The lengths of fieldNames and positiveResponses
* must be the same.
* @param fieldNames The field names to read.
* @param positiveResponses The positive responses to look for.
* @param outputFactory The output factory to use.
* @param positiveName The value of a 'positive' output
* @param negativeName the value of a 'negative' output
* @param displayField whether to include field names in the generated labels.
*/
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory, String positiveName, String negativeName, boolean displayField) {
if(fieldNames.size() != positiveResponses.size()) {
throw new IllegalArgumentException("fieldNames and positiveResponses must be the same length");
}
this.fieldNames = fieldNames;
this.positiveResponses = positiveResponses;
this.outputFactory = outputFactory;
this.positiveName = positiveName;
this.negativeName = negativeName;
this.displayField = displayField;
}

@Override
public OutputFactory<T> getOutputFactory() {
return outputFactory;
}

@Deprecated
@Override
public String getFieldName() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we tag this method deprecated too? I think that means that Maven doesn't warn you for using a deprecated method.

return fieldName;
return fieldNames.get(0);
}

@Deprecated
Expand All @@ -83,12 +194,33 @@ public void setFieldName(String fieldName) {

@Override
public Optional<T> process(String value) {
return Optional.of(outputFactory.generateOutput(positiveResponse.equals(value) ? positiveName : negativeName));
return process(Collections.singletonList(value));
}

@Override
public Optional<T> process(List<String> values) {
if(values.size() != fieldNames.size()) {
throw new IllegalArgumentException("values must have the same length as fieldNames. Got values: " + values.size() + " fieldNames: " + fieldNames.size());
}
List<String> responses = new ArrayList<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should check that values.size()==fieldNames.size()?

String prefix = "";
for(int i=0; i < fieldNames.size(); i++) {
if(displayField) {
prefix = fieldNames.get(i) + "=";
}
responses.add(prefix + (positiveResponses.get(i).equals(values.get(i)) ? positiveName : negativeName));
}
return Optional.of(outputFactory.generateOutput(fieldNames.size() == 1 ? responses.get(0) : responses));
}

@Override
public List<String> getFieldNames() {
return fieldNames;
}

@Override
public String toString() {
return "BinaryResponseProcessor(fieldName="+ fieldName +", positiveResponse="+ positiveResponse +", positiveName="+positiveName +", negativeName="+negativeName+")";
return "BinaryResponseProcessor(fieldNames="+ fieldNames.toString() +", positiveResponses="+ positiveResponses.toString() +", positiveName="+positiveName +", negativeName="+negativeName+")";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.tribuo.OutputFactory;
import org.tribuo.data.columnar.ResponseProcessor;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
Expand Down Expand Up @@ -84,6 +86,16 @@ public Optional<T> process(String value) {
return Optional.empty();
}

@Override
public Optional<T> process(List<String> values) {
return Optional.empty();
}

@Override
public List<String> getFieldNames() {
return Collections.singletonList(FIELD_NAME);
}

@Override
public String toString() {
return "EmptyResponseProcessor(outputFactory="+outputFactory.toString()+")";
Expand Down
Loading