Skip to content

Commit

Permalink
Add ModelCard infrastructure (including implementation and testing) (#…
Browse files Browse the repository at this point in the history
…243)

* Add ModelCard infrastructure (including implementation and testing)

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Fix dependencies and language level of ModelCard

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Fix dependency issues

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Add license header to all new files

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Fix formatting issues

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Fix dependency issues

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Make all classes constituting ModelCard final

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Replace assert statements with assertEquals

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Use Path object instead of String and use temp files when testing

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Have list items display to user 0-indexed

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Make schema version a static final string

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Use self-generated testing data files

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Refactor constructor of TrainingDetails

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Switch field type of configuredParams from Map to JsonNode

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Refactor code for checking whether a value is numeric and use entrySet

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Implement equals and hashCode in all ModelCard classes

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Implement deserializeFromJson

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Make mapper package private

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Accept File as a type in UsageDetails

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Add more detail to command explanation in UsageDetails

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Return an unmodifiable view of collections in getters

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Fix styling issues

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Flatten configuredParams map

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Separate UsageDetails and ModelCardCLI and create UsageDetailsBuilder

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Remove addMetric from TestingDetails to make the class immutable

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Enable deserialization of model cards that do not have UsageDetails specified

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Remove redundant code from constructors and add license header to all files

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Include UsageDetails in json schema for all possible scenarios

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Allow user to use the same shell to write the UsageDetails for different model cards

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Remove processNestedParams method to avoid redundancy

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>

* Return an unmodifiable view of a map

Signed-off-by: Romina Mahinpei <mahinpei@student.ubc.ca>
  • Loading branch information
rmahinpei authored Jul 12, 2022
1 parent 252d5b8 commit 1a4f6e8
Show file tree
Hide file tree
Showing 17 changed files with 1,909 additions and 0 deletions.
122 changes: 122 additions & 0 deletions Interop/ModelCard/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>tribuo-interop</artifactId>
<groupId>org.tribuo</groupId>
<version>4.3.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<name>ModelCard</name>
<artifactId>tribuo-modelcard</artifactId>

<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<maven.compiler.release>17</maven.compiler.release>
</properties>

<dependencies>
<!-- runtime dependencies -->
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-data</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-interop-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-regression-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-json</artifactId>
<version>${project.version}</version>
</dependency>
<!-- test time dependencies -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-classification-core</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-classification-libsvm</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-classification-sgd</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-common-sgd</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-multilabel-core</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-multilabel-sgd</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-regression-sgd</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-classification-xgboost</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-anomaly-libsvm</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-clustering-core</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-clustering-kmeans</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.interop.modelcard;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.tribuo.Model;
import org.tribuo.evaluation.Evaluation;
import org.tribuo.interop.ExternalModel;

import java.io.IOException;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

public class ModelCard {
static final ObjectMapper mapper = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT);
private final ModelDetails modelDetails;
private final TrainingDetails trainingDetails;
private final TestingDetails testingDetails;
private final UsageDetails usageDetails;

public ModelCard(Model<?> model, Evaluation<?> evaluation, Map<String, Double> testingMetrics, UsageDetails usage) {
if (model instanceof ExternalModel) {
throw new IllegalArgumentException("External models currently not supported by ModelCard.");
}
modelDetails = new ModelDetails(model);
trainingDetails = new TrainingDetails(model);
testingDetails = new TestingDetails(evaluation, testingMetrics);
usageDetails = usage;
}

public ModelCard(Model<?> model, Evaluation<?> evaluation, UsageDetails usage) {
this(model, evaluation, Collections.emptyMap(), usage);
}

public ModelCard(Model<?> model, Evaluation<?> evaluation, Map<String, Double> testingMetrics) {
this(model, evaluation, testingMetrics, null);
}

public ModelCard(Model<?> model, Evaluation<?> evaluation) {
this(model, evaluation, Collections.emptyMap(), null);
}

private ModelCard(JsonNode modelCard) throws JsonProcessingException {
modelDetails = new ModelDetails(modelCard.get("ModelDetails"));
trainingDetails = new TrainingDetails(modelCard.get("TrainingDetails"));
testingDetails = new TestingDetails(modelCard.get("TestingDetails"));
if (modelCard.get("UsageDetails").isNull()) {
usageDetails = null;
} else {
usageDetails = new UsageDetails(modelCard.get("UsageDetails"));
}
}

public static ModelCard deserializeFromJson(Path sourceFile) throws IOException {
JsonNode modelCard = mapper.readTree(sourceFile.toFile());
return new ModelCard(modelCard);
}

public static ModelCard deserializeFromJson(JsonNode modelCard) throws JsonProcessingException {
return new ModelCard(modelCard);
}

public ModelDetails getModelDetails() {
return modelDetails;
}

public TrainingDetails getTrainingDetails() {
return trainingDetails;
}

public TestingDetails getTestingDetails() {
return testingDetails;
}

public UsageDetails getUsageDetails() {
return usageDetails;
}

public ObjectNode toJson() {
ObjectNode modelCardObject = mapper.createObjectNode();
modelCardObject.set("ModelDetails", modelDetails.toJson());
modelCardObject.set("TrainingDetails", trainingDetails.toJson());
modelCardObject.set("TestingDetails", testingDetails.toJson());
if (usageDetails != null) {
modelCardObject.set("UsageDetails", usageDetails.toJson());
} else {
modelCardObject.putNull("UsageDetails");
}
return modelCardObject;
}

public void saveToFile(Path destinationFile) throws IOException {
ObjectNode modelCardObject = toJson();
mapper.writeValue(destinationFile.toFile(), modelCardObject);
}

@Override
public String toString() {
return toJson().toPrettyString();
}


@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ModelCard modelCard = (ModelCard) o;
return modelDetails.equals(modelCard.modelDetails) &&
trainingDetails.equals(modelCard.trainingDetails) &&
testingDetails.equals(modelCard.testingDetails) &&
Objects.equals(usageDetails, modelCard.usageDetails);
}

@Override
public int hashCode() {
return Objects.hash(modelDetails, trainingDetails, testingDetails, usageDetails);
}
}
Loading

0 comments on commit 1a4f6e8

Please sign in to comment.