From 5d6b60e0cbf17aaf8d5b4c9775d6863dcc01a809 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 14 Dec 2022 10:38:40 -0500 Subject: [PATCH] Improving the multiple response documentation in CSVLoader, adding a test for the documented behaviour and fixing a small toString bug in MockMultiOutput and MultiLabel. (#306) --- .../java/org/tribuo/test/MockMultiOutput.java | 10 +++--- .../java/org/tribuo/data/csv/CSVLoader.java | 32 ++++++++++++++++++- .../org/tribuo/data/csv/CSVLoaderTest.java | 12 +++++-- .../csv/test-multioutput-singlecolumn.csv | 7 ++++ .../org/tribuo/multilabel/MultiLabel.java | 10 +++--- 5 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 Data/src/test/resources/org/tribuo/data/csv/test-multioutput-singlecolumn.csv diff --git a/Core/src/test/java/org/tribuo/test/MockMultiOutput.java b/Core/src/test/java/org/tribuo/test/MockMultiOutput.java index 9fce1691c..4cb375cc4 100644 --- a/Core/src/test/java/org/tribuo/test/MockMultiOutput.java +++ b/Core/src/test/java/org/tribuo/test/MockMultiOutput.java @@ -231,11 +231,13 @@ public String toString() { StringBuilder builder = new StringBuilder(); builder.append("(LabelSet={"); - for (MockOutput l : labels) { - builder.append(l.toString()); - builder.append(','); + if (labels.size() > 0) { + for (MockOutput l : labels) { + builder.append(l.toString()); + builder.append(','); + } + builder.deleteCharAt(builder.length() - 1); } - builder.deleteCharAt(builder.length()-1); builder.append('}'); if (!Double.isNaN(score)) { builder.append(",OverallScore="); diff --git a/Data/src/main/java/org/tribuo/data/csv/CSVLoader.java b/Data/src/main/java/org/tribuo/data/csv/CSVLoader.java index b50b7127b..7b5cd3886 100644 --- a/Data/src/main/java/org/tribuo/data/csv/CSVLoader.java +++ b/Data/src/main/java/org/tribuo/data/csv/CSVLoader.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,6 +68,12 @@ * {@link org.tribuo.data.columnar.RowProcessor} to cope with your specific input format. *

* CSVLoader is thread safe and immutable. + *

+ * Multi-output responses such as {@code MultiLabel} or {@code Regressor} can be processed in + * two different ways either as a single column of separated values, or multiple columns. If + * there is a single column the value is passed directly to the {@link OutputFactory}. If + * there are multiple response columns then the name of the column is concatenated with the + * value, then a list of the concatenated values is passed to the {@link OutputFactory}. * @param The type of the output generated. */ public class CSVLoader> { @@ -139,6 +145,10 @@ public MutableDataset load(Path csvPath, String responseName, String[] header *

* The {@code responseNames} set is traversed in iteration order to emit outputs, * and should be an ordered set to ensure reproducibility. + *

+ * If there are multiple elements in {@code responseNames} then the responses are + * processed into the form 'column-name=column-value' before being passed to the + * {@link OutputFactory} for conversion into an {@link Output}. * * @param csvPath The path to load. * @param responseNames The names of the response variables. @@ -154,6 +164,10 @@ public MutableDataset load(Path csvPath, Set responseNames) throws IO *

* The {@code responseNames} set is traversed in iteration order to emit outputs, * and should be an ordered set to ensure reproducibility. + *

+ * If there are multiple elements in {@code responseNames} then the responses are + * processed into the form 'column-name=column-value' before being passed to the + * {@link OutputFactory} for conversion into an {@link Output}. * * @param csvPath The path to load. * @param responseNames The names of the response variables. @@ -220,6 +234,10 @@ public DataSource loadDataSource(URL csvPath, String responseName, String[] h *

* The {@code responseNames} set is traversed in iteration order to emit outputs, * and should be an ordered set to ensure reproducibility. + *

+ * If there are multiple elements in {@code responseNames} then the responses are + * processed into the form 'column-name=column-value' before being passed to the + * {@link OutputFactory} for conversion into an {@link Output}. * * @param csvPath The csv to load from. * @param responseNames The names of the response variables. @@ -235,6 +253,10 @@ public DataSource loadDataSource(Path csvPath, Set responseNames) thr *

* The {@code responseNames} set is traversed in iteration order to emit outputs, * and should be an ordered set to ensure reproducibility. + *

+ * If there are multiple elements in {@code responseNames} then the responses are + * processed into the form 'column-name=column-value' before being passed to the + * {@link OutputFactory} for conversion into an {@link Output}. * * @param csvPath The csv to load from. * @param responseNames The names of the response variables. @@ -250,6 +272,10 @@ public DataSource loadDataSource(URL csvPath, Set responseNames) thro *

* The {@code responseNames} set is traversed in iteration order to emit outputs, * and should be an ordered set to ensure reproducibility. + *

+ * If there are multiple elements in {@code responseNames} then the responses are + * processed into the form 'column-name=column-value' before being passed to the + * {@link OutputFactory} for conversion into an {@link Output}. * * @param csvPath The csv to load from. * @param responseNames The names of the response variables. @@ -266,6 +292,10 @@ public DataSource loadDataSource(Path csvPath, Set responseNames, Str *

* The {@code responseNames} set is traversed in iteration order to emit outputs, * and should be an ordered set to ensure reproducibility. + *

+ * If there are multiple elements in {@code responseNames} then the responses are + * processed into the form 'column-name=column-value' before being passed to the + * {@link OutputFactory} for conversion into an {@link Output}. * * @param csvPath The csv to load from. * @param responseNames The names of the response variables. diff --git a/Data/src/test/java/org/tribuo/data/csv/CSVLoaderTest.java b/Data/src/test/java/org/tribuo/data/csv/CSVLoaderTest.java index ffb9d61c1..865bcdc15 100644 --- a/Data/src/test/java/org/tribuo/data/csv/CSVLoaderTest.java +++ b/Data/src/test/java/org/tribuo/data/csv/CSVLoaderTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,13 +82,21 @@ public void testLoadMultiOutput() throws IOException { assertTrue(data.getExample(1).getOutput().contains("R1")); assertTrue(data.getExample(1).getOutput().contains("R2")); - // // Row #2: R1=False and R2=False. // In this case, the labelSet is empty and the labelString is the empty string. assertEquals(0, data.getExample(2).getOutput().getLabelSet().size()); assertEquals("", data.getExample(2).getOutput().getLabelString()); assertTrue(data.getExample(2).validateExample()); + + URL singlePath = CSVLoaderTest.class.getResource("/org/tribuo/data/csv/test-multioutput-singlecolumn.csv"); + DataSource singleSource = loader.loadDataSource(singlePath, "Label"); + MutableDataset singleData = new MutableDataset<>(singleSource); + assertEquals(6, singleData.size()); + + for (int i = 0; i < 6; i++) { + assertEquals(data.getExample(i).getOutput().getLabelString(), singleData.getExample(i).getOutput().getLabelString()); + } } @Test diff --git a/Data/src/test/resources/org/tribuo/data/csv/test-multioutput-singlecolumn.csv b/Data/src/test/resources/org/tribuo/data/csv/test-multioutput-singlecolumn.csv new file mode 100644 index 000000000..52b98bfb4 --- /dev/null +++ b/Data/src/test/resources/org/tribuo/data/csv/test-multioutput-singlecolumn.csv @@ -0,0 +1,7 @@ +A,B,C,D,Label +1,2,3,4,"R1" +6,7,8,9,"R1,R2" +6,7,8,9, +2,5,3,4,"R1" +1,2,5,9,"R2" +0,2,5,9,"R2" diff --git a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/MultiLabel.java b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/MultiLabel.java index 796a6a799..7dd4d7717 100644 --- a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/MultiLabel.java +++ b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/MultiLabel.java @@ -314,11 +314,13 @@ public String toString() { StringBuilder builder = new StringBuilder(); builder.append("(LabelSet={"); - for (Label l : labels) { - builder.append(l.toString()); - builder.append(','); + if (labels.size() > 0) { + for (Label l : labels) { + builder.append(l.toString()); + builder.append(','); + } + builder.deleteCharAt(builder.length() - 1); } - builder.deleteCharAt(builder.length()-1); builder.append('}'); if (!Double.isNaN(score)) { builder.append(",OverallScore=");