Skip to content

Commit

Permalink
[api] Standardizes CV output format (#3493)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Oct 4, 2024
1 parent 3059991 commit f88f2d6
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 113 deletions.
54 changes: 13 additions & 41 deletions api/src/main/java/ai/djl/modality/Classifications.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,8 @@
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;

import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -42,11 +36,6 @@ public class Classifications implements JsonSerializable, Ensembleable<Classific

private static final long serialVersionUID = 1L;

private static final Gson GSON =
JsonUtils.builder()
.registerTypeAdapter(Classifications.class, new ClassificationsSerializer())
.create();

@SuppressWarnings("serial")
protected List<String> classNames;

Expand Down Expand Up @@ -210,31 +199,25 @@ public <T extends Classification> T get(String className) {

/** {@inheritDoc} */
@Override
public String toJson() {
return GSON.toJson(this) + '\n';
}

/** {@inheritDoc} */
@Override
public String getAsString() {
return toJson();
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8));
public JsonElement serialize() {
return JsonUtils.GSON.toJsonTree(topK());
}

/** {@inheritDoc} */
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append('[').append(System.lineSeparator());
for (Classification item : topK(topK)) {
sb.append('\t').append(item).append(System.lineSeparator());
sb.append("[\n");
List<Classification> list = topK();
int index = 0;
for (Classification item : list) {
sb.append('\t').append(item);
if (++index < list.size()) {
sb.append(',');
}
sb.append('\n');
}
sb.append(']');
sb.append("]\n");
return sb.toString();
}

Expand Down Expand Up @@ -306,7 +289,7 @@ public double getProbability() {
@Override
public String toString() {
StringBuilder sb = new StringBuilder(100);
sb.append("{\"class\": \"").append(className).append("\", \"probability\": ");
sb.append("{\"className\": \"").append(className).append("\", \"probability\": ");
if (probability < 0.00001) {
sb.append(String.format("%.1e", probability));
} else {
Expand All @@ -317,15 +300,4 @@ public String toString() {
return sb.toString();
}
}

/** A customized Gson serializer to serialize the {@code Classifications} object. */
public static final class ClassificationsSerializer implements JsonSerializer<Classifications> {

/** {@inheritDoc} */
@Override
public JsonElement serialize(Classifications src, Type type, JsonSerializationContext ctx) {
List<?> list = src.topK();
return ctx.serialize(list);
}
}
}
41 changes: 15 additions & 26 deletions api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,11 @@
import ai.djl.util.JsonUtils;
import ai.djl.util.RandomUtils;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import com.google.gson.JsonObject;

import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;

/**
* A class representing the segmentation result of an image in an {@link
Expand All @@ -38,12 +34,7 @@ public class CategoryMask implements JsonSerializable {

private static final int COLOR_BLACK = 0xFF000000;

private static final Gson GSON =
JsonUtils.builder()
.registerTypeAdapter(CategoryMask.class, new SegmentationSerializer())
.create();

private transient List<String> classes;
private List<String> classes;
private int[][] mask;

/**
Expand Down Expand Up @@ -77,14 +68,22 @@ public int[][] getMask() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8));
public JsonElement serialize() {
JsonObject ret = new JsonObject();
ret.add("classes", JsonUtils.GSON.toJsonTree(classes));
ret.add("mask", JsonUtils.GSON.toJsonTree(mask));
return ret;
}

/** {@inheritDoc} */
@Override
public String toJson() {
return GSON.toJson(this) + '\n';
public String toString() {
StringBuilder sb = new StringBuilder(4096);
String list = classes.stream().map(s -> '"' + s + '"').collect(Collectors.joining(", "));
sb.append("{\n\t\"classes\": [").append(list).append("],\n\t\"mask\": ");
sb.append(JsonUtils.GSON_COMPACT.toJson(mask));
sb.append("\n}");
return sb.toString();
}

/**
Expand Down Expand Up @@ -195,14 +194,4 @@ private int[] generateColors(int background, int opacity) {
}
return colors;
}

/** A customized Gson serializer to serialize the {@code Segmentation} object. */
public static final class SegmentationSerializer implements JsonSerializer<CategoryMask> {

/** {@inheritDoc} */
@Override
public JsonElement serialize(CategoryMask src, Type type, JsonSerializationContext ctx) {
return ctx.serialize(src.getMask());
}
}
}
20 changes: 3 additions & 17 deletions api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
package ai.djl.modality.cv.output;

import ai.djl.modality.Classifications;
import ai.djl.util.JsonUtils;

import com.google.gson.Gson;

import java.util.List;

Expand All @@ -27,11 +24,6 @@ public class DetectedObjects extends Classifications {

private static final long serialVersionUID = 1L;

private static final Gson GSON =
JsonUtils.builder()
.registerTypeAdapter(DetectedObjects.class, new ClassificationsSerializer())
.create();

@SuppressWarnings("serial")
private List<BoundingBox> boundingBoxes;

Expand Down Expand Up @@ -69,12 +61,6 @@ public int getNumberOfObjects() {
return boundingBoxes.size();
}

/** {@inheritDoc} */
@Override
public String toJson() {
return GSON.toJson(this) + '\n';
}

/** A {@code DetectedObject} represents a single potential detected Object for an image. */
public static final class DetectedObject extends Classification {

Expand Down Expand Up @@ -106,15 +92,15 @@ public BoundingBox getBoundingBox() {
public String toString() {
double probability = getProbability();
StringBuilder sb = new StringBuilder(200);
sb.append("{\"class\": \"").append(getClassName()).append("\", \"probability\": ");
sb.append("{\"className\": \"").append(getClassName()).append("\", \"probability\": ");
if (probability < 0.00001) {
sb.append(String.format("%.1e", probability));
} else {
probability = (int) (probability * 100000) / 100000f;
sb.append(String.format("%.5f", probability));
}
if (getBoundingBox() != null) {
sb.append(", \"bounds\": ").append(getBoundingBox());
if (boundingBox != null) {
sb.append(", \"boundingBox\": ").append(boundingBox);
}
sb.append('}');
return sb.toString();
Expand Down
26 changes: 5 additions & 21 deletions api/src/main/java/ai/djl/modality/cv/output/Joints.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.modality.cv.output;

import ai.djl.util.JsonUtils;

import java.io.Serializable;
import java.util.List;

Expand Down Expand Up @@ -48,19 +50,7 @@ public List<Joint> getJoints() {
/** {@inheritDoc} */
@Override
public String toString() {
StringBuilder sb = new StringBuilder(4000);
sb.append("\n[\n\t");
boolean first = true;
for (Joint joint : joints) {
if (first) {
first = false;
} else {
sb.append(",\n\t");
}
sb.append(joint);
}
sb.append("\n]");
return sb.toString();
return JsonUtils.GSON_PRETTY.toJson(this) + "\n";
}

/**
Expand All @@ -69,7 +59,9 @@ public String toString() {
* @see Joints
*/
public static class Joint extends Point {

private static final long serialVersionUID = 1L;

private double confidence;

/**
Expand All @@ -92,13 +84,5 @@ public Joint(double x, double y, double confidence) {
public double getConfidence() {
return confidence;
}

/** {@inheritDoc} */
@Override
public String toString() {
return String.format(
"{\"Joint\": {\"x\"=%.3f, \"y\"=%.3f}, \"confidence\": %.4f}",
getX(), getY(), getConfidence());
}
}
}
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/output/Landmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
*/
package ai.djl.modality.cv.output;

import ai.djl.util.JsonUtils;

import com.google.gson.JsonObject;

import java.util.List;

/** {@code Landmark} is the container that stores the key points for landmark on a single face. */
Expand Down Expand Up @@ -41,4 +45,12 @@ public Landmark(double x, double y, double width, double height, List<Point> poi
public Iterable<Point> getPath() {
return points;
}

/** {@inheritDoc} */
@Override
public JsonObject serialize() {
JsonObject ret = super.serialize();
ret.add("landmarks", JsonUtils.GSON.toJsonTree(points));
return ret;
}
}
15 changes: 15 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/output/Mask.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.JsonUtils;

import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;

/**
* A mask with a probability for each pixel within a bounding rectangle.
Expand Down Expand Up @@ -79,6 +83,17 @@ public boolean isFullImageMask() {
return fullImageMask;
}

/** {@inheritDoc} */
@Override
public JsonObject serialize() {
JsonObject ret = super.serialize();
if (fullImageMask) {
ret.add("fullImageMask", new JsonPrimitive(true));
}
ret.add("mask", JsonUtils.GSON.toJsonTree(probDist));
return ret;
}

/**
* Converts the mask tensor to a mask array.
*
Expand Down
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/output/Point.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.modality.cv.output;

import ai.djl.util.JsonUtils;

import java.io.Serializable;

/**
Expand All @@ -20,6 +22,7 @@
public class Point implements Serializable {

private static final long serialVersionUID = 1L;

private double x;
private double y;

Expand Down Expand Up @@ -52,4 +55,10 @@ public double getX() {
public double getY() {
return y;
}

/** {@inheritDoc} */
@Override
public String toString() {
return JsonUtils.GSON_COMPACT.toJson(this);
}
}
Loading

0 comments on commit f88f2d6

Please sign in to comment.