-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for multioutputs to ResponseProcesser, with tests. #150
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
acc4edf
Added support for multioutputs to ResponseProcesser, with tests.
JackSullivan a01be13
Fixed an issue where ResponseProcessors could be fed nulls from `RowP…
JackSullivan 201e4f6
changed header/value delimiters on multilabel `ResponseProcessor`s to…
JackSullivan c02ec67
added extra tests for missing values
JackSullivan 749ceeb
updates based on feedback
JackSullivan 8c28789
fixing more change requests I missed the first time
JackSullivan 11a78eb
further updates based on feedback
JackSullivan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,12 +17,17 @@ | |
package org.tribuo.data.columnar.processors.response; | ||
|
||
import com.oracle.labs.mlrg.olcut.config.Config; | ||
import com.oracle.labs.mlrg.olcut.config.ConfigurableName; | ||
import com.oracle.labs.mlrg.olcut.config.PropertyException; | ||
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; | ||
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; | ||
import org.tribuo.Output; | ||
import org.tribuo.OutputFactory; | ||
import org.tribuo.data.columnar.ResponseProcessor; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Optional; | ||
|
||
/** | ||
|
@@ -32,20 +37,64 @@ | |
*/ | ||
public class BinaryResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> { | ||
|
||
@Config(mandatory = true,description="The field name to read.") | ||
@Config(description="The field name to read, you should use only one of this or fieldNames") | ||
@Deprecated | ||
private String fieldName; | ||
|
||
@Config(mandatory = true,description="The string which triggers a positive response.") | ||
@Config(description="The string which triggers a positive response.") | ||
private String positiveResponse; | ||
|
||
@Config(mandatory = true,description="Output factory to use to create the response.") | ||
@Config(mandatory = true, description="Output factory to use to create the response.") | ||
private OutputFactory<T> outputFactory; | ||
|
||
public static final String POSITIVE_NAME = "1"; | ||
public static final String NEGATIVE_NAME = "0"; | ||
|
||
@Config(description="The positive response to emit.") | ||
private String positiveName = "1"; | ||
private String positiveName = POSITIVE_NAME; | ||
|
||
@Config(description="The negative response to emit.") | ||
private String negativeName = "0"; | ||
private String negativeName = NEGATIVE_NAME; | ||
|
||
@Config(description = "A list of field names to read, you should use only one of this or fieldName.") | ||
private List<String> fieldNames; | ||
|
||
@Config(description = "A list of strings that trigger positive responses; it should be the same length as fieldNames or empty") | ||
private List<String> positiveResponses; | ||
|
||
@Config(description = "Whether to display field names as part of the generated output, defaults to false") | ||
private boolean displayField; | ||
|
||
@ConfigurableName | ||
private String configName; | ||
|
||
@Override | ||
public void postConfig() { | ||
if (fieldName != null && fieldNames != null) { // we can only have one path | ||
throw new PropertyException(configName, "fieldName, FieldNames", "only one of fieldName or fieldNames can be populated"); | ||
} else if (fieldNames != null) { | ||
if(positiveResponse != null) { | ||
positiveResponses = positiveResponses == null ? Collections.nCopies(fieldNames.size(), positiveResponse) : positiveResponses; | ||
if(positiveResponses.size() != fieldNames.size()) { | ||
throw new PropertyException(configName, "positiveResponses", "must either be empty or match the length of fieldNames"); | ||
} | ||
} else { | ||
throw new PropertyException(configName, "positiveResponse, positiveResponses", "one of positiveResponse or positiveResponses must be populated"); | ||
} | ||
} else if (fieldName != null) { | ||
if(positiveResponses != null) { | ||
throw new PropertyException(configName, "positiveResponses", "if fieldName is populated, positiveResponses must be blank"); | ||
} | ||
fieldNames = Collections.singletonList(fieldName); | ||
if(positiveResponse != null) { | ||
positiveResponses = Collections.singletonList(positiveResponse); | ||
} else { | ||
throw new PropertyException(configName, "positiveResponse", "if fieldName is populated positiveResponse must be populated"); | ||
} | ||
} else { | ||
throw new PropertyException(configName, "fieldName, fieldNames", "One of fieldName or fieldNames must be populated"); | ||
} | ||
} | ||
|
||
/** | ||
* for OLCUT. | ||
|
@@ -54,25 +103,87 @@ private BinaryResponseProcessor() {} | |
|
||
/** | ||
* Constructs a binary response processor which emits a positive value for a single string | ||
* and a negative value for all other field values. | ||
* and a negative value for all other field values. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME} | ||
* for negative outputs. | ||
* @param fieldName The field name to read. | ||
* @param positiveResponse The positive response to look for. | ||
* @param outputFactory The output factory to use. | ||
*/ | ||
public BinaryResponseProcessor(String fieldName, String positiveResponse, OutputFactory<T> outputFactory) { | ||
this.fieldName = fieldName; | ||
this.positiveResponse = positiveResponse; | ||
this(Collections.singletonList(fieldName), positiveResponse, outputFactory); | ||
} | ||
|
||
/** | ||
* Constructs a binary response processor which emits a positive value for a single string | ||
* and a negative value for all other field values. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME} | ||
* for negative outputs. | ||
* @param fieldNames The field names to read. | ||
* @param positiveResponse The positive response to look for. | ||
* @param outputFactory The output factory to use. | ||
*/ | ||
public BinaryResponseProcessor(List<String> fieldNames, String positiveResponse, OutputFactory<T> outputFactory) { | ||
this(fieldNames, Collections.nCopies(fieldNames.size(), positiveResponse), outputFactory); | ||
} | ||
|
||
/** | ||
* Constructs a binary response processor which emits a positive value for a single string | ||
* and a negative value for all other field values. The lengths of fieldNames and positiveResponses | ||
* must be the same. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME} | ||
* for negative outputs. | ||
* @param fieldNames The field names to read. | ||
* @param positiveResponses The positive responses to look for. | ||
* @param outputFactory The output factory to use. | ||
*/ | ||
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory) { | ||
this(fieldNames, positiveResponses, outputFactory, false); | ||
} | ||
|
||
/** | ||
* Constructs a binary response processor which emits a positive value for a single string | ||
* and a negative value for all other field values. The lengths of fieldNames and positiveResponses | ||
* must be the same. Defaults to {@link #POSITIVE_NAME} for positive outputs and {@link #NEGATIVE_NAME} | ||
* for negative outputs. | ||
* @param fieldNames The field names to read. | ||
* @param positiveResponses The positive responses to look for. | ||
* @param outputFactory The output factory to use. | ||
* @param displayField whether to include field names in the generated labels. | ||
*/ | ||
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory, boolean displayField) { | ||
this(fieldNames, positiveResponses, outputFactory, POSITIVE_NAME, NEGATIVE_NAME, displayField); | ||
} | ||
|
||
/** | ||
* Constructs a binary response processor which emits a positive value for a single string | ||
* and a negative value for all other field values. The lengths of fieldNames and positiveResponses | ||
* must be the same. | ||
* @param fieldNames The field names to read. | ||
* @param positiveResponses The positive responses to look for. | ||
* @param outputFactory The output factory to use. | ||
* @param positiveName The value of a 'positive' output | ||
* @param negativeName the value of a 'negative' output | ||
* @param displayField whether to include field names in the generated labels. | ||
*/ | ||
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory, String positiveName, String negativeName, boolean displayField) { | ||
if(fieldNames.size() != positiveResponses.size()) { | ||
throw new IllegalArgumentException("fieldNames and positiveResponses must be the same length"); | ||
} | ||
this.fieldNames = fieldNames; | ||
this.positiveResponses = positiveResponses; | ||
this.outputFactory = outputFactory; | ||
this.positiveName = positiveName; | ||
this.negativeName = negativeName; | ||
this.displayField = displayField; | ||
} | ||
|
||
@Override | ||
public OutputFactory<T> getOutputFactory() { | ||
return outputFactory; | ||
} | ||
|
||
@Deprecated | ||
@Override | ||
public String getFieldName() { | ||
return fieldName; | ||
return fieldNames.get(0); | ||
} | ||
|
||
@Deprecated | ||
|
@@ -83,12 +194,33 @@ public void setFieldName(String fieldName) { | |
|
||
@Override | ||
public Optional<T> process(String value) { | ||
return Optional.of(outputFactory.generateOutput(positiveResponse.equals(value) ? positiveName : negativeName)); | ||
return process(Collections.singletonList(value)); | ||
} | ||
|
||
@Override | ||
public Optional<T> process(List<String> values) { | ||
if(values.size() != fieldNames.size()) { | ||
throw new IllegalArgumentException("values must have the same length as fieldNames. Got values: " + values.size() + " fieldNames: " + fieldNames.size()); | ||
} | ||
List<String> responses = new ArrayList<>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should check that |
||
String prefix = ""; | ||
for(int i=0; i < fieldNames.size(); i++) { | ||
if(displayField) { | ||
prefix = fieldNames.get(i) + "="; | ||
} | ||
responses.add(prefix + (positiveResponses.get(i).equals(values.get(i)) ? positiveName : negativeName)); | ||
} | ||
return Optional.of(outputFactory.generateOutput(fieldNames.size() == 1 ? responses.get(0) : responses)); | ||
} | ||
|
||
@Override | ||
public List<String> getFieldNames() { | ||
return fieldNames; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "BinaryResponseProcessor(fieldName="+ fieldName +", positiveResponse="+ positiveResponse +", positiveName="+positiveName +", negativeName="+negativeName+")"; | ||
return "BinaryResponseProcessor(fieldNames="+ fieldNames.toString() +", positiveResponses="+ positiveResponses.toString() +", positiveName="+positiveName +", negativeName="+negativeName+")"; | ||
} | ||
|
||
@Override | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we tag this method deprecated too? I think that means that Maven doesn't warn you for using a deprecated method.