From 1a4f6e87f540f81bf485be6df6a3a82871696bdc Mon Sep 17 00:00:00 2001 From: Romina Mahinpei Date: Tue, 12 Jul 2022 13:01:09 -0700 Subject: [PATCH] Add ModelCard infrastructure (including implementation and testing) (#243) * Add ModelCard infrastructure (including implementation and testing) Signed-off-by: Romina Mahinpei * Fix dependencies and language level of ModelCard Signed-off-by: Romina Mahinpei * Fix dependency issues Signed-off-by: Romina Mahinpei * Add license header to all new files Signed-off-by: Romina Mahinpei * Fix formatting issues Signed-off-by: Romina Mahinpei * Fix dependency issues Signed-off-by: Romina Mahinpei * Make all classes constituting ModelCard final Signed-off-by: Romina Mahinpei * Replace assert statements with assertEquals Signed-off-by: Romina Mahinpei * Use Path object instead of String and use temp files when testing Signed-off-by: Romina Mahinpei * Have list items display to user 0-indexed Signed-off-by: Romina Mahinpei * Make schema version a static final string Signed-off-by: Romina Mahinpei * Use self-generated testing data files Signed-off-by: Romina Mahinpei * Refactor constructor of TrainingDetails Signed-off-by: Romina Mahinpei * Switch field type of configuredParams from Map to JsonNode Signed-off-by: Romina Mahinpei * Refactor code for checking whether a value is numeric and use entrySet Signed-off-by: Romina Mahinpei * Implement equals and hashCode in all ModelCard classes Signed-off-by: Romina Mahinpei * Implement deserializeFromJson Signed-off-by: Romina Mahinpei * Make mapper package private Signed-off-by: Romina Mahinpei * Accept File as a type in UsageDetails Signed-off-by: Romina Mahinpei * Add more detail to command explanation in UsageDetails Signed-off-by: Romina Mahinpei * Return an unmodifiable view of collections in getters Signed-off-by: Romina Mahinpei * Fix styling issues Signed-off-by: Romina Mahinpei * Flatten configuredParams map Signed-off-by: Romina Mahinpei * Separate UsageDetails and ModelCardCLI and create UsageDetailsBuilder Signed-off-by: Romina Mahinpei * Remove addMetric from TestingDetails to make the class immutable Signed-off-by: Romina Mahinpei * Enable deserialization of model cards that do not have UsageDetails specified Signed-off-by: Romina Mahinpei * Remove redundant code from constructors and add license header to all files Signed-off-by: Romina Mahinpei * Include UsageDetails in json schema for all possible scenarios Signed-off-by: Romina Mahinpei * Allow user to use the same shell to write the UsageDetails for different model cards Signed-off-by: Romina Mahinpei * Remove processNestedParams method to avoid redundancy Signed-off-by: Romina Mahinpei * Return an unmodifiable view of a map Signed-off-by: Romina Mahinpei --- Interop/ModelCard/pom.xml | 122 ++++++++ .../tribuo/interop/modelcard/ModelCard.java | 138 +++++++++ .../interop/modelcard/ModelCardCLI.java | 285 +++++++++++++++++ .../interop/modelcard/ModelDetails.java | 117 +++++++ .../interop/modelcard/TestingDetails.java | 102 +++++++ .../interop/modelcard/TrainingDetails.java | 151 +++++++++ .../interop/modelcard/UsageDetails.java | 213 +++++++++++++ .../modelcard/UsageDetailsBuilder.java | 163 ++++++++++ .../interop/modelcard/EnsembleModelsTest.java | 288 ++++++++++++++++++ .../interop/modelcard/ExternalModelsTest.java | 68 +++++ .../interop/modelcard/NativeModelsTest.java | 207 +++++++++++++ .../resources/classificationSampleData.csv | 11 + .../src/test/resources/externalModelPath.xgb | Bin 0 -> 631604 bytes .../resources/externalModelSampleData.csv | 11 + .../multiClassificationSampleData.svm | 10 + .../test/resources/regressionSampleData.csv | 11 + Interop/pom.xml | 12 + 17 files changed, 1909 insertions(+) create mode 100644 Interop/ModelCard/pom.xml create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCard.java create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCardCLI.java create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelDetails.java create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TestingDetails.java create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TrainingDetails.java create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetails.java create mode 100644 Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetailsBuilder.java create mode 100644 Interop/ModelCard/src/test/java/org/tribuo/interop/modelcard/EnsembleModelsTest.java create mode 100644 Interop/ModelCard/src/test/java/org/tribuo/interop/modelcard/ExternalModelsTest.java create mode 100644 Interop/ModelCard/src/test/java/org/tribuo/interop/modelcard/NativeModelsTest.java create mode 100644 Interop/ModelCard/src/test/resources/classificationSampleData.csv create mode 100644 Interop/ModelCard/src/test/resources/externalModelPath.xgb create mode 100644 Interop/ModelCard/src/test/resources/externalModelSampleData.csv create mode 100644 Interop/ModelCard/src/test/resources/multiClassificationSampleData.svm create mode 100644 Interop/ModelCard/src/test/resources/regressionSampleData.csv diff --git a/Interop/ModelCard/pom.xml b/Interop/ModelCard/pom.xml new file mode 100644 index 000000000..14780e8aa --- /dev/null +++ b/Interop/ModelCard/pom.xml @@ -0,0 +1,122 @@ + + + + tribuo-interop + org.tribuo + 4.3.0-SNAPSHOT + + 4.0.0 + ModelCard + tribuo-modelcard + + + 17 + 17 + 17 + + + + + + org.tribuo + tribuo-core + ${project.version} + + + org.tribuo + tribuo-data + ${project.version} + + + org.tribuo + tribuo-interop-core + ${project.version} + + + org.tribuo + tribuo-regression-core + ${project.version} + + + org.tribuo + tribuo-json + ${project.version} + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + org.tribuo + tribuo-classification-core + ${project.version} + test + + + org.tribuo + tribuo-classification-libsvm + ${project.version} + test + + + org.tribuo + tribuo-classification-sgd + ${project.version} + test + + + org.tribuo + tribuo-common-sgd + ${project.version} + test + + + org.tribuo + tribuo-multilabel-core + ${project.version} + test + + + org.tribuo + tribuo-multilabel-sgd + ${project.version} + test + + + org.tribuo + tribuo-regression-sgd + ${project.version} + test + + + org.tribuo + tribuo-classification-xgboost + ${project.version} + test + + + org.tribuo + tribuo-anomaly-libsvm + ${project.version} + test + + + org.tribuo + tribuo-clustering-core + ${project.version} + test + + + org.tribuo + tribuo-clustering-kmeans + ${project.version} + test + + + + \ No newline at end of file diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCard.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCard.java new file mode 100644 index 000000000..1f8f9ba0b --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCard.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.interop.modelcard; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.tribuo.Model; +import org.tribuo.evaluation.Evaluation; +import org.tribuo.interop.ExternalModel; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +public class ModelCard { + static final ObjectMapper mapper = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT); + private final ModelDetails modelDetails; + private final TrainingDetails trainingDetails; + private final TestingDetails testingDetails; + private final UsageDetails usageDetails; + + public ModelCard(Model model, Evaluation evaluation, Map testingMetrics, UsageDetails usage) { + if (model instanceof ExternalModel) { + throw new IllegalArgumentException("External models currently not supported by ModelCard."); + } + modelDetails = new ModelDetails(model); + trainingDetails = new TrainingDetails(model); + testingDetails = new TestingDetails(evaluation, testingMetrics); + usageDetails = usage; + } + + public ModelCard(Model model, Evaluation evaluation, UsageDetails usage) { + this(model, evaluation, Collections.emptyMap(), usage); + } + + public ModelCard(Model model, Evaluation evaluation, Map testingMetrics) { + this(model, evaluation, testingMetrics, null); + } + + public ModelCard(Model model, Evaluation evaluation) { + this(model, evaluation, Collections.emptyMap(), null); + } + + private ModelCard(JsonNode modelCard) throws JsonProcessingException { + modelDetails = new ModelDetails(modelCard.get("ModelDetails")); + trainingDetails = new TrainingDetails(modelCard.get("TrainingDetails")); + testingDetails = new TestingDetails(modelCard.get("TestingDetails")); + if (modelCard.get("UsageDetails").isNull()) { + usageDetails = null; + } else { + usageDetails = new UsageDetails(modelCard.get("UsageDetails")); + } + } + + public static ModelCard deserializeFromJson(Path sourceFile) throws IOException { + JsonNode modelCard = mapper.readTree(sourceFile.toFile()); + return new ModelCard(modelCard); + } + + public static ModelCard deserializeFromJson(JsonNode modelCard) throws JsonProcessingException { + return new ModelCard(modelCard); + } + + public ModelDetails getModelDetails() { + return modelDetails; + } + + public TrainingDetails getTrainingDetails() { + return trainingDetails; + } + + public TestingDetails getTestingDetails() { + return testingDetails; + } + + public UsageDetails getUsageDetails() { + return usageDetails; + } + + public ObjectNode toJson() { + ObjectNode modelCardObject = mapper.createObjectNode(); + modelCardObject.set("ModelDetails", modelDetails.toJson()); + modelCardObject.set("TrainingDetails", trainingDetails.toJson()); + modelCardObject.set("TestingDetails", testingDetails.toJson()); + if (usageDetails != null) { + modelCardObject.set("UsageDetails", usageDetails.toJson()); + } else { + modelCardObject.putNull("UsageDetails"); + } + return modelCardObject; + } + + public void saveToFile(Path destinationFile) throws IOException { + ObjectNode modelCardObject = toJson(); + mapper.writeValue(destinationFile.toFile(), modelCardObject); + } + + @Override + public String toString() { + return toJson().toPrettyString(); + } + + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ModelCard modelCard = (ModelCard) o; + return modelDetails.equals(modelCard.modelDetails) && + trainingDetails.equals(modelCard.trainingDetails) && + testingDetails.equals(modelCard.testingDetails) && + Objects.equals(usageDetails, modelCard.usageDetails); + } + + @Override + public int hashCode() { + return Objects.hash(modelDetails, trainingDetails, testingDetails, usageDetails); + } +} diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCardCLI.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCardCLI.java new file mode 100644 index 000000000..cd08cd46e --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelCardCLI.java @@ -0,0 +1,285 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tribuo.interop.modelcard; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.oracle.labs.mlrg.olcut.command.Command; +import com.oracle.labs.mlrg.olcut.command.CommandGroup; +import com.oracle.labs.mlrg.olcut.command.CommandInterpreter; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.tribuo.interop.modelcard.ModelCard.mapper; + +public class ModelCardCLI implements CommandGroup { + private final CommandInterpreter shell = new CommandInterpreter(); + private UsageDetailsBuilder builder = new UsageDetailsBuilder(); + private final List outOfScopeUses = new ArrayList<>(); + private final List preProcessingSteps = new ArrayList<>(); + private final List considerations = new ArrayList<>(); + private final List factors = new ArrayList<>(); + private final List resources = new ArrayList<>(); + + public void startShell() { + shell.setPrompt("CLI% "); + shell.add(this); + shell.start(); + } + + @Override + public String getName() { + return "ModelCardCLI"; + } + + @Override + public String getDescription() { + return "CLI for building a UsageDetails for a model card."; + } + + @Command( + usage = " Records intended use of model." + ) + public String intendedUse(CommandInterpreter ci, String use) { + builder.intendedUse(use); + return("Recorded intended use as " + use + "."); + } + + @Command( + usage = " Records intended users of model." + ) + public String intendedUsers(CommandInterpreter ci, String users) { + builder.intendedUsers(users); + return("Recorded intended users as " + users + "."); + } + + @Command( + usage = " Adds an out-of-scope use to list of out-of-scope uses." + ) + public String addOutOfScopeUse(CommandInterpreter ci, String use) { + outOfScopeUses.add(use); + return("Added an out-of-scope use to list of out-of-scope uses."); + } + + @Command( + usage = " Remove out-of-scope use at specified index (0-indexed)." + ) + public String removeOutOfScopeUse(CommandInterpreter ci, int index) { + outOfScopeUses.remove(index); + return("Removed out-of-scope use at specified index."); + } + + @Command( + usage = "Displays all added out-of-scope uses." + ) + public String viewOutOfScopeUse(CommandInterpreter ci) { + for (int i = 0; i < outOfScopeUses.size(); i++) { + System.out.println("\t" + i + ") "+ outOfScopeUses.get(i)); + } + return("Displayed all added out-of-scope uses."); + } + + @Command( + usage = " Adds pre-processing step to list of steps." + ) + public String addPreProcessingStep(CommandInterpreter ci, String step) { + preProcessingSteps.add(step); + return("Added pre-processing step to list of steps."); + } + + @Command( + usage = " Remove pro-processing step at specified index (0-indexed)." + ) + public String removePreProcessingStep(CommandInterpreter ci, int index) { + preProcessingSteps.remove(index); + return("Removed pre-processing step at specified index."); + } + + @Command( + usage = "Displays all added pre-processing steps." + ) + public String viewPreProcessingSteps(CommandInterpreter ci) { + for (int i = 0; i < preProcessingSteps.size(); i++) { + System.out.println("\t" + i + ") "+ preProcessingSteps.get(i)); + } + return("Displayed all added pre-processing steps."); + } + + @Command( + usage = " Adds consideration to list of considerations." + ) + public String addConsideration(CommandInterpreter ci, String consideration) { + considerations.add(consideration); + return("Added consideration to list of considerations."); + } + + @Command( + usage = " Remove consideration at specified index (0-indexed)." + ) + public String removeConsideration(CommandInterpreter ci, int index) { + considerations.remove(index); + return("Removed consideration at specified index."); + } + + @Command( + usage = "Displays all added considerations." + ) + public String viewConsiderations(CommandInterpreter ci) { + for (int i = 0; i < considerations.size(); i++) { + System.out.println("\t" + i + ") "+ considerations.get(i)); + } + return("Displayed all added considerations."); + } + + @Command( + usage = " Adds relevant factor to list of factors." + ) + public String addFactor(CommandInterpreter ci, String factor) { + factors.add(factor); + return("Added factor to list of factors."); + } + + @Command( + usage = " Remove factor at specified index (0-indexed)." + ) + public String removeFactor(CommandInterpreter ci, int index) { + factors.remove(index); + return("Removed factor at specified index."); + } + + @Command( + usage = "Displays all added factors." + ) + public String viewFactors(CommandInterpreter ci) { + for (int i = 0; i < factors.size(); i++) { + System.out.println("\t" + i + ") "+ factors.get(i)); + } + return("Displayed all added factors."); + } + + @Command( + usage = " Adds resource to list of resources." + ) + public String addResource(CommandInterpreter ci, String resource) { + resources.add(resource); + return("Added resource to list of resources."); + } + + @Command( + usage = " Remove resource at specified index (0-indexed)." + ) + public String removeResource(CommandInterpreter ci, int index) { + resources.remove(index); + return("Removed resource at specified index."); + } + + @Command( + usage = "Displays all added resources." + ) + public String viewResources(CommandInterpreter ci) { + for (int i = 0; i < resources.size(); i++) { + System.out.println("\t" + i + ") "+ resources.get(i)); + } + return("Displayed all added resources."); + } + + @Command( + usage = " Records primary contact in case of questions or comments." + ) + public String primaryContact(CommandInterpreter ci, String contact) { + builder.primaryContact(contact); + return("Recorded primary contact as " + contact + "."); + } + + @Command( + usage = " Records model's citation." + ) + public String modelCitation(CommandInterpreter ci, String citation) { + builder.modelCitation(citation); + return("Recorded model citation as " + citation + "."); + } + + @Command( + usage = " Records model's license." + ) + public String modelLicense(CommandInterpreter ci, String license) { + builder.modelLicense(license); + return("Recorded model license as " + license + "."); + } + + private UsageDetails createUsageDetails() { + builder.outOfScopeUses(outOfScopeUses); + builder.preProcessingSteps(preProcessingSteps); + builder.considerations(considerations); + builder.factors(factors); + builder.resources(resources); + return builder.build(); + } + + @Command( + usage = " Saves UsageDetails to an existing ModelCard file." + ) + public String saveUsageDetails(CommandInterpreter ci, File destinationFile) throws IOException { + UsageDetails usageDetails = createUsageDetails(); + + ObjectNode usageDetailsObject = usageDetails.toJson(); + ObjectNode modelCardObject = mapper.readValue(destinationFile, ObjectNode.class); + if (!modelCardObject.get("UsageDetails").isNull()) { + throw new IllegalArgumentException("This ModelCard already contains a UsageDetails."); + } + modelCardObject.set("UsageDetails", usageDetailsObject); + mapper.writeValue(destinationFile, modelCardObject); + + return "Saved UsageDetails to destination file."; + } + + @Command( + usage = "Removes all previously written fields for UsageDetails to write a new UsageDetails." + ) + public String newUsageDetails(CommandInterpreter ci) { + builder = new UsageDetailsBuilder(); + outOfScopeUses.clear(); + preProcessingSteps.clear(); + considerations.clear(); + factors.clear(); + resources.clear(); + return "Started a new UsageDetails."; + } + + @Command( + usage = "Displays current state of UsageDetails." + ) + public String viewUsageDetails(CommandInterpreter ci) { + System.out.println(createUsageDetails()); + return "Displayed current state of UsageDetails."; + } + + + @Command( + usage = "Closes CLI without explicitly saving anything recorded." + ) + public String close(CommandInterpreter ci) { + shell.close(); + return "Closed ClI."; + } + + public static void main(String[] args) { + ModelCardCLI driver = new ModelCardCLI(); + driver.startShell(); + } +} diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelDetails.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelDetails.java new file mode 100644 index 000000000..cfac6276e --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/ModelDetails.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.interop.modelcard; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; +import org.tribuo.Model; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +import static org.tribuo.interop.modelcard.ModelCard.mapper; + +public final class ModelDetails { + private static final String schemaVersion = "1.0"; + private final String modelType; + private final String modelPackage; + private final String tribuoVersion; + private final String javaVersion; + private final Map configuredParams; + + public ModelDetails(Model model) { + modelType = model.getClass().getSimpleName(); + modelPackage = model.getClass().getTypeName(); + tribuoVersion = model.getProvenance().getTribuoVersion(); + javaVersion = model.getProvenance().getJavaVersion(); + configuredParams = ProvenanceUtil.convertToMap(model.getProvenance().getTrainerProvenance()); + } + + public ModelDetails(JsonNode modelDetailsJson) throws JsonProcessingException { + modelType = modelDetailsJson.get("model-type").textValue(); + modelPackage = modelDetailsJson.get("model-package").textValue(); + tribuoVersion = modelDetailsJson.get("tribuo-version").textValue(); + javaVersion = modelDetailsJson.get("java-version").textValue(); + TypeReference> typeRef = new TypeReference<>() {}; + configuredParams = Collections.unmodifiableMap(mapper.readValue(modelDetailsJson.get("configured-parameters").toString(), typeRef)); + } + + public String getSchemaVersion() { + return schemaVersion; + } + + public String getModelType() { + return modelType; + } + + public String getModelPackage() { + return modelPackage; + } + + public String getTribuoVersion() { + return tribuoVersion; + } + + public String getJavaVersion() { + return javaVersion; + } + + public Map getConfiguredParams() { + return Collections.unmodifiableMap(configuredParams); + } + + public ObjectNode toJson() { + ObjectNode modelDetailsObject = mapper.createObjectNode(); + modelDetailsObject.put("schema-version", schemaVersion); + modelDetailsObject.put("model-type", modelType); + modelDetailsObject.put("model-package", modelPackage); + modelDetailsObject.put("tribuo-version", tribuoVersion); + modelDetailsObject.put("java-version", javaVersion); + modelDetailsObject.set("configured-parameters", mapper.convertValue(configuredParams, ObjectNode.class)); + return modelDetailsObject; + } + + @Override + public String toString() { + return toJson().toPrettyString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ModelDetails that = (ModelDetails) o; + return modelType.equals(that.modelType) && + modelPackage.equals(that.modelPackage) && + tribuoVersion.equals(that.tribuoVersion) && + javaVersion.equals(that.javaVersion) && + configuredParams.equals(that.configuredParams); + } + + @Override + public int hashCode() { + return Objects.hash(modelType, modelPackage, tribuoVersion, javaVersion, configuredParams); + } +} diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TestingDetails.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TestingDetails.java new file mode 100644 index 000000000..49c3f5900 --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TestingDetails.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.interop.modelcard; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.tribuo.evaluation.Evaluation; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.tribuo.interop.modelcard.ModelCard.mapper; + +public final class TestingDetails { + private static final String schemaVersion = "1.0"; + private final int testingSetSize; + private final Map metrics = new HashMap<>(); + + public TestingDetails(Evaluation evaluation, Map testingMetrics) { + testingSetSize = evaluation.getPredictions().size(); + metrics.putAll(testingMetrics); + } + + public TestingDetails(Evaluation evaluation) { + this(evaluation, Collections.emptyMap()); + } + + public TestingDetails(JsonNode testingDetailsJson) throws JsonProcessingException { + testingSetSize = testingDetailsJson.get("testing-set-size").intValue(); + Map parsed = mapper.readValue(testingDetailsJson.get("metrics").toString(), Map.class); + for (Map.Entry entry : parsed.entrySet()) { + metrics.put((String) entry.getKey(), (Double) entry.getValue()); + } + + } + + public String getSchemaVersion() { + return schemaVersion; + } + + public int getTestingSetSize() { + return testingSetSize; + } + + public Map getMetrics() { + return Collections.unmodifiableMap(metrics); + } + + public ObjectNode toJson() { + ObjectNode testingDetailsObject = mapper.createObjectNode(); + testingDetailsObject.put("schema-version", schemaVersion); + testingDetailsObject.put("testing-set-size", testingSetSize); + + ObjectNode testingMetricsObject = mapper.createObjectNode(); + for (Map.Entry entry : metrics.entrySet()) { + testingMetricsObject.put(entry.getKey(), entry.getValue()); + } + testingDetailsObject.set("metrics", testingMetricsObject); + + return testingDetailsObject; + } + + @Override + public String toString() { + return toJson().toPrettyString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingDetails that = (TestingDetails) o; + return testingSetSize == that.testingSetSize && metrics.equals(that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(testingSetSize, metrics); + } +} + diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TrainingDetails.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TrainingDetails.java new file mode 100644 index 000000000..0f09d99a6 --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/TrainingDetails.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.interop.modelcard; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.tribuo.Model; +import org.tribuo.regression.Regressor; + +import java.util.*; + +import static org.tribuo.interop.modelcard.ModelCard.mapper; + +public final class TrainingDetails { + public static final String schemaVersion = "1.0"; + private final String trainingTime; + private final int trainingSetSize; + private final int numFeatures; + private final List features = new ArrayList<>();; + private final int numOutputs; + private final Map outputsDistribution = new HashMap<>(); + + public TrainingDetails(Model model) { + trainingTime = model.getProvenance().getTrainingTime().toString(); + trainingSetSize = model.getProvenance().getDatasetProvenance().getNumExamples(); + + numFeatures = model.getProvenance().getDatasetProvenance().getNumFeatures(); + for (int i = 0; i < model.getFeatureIDMap().size(); i++) { + features.add(model.getFeatureIDMap().get(i).getName()); + } + + numOutputs = model.getProvenance().getDatasetProvenance().getNumOutputs(); + + if (!model.validate(Regressor.class)) { + for (var pair : model.getOutputIDInfo().outputCountsIterable()) { + outputsDistribution.put(pair.getA(), pair.getB()); + } + } + } + + public TrainingDetails(JsonNode trainingDetailsJson) throws JsonProcessingException { + trainingTime = trainingDetailsJson.get("training-time").textValue(); + trainingSetSize = trainingDetailsJson.get("training-set-size").intValue(); + + numFeatures = trainingDetailsJson.get("num-features").intValue(); + for (int i = 0; i < trainingDetailsJson.get("features-list").size(); i++) { + features.add(trainingDetailsJson.get("features-list").get(i).textValue()); + } + + numOutputs = trainingDetailsJson.get("num-outputs").intValue(); + Map parsed = mapper.readValue(trainingDetailsJson.get("outputs-distribution").toString(), Map.class); + for (Map.Entry entry : parsed.entrySet()) { + Integer val = (Integer) entry.getValue(); + outputsDistribution.put((String)entry.getKey(), val.longValue()); + } + } + + public String getSchemaVersion() { + return schemaVersion; + } + + public String getTrainingTime() { + return trainingTime; + } + + public int getTrainingSetSize() { + return trainingSetSize; + } + + public int getNumFeatures() { + return numFeatures; + } + + public List getFeatures() { + return Collections.unmodifiableList(features); + } + + public int getNumOutputs() { + return numOutputs; + } + + public Map getOutputsDistribution() { + return Collections.unmodifiableMap(outputsDistribution); + } + + public ObjectNode toJson() { + ObjectNode datasetDetailsObject = mapper.createObjectNode(); + datasetDetailsObject.put("schema-version", schemaVersion); + datasetDetailsObject.put("training-time", trainingTime); + datasetDetailsObject.put("training-set-size", trainingSetSize); + + datasetDetailsObject.put("num-features", numFeatures); + ArrayNode featuresArr = mapper.createArrayNode(); + for (String s : features) { + featuresArr.add(s); + } + datasetDetailsObject.set("features-list", featuresArr); + + datasetDetailsObject.put("num-outputs", numOutputs); + ObjectNode outputsArr = mapper.createObjectNode(); + for (Map.Entry entry : outputsDistribution.entrySet()) { + outputsArr.put(entry.getKey(), entry.getValue()); + } + datasetDetailsObject.set("outputs-distribution", outputsArr); + + return datasetDetailsObject; + } + + @Override + public String toString() { + return toJson().toPrettyString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TrainingDetails that = (TrainingDetails) o; + return trainingSetSize == that.trainingSetSize && + numFeatures == that.numFeatures && + numOutputs == that.numOutputs && + trainingTime.equals(that.trainingTime) && + features.equals(that.features) && + outputsDistribution.equals(that.outputsDistribution); + } + + @Override + public int hashCode() { + return Objects.hash(trainingTime, trainingSetSize, numFeatures, features, numOutputs, outputsDistribution); + } +} \ No newline at end of file diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetails.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetails.java new file mode 100644 index 000000000..15147ab88 --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetails.java @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.interop.modelcard; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.tribuo.interop.modelcard.ModelCard.mapper; + +public final class UsageDetails { + public static final String schemaVersion = "1.0"; + private final String intendedUse; + private final String intendedUsers; + private final List outOfScopeUses; + private final List preProcessingSteps; + private final List considerations; + private final List factors; + private final List resources; + private final String primaryContact; + private final String modelCitation; + private final String modelLicense; + + public UsageDetails( + String intendedUse, + String intendedUsers, + List outOfScopeUses, + List preProcessingSteps, + List considerations, + List factors, + List resources, + String primaryContact, + String modelCitation, + String modelLicense) + { + this.intendedUse = intendedUse; + this.intendedUsers = intendedUsers; + this.outOfScopeUses = outOfScopeUses; + this.preProcessingSteps = preProcessingSteps; + this.considerations = considerations; + this.factors = factors; + this.resources = resources; + this.primaryContact = primaryContact; + this.modelCitation = modelCitation; + this.modelLicense = modelLicense; + } + + public UsageDetails(JsonNode usageDetailsJson) { + intendedUse = usageDetailsJson.get("intended-use").textValue(); + intendedUsers = usageDetailsJson.get("intended-users").textValue(); + + outOfScopeUses = new ArrayList<>(); + for (int i = 0; i < usageDetailsJson.get("out-of-scope-uses").size(); i++) { + outOfScopeUses.add(usageDetailsJson.get("out-of-scope-uses").get(i).textValue()); + } + preProcessingSteps = new ArrayList<>(); + for (int i = 0; i < usageDetailsJson.get("pre-processing-steps").size(); i++) { + preProcessingSteps.add(usageDetailsJson.get("pre-processing-steps").get(i).textValue()); + } + considerations = new ArrayList<>(); + for (int i = 0; i < usageDetailsJson.get("considerations-list").size(); i++) { + considerations.add(usageDetailsJson.get("considerations-list").get(i).textValue()); + } + factors = new ArrayList<>(); + for (int i = 0; i < usageDetailsJson.get("relevant-factors-list").size(); i++) { + factors.add(usageDetailsJson.get("relevant-factors-list").get(i).textValue()); + } + resources = new ArrayList<>(); + for (int i = 0; i < usageDetailsJson.get("resources-list").size(); i++) { + resources.add(usageDetailsJson.get("resources-list").get(i).textValue()); + } + primaryContact = usageDetailsJson.get("primary-contact").textValue(); + modelCitation = usageDetailsJson.get("model-citation").textValue(); + modelLicense = usageDetailsJson.get("model-license").textValue(); + } + + public String getSchemaVersion() { + return schemaVersion; + } + + public String getIntendedUse() { + return intendedUse; + } + + public String getIntendedUsers() { + return intendedUsers; + } + + public List getOutOfScopeUses() { + return Collections.unmodifiableList(outOfScopeUses); + } + + public List getPreProcessingSteps() { + return Collections.unmodifiableList(preProcessingSteps); + } + + public List getConsiderations() { + return Collections.unmodifiableList(considerations); + } + + public List getFactors() { + return Collections.unmodifiableList(factors); + } + + public List getResources() { + return Collections.unmodifiableList(resources); + } + + public String getPrimaryContact() { + return primaryContact; + } + + public String getModelCitation() { + return modelCitation; + } + + public String getModelLicense() { + return modelLicense; + } + + public ObjectNode toJson() { + ObjectNode usageDetailsObject = mapper.createObjectNode(); + usageDetailsObject.put("schema-version", schemaVersion); + usageDetailsObject.put("intended-use", intendedUse); + usageDetailsObject.put("intended-users", intendedUsers); + + ArrayNode usesArr = mapper.createArrayNode(); + for (String s : outOfScopeUses) { + usesArr.add(s); + } + usageDetailsObject.set("out-of-scope-uses", usesArr); + + ArrayNode processingArr = mapper.createArrayNode(); + for (String s : preProcessingSteps) { + processingArr.add(s); + } + usageDetailsObject.set("pre-processing-steps", processingArr); + + ArrayNode considerationsArr = mapper.createArrayNode(); + for (String s : considerations) { + considerationsArr.add(s); + } + usageDetailsObject.set("considerations-list", considerationsArr); + + ArrayNode factorsArr = mapper.createArrayNode(); + for (String s : factors) { + factorsArr.add(s); + } + usageDetailsObject.set("relevant-factors-list", factorsArr); + + ArrayNode resourcesArr = mapper.createArrayNode(); + for (String s : resources) { + resourcesArr.add(s); + } + usageDetailsObject.set("resources-list", resourcesArr); + + usageDetailsObject.put("primary-contact", primaryContact); + usageDetailsObject.put("model-citation", modelCitation); + usageDetailsObject.put("model-license", modelLicense); + + return usageDetailsObject; + } + + @Override + public String toString() { + return toJson().toPrettyString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + UsageDetails that = (UsageDetails) o; + return intendedUse.equals(that.intendedUse) && + intendedUsers.equals(that.intendedUsers) && + outOfScopeUses.equals(that.outOfScopeUses) && + preProcessingSteps.equals(that.preProcessingSteps) && + considerations.equals(that.considerations) && + factors.equals(that.factors) && + resources.equals(that.resources) && + primaryContact.equals(that.primaryContact) && + modelCitation.equals(that.modelCitation) && + modelLicense.equals(that.modelLicense); + } + + @Override + public int hashCode() { + return Objects.hash(intendedUse, intendedUsers, outOfScopeUses, preProcessingSteps, considerations, factors, resources, primaryContact, modelCitation, modelLicense); + } +} \ No newline at end of file diff --git a/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetailsBuilder.java b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetailsBuilder.java new file mode 100644 index 000000000..560daac59 --- /dev/null +++ b/Interop/ModelCard/src/main/java/org/tribuo/interop/modelcard/UsageDetailsBuilder.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tribuo.interop.modelcard; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public final class UsageDetailsBuilder { + private String intendedUse = ""; + private String intendedUsers = ""; + private final List outOfScopeUses = new ArrayList<>(); + private final List preProcessingSteps = new ArrayList<>(); + private final List considerations = new ArrayList<>(); + private final List factors = new ArrayList<>(); + private final List resources = new ArrayList<>(); + private String primaryContact = ""; + private String modelCitation = ""; + private String modelLicense = ""; + + public UsageDetailsBuilder() { } + + public String getIntendedUse() { + return intendedUse; + } + + public String getIntendedUsers() { + return intendedUsers; + } + + public List getOutOfScopeUses() { + return Collections.unmodifiableList(outOfScopeUses); + } + + public List getPreProcessingSteps() { + return Collections.unmodifiableList(preProcessingSteps); + } + + public List getConsiderations() { + return Collections.unmodifiableList(considerations); + } + + public List getFactors() { + return Collections.unmodifiableList(factors); + } + + public List getResources() { + return Collections.unmodifiableList(resources); + } + + public String getPrimaryContact() { + return primaryContact; + } + + public String getModelCitation() { + return modelCitation; + } + + public String getModelLicense() { + return modelLicense; + } + + public UsageDetailsBuilder intendedUse(String intendedUse) { + this.intendedUse = intendedUse; + return this; + } + + public UsageDetailsBuilder intendedUsers(String intendedUsers) { + this.intendedUsers = intendedUsers; + return this; + } + + public UsageDetailsBuilder outOfScopeUses(List uses) { + this.outOfScopeUses.addAll(uses); + return this; + } + + public UsageDetailsBuilder preProcessingSteps(List steps) { + this.preProcessingSteps.addAll(steps); + return this; + } + + public UsageDetailsBuilder considerations(List considerations) { + this.considerations.addAll(considerations); + return this; + } + + public UsageDetailsBuilder factors(List factors) { + this.factors.addAll(factors); + return this; + } + + public UsageDetailsBuilder resources(List resources) { + this.resources.addAll(resources); + return this; + } + + public UsageDetailsBuilder primaryContact(String primaryContact) { + this.primaryContact = primaryContact; + return this; + } + + public UsageDetailsBuilder modelCitation(String modelCitation) { + this.modelCitation = modelCitation; + return this; + } + + public UsageDetailsBuilder modelLicense(String modelLicense) { + this.modelLicense = modelLicense; + return this; + } + + public UsageDetails build() { + return new UsageDetails( + intendedUse, + intendedUsers, + outOfScopeUses, + preProcessingSteps, + considerations, + factors, + resources, + primaryContact, + modelCitation, + modelLicense + ); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UsageDetailsBuilder that = (UsageDetailsBuilder) o; + return intendedUse.equals(that.intendedUse) && + intendedUsers.equals(that.intendedUsers) && + outOfScopeUses.equals(that.outOfScopeUses) && + preProcessingSteps.equals(that.preProcessingSteps) && + considerations.equals(that.considerations) && + factors.equals(that.factors) && + resources.equals(that.resources) && + primaryContact.equals(that.primaryContact) && + modelCitation.equals(that.modelCitation) && + modelLicense.equals(that.modelLicense); + } + + @Override + public int hashCode() { + return Objects.hash(intendedUse, intendedUsers, outOfScopeUses, preProcessingSteps, considerations, factors, resources, primaryContact, modelCitation, modelLicense); + } +} diff --git a/Interop/ModelCard/src/test/java/org/tribuo/interop/modelcard/EnsembleModelsTest.java b/Interop/ModelCard/src/test/java/org/tribuo/interop/modelcard/EnsembleModelsTest.java new file mode 100644 index 000000000..afbf66298 --- /dev/null +++ b/Interop/ModelCard/src/test/java/org/tribuo/interop/modelcard/EnsembleModelsTest.java @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.interop.modelcard; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.tribuo.DataSource; +import org.tribuo.MutableDataset; +import org.tribuo.classification.Label; +import org.tribuo.classification.ensemble.FullyWeightedVotingCombiner; +import org.tribuo.classification.ensemble.VotingCombiner; +import org.tribuo.classification.evaluation.LabelEvaluator; +import org.tribuo.classification.example.NoisyInterlockingCrescentsDataSource; +import org.tribuo.classification.libsvm.LibSVMClassificationModel; +import org.tribuo.classification.libsvm.LibSVMClassificationTrainer; +import org.tribuo.classification.libsvm.SVMClassificationType; +import org.tribuo.classification.sgd.fm.FMClassificationTrainer; +import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer; +import org.tribuo.classification.sgd.objectives.LogMulticlass; +import org.tribuo.common.libsvm.KernelType; +import org.tribuo.common.libsvm.SVMParameters; +import org.tribuo.common.sgd.AbstractFMModel; +import org.tribuo.common.sgd.AbstractFMTrainer; +import org.tribuo.common.sgd.AbstractSGDTrainer; +import org.tribuo.ensemble.BaggingTrainer; +import org.tribuo.ensemble.EnsembleModel; +import org.tribuo.ensemble.WeightedEnsembleModel; +import org.tribuo.math.optimisers.AdaGrad; +import org.tribuo.multilabel.MultiLabel; +import org.tribuo.multilabel.ensemble.MultiLabelVotingCombiner; +import org.tribuo.multilabel.evaluation.MultiLabelEvaluator; +import org.tribuo.multilabel.example.MultiLabelGaussianDataSource; +import org.tribuo.multilabel.sgd.fm.FMMultiLabelTrainer; +import org.tribuo.multilabel.sgd.objectives.BinaryCrossEntropy; +import org.tribuo.regression.Regressor; +import org.tribuo.regression.ensemble.AveragingCombiner; +import org.tribuo.regression.evaluation.RegressionEvaluator; +import org.tribuo.regression.example.NonlinearGaussianDataSource; +import org.tribuo.regression.sgd.fm.FMRegressionTrainer; +import org.tribuo.regression.sgd.objectives.SquaredLoss; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class EnsembleModelsTest { + // Classification combiners + private static final VotingCombiner VOTING = new VotingCombiner(); + private static final FullyWeightedVotingCombiner FULL_VOTING = new FullyWeightedVotingCombiner(); + // Regression combiners + private static final AveragingCombiner AVERAGING = new AveragingCombiner(); + // Multi-label combiners + private static final MultiLabelVotingCombiner ML_VOTING = new MultiLabelVotingCombiner(); + + @BeforeAll + public static void setup() { + Class[] classes = new Class[]{ + BaggingTrainer.class, + AbstractSGDTrainer.class, + org.tribuo.classification.sgd.linear.LinearSGDTrainer.class, + AbstractFMTrainer.class, + FMClassificationTrainer.class + }; + for (Class c : classes) { + Logger logger = Logger.getLogger(c.getName()); + logger.setLevel(Level.WARNING); + } + } + + @Test + public void testHomogenousClassificationModelCard() throws IOException { + DataSource