Skip to content

Commit

Permalink
To support Yolov8 (#2776)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: KexinFeng <fenkexin@amazon.com>
  • Loading branch information
2 people authored and frankfliu committed Apr 26, 2024
1 parent 2eb77e5 commit 8e38c83
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ protected double overlap(double x1, double w1, double x2, double w2) {
return right - left;
}

private DetectedObjects processFromBoxOutput(NDList list) {
protected DetectedObjects processFromBoxOutput(NDList list) {
float[] flattened = list.get(0).toFloatArray();
ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();
int sizeClasses = classes.size();
Expand Down Expand Up @@ -280,7 +280,7 @@ public YoloV5Translator build() {
}
}

private static final class IntermediateResult {
protected static final class IntermediateResult {

/**
* A sortable score for how good the recognition is relative to others. Higher should be
Expand Down
103 changes: 103 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;

import java.util.ArrayList;
import java.util.Map;

/**
* A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check
* here: https://github.com/ultralytics/ultralytics
*/
public class YoloV8Translator extends YoloV5Translator {

/**
* Constructs an ImageTranslator with the provided builder.
*
* @param builder the data to build with
*/
protected YoloV8Translator(Builder builder) {
super(builder);
}

/**
* Creates a builder to build a {@code YoloV8Translator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static YoloV8Translator.Builder builder(Map<String, ?> arguments) {
YoloV8Translator.Builder builder = new YoloV8Translator.Builder();
builder.configPreProcess(arguments);
builder.configPostProcess(arguments);

return builder;
}

@Override
protected DetectedObjects processFromBoxOutput(NDList list) {
NDArray features4OneImg = list.get(0);
int sizeClasses = classes.size();
long sizeBoxes = features4OneImg.size(1);
ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();

for (long b = 0; b < sizeBoxes; b++) {
float maxClass = 0;
int maxIndex = 0;
for (int c = 4; c < sizeClasses; c++) {
float classProb = features4OneImg.getFloat(c, b);
if (classProb > maxClass) {
maxClass = classProb;
maxIndex = c;
}
}

if (maxClass > threshold) {
float xPos = features4OneImg.getFloat(0, b); // center x
float yPos = features4OneImg.getFloat(1, b); // center y
float w = features4OneImg.getFloat(2, b);
float h = features4OneImg.getFloat(3, b);
Rectangle rect =
new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
intermediateResults.add(
new IntermediateResult(classes.get(maxIndex), maxClass, maxIndex, rect));
}
}

return nms(intermediateResults);
}

/** The builder for {@link YoloV8Translator}. */
public static class Builder extends YoloV5Translator.Builder {
/**
* Builds the translator.
*
* @return the new translator
*/
@Override
public YoloV8Translator build() {
if (pipeline == null) {
addTransform(
array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255));
}
validate();
return new YoloV8Translator(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.translate.Translator;

import java.io.Serializable;
import java.util.Map;

/** A translatorFactory that creates a {@link YoloV8Translator} instance. */
public class YoloV8TranslatorFactory extends ObjectDetectionTranslatorFactory
implements Serializable {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
protected Translator<Image, DetectedObjects> buildBaseTranslator(
Model model, Map<String, ?> arguments) {
return YoloV8Translator.builder(arguments).build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.translate.BasicTranslator;
import ai.djl.translate.Translator;

import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;

public class YoloV8TranslatorFactoryTest {

private YoloV8TranslatorFactory factory;

@BeforeClass
public void setUp() {
factory = new YoloV8TranslatorFactory();
}

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 5);
}

@Test
public void testNewInstance() {
Map<String, String> arguments = new HashMap<>();
try (Model model = Model.newInstance("test")) {
Translator<Image, DetectedObjects> translator1 =
factory.newInstance(Image.class, DetectedObjects.class, model, arguments);
Assert.assertTrue(translator1 instanceof YoloV8Translator);

Translator<Path, DetectedObjects> translator2 =
factory.newInstance(Path.class, DetectedObjects.class, model, arguments);
Assert.assertTrue(translator2 instanceof BasicTranslator);

Translator<URL, DetectedObjects> translator3 =
factory.newInstance(URL.class, DetectedObjects.class, model, arguments);
Assert.assertTrue(translator3 instanceof BasicTranslator);

Translator<InputStream, DetectedObjects> translator4 =
factory.newInstance(InputStream.class, DetectedObjects.class, model, arguments);
Assert.assertTrue(translator4 instanceof BasicTranslator);

Translator<Input, Output> translator5 =
factory.newInstance(Input.class, Output.class, model, arguments);
Assert.assertTrue(translator5 instanceof ImageServingTranslator);

Assert.assertThrows(
IllegalArgumentException.class,
() -> factory.newInstance(Image.class, Output.class, model, arguments));
}
}
}
132 changes: 132 additions & 0 deletions examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV8TranslatorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** An example of inference using an yolov8 model. */
public final class Yolov8Detection {

private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class);

private Yolov8Detection() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
DetectedObjects detection = Yolov8Detection.predict();
logger.info("{}", detection);
}

public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
String classPath = System.getProperty("java.class.path");
String pathSeparator = System.getProperty("path.separator");
classPath = classPath.split(pathSeparator)[0];
Path modelPath = Paths.get(classPath + "/yolov8n.onnx");
Path imgPath = Paths.get(classPath + "/yolov8_test.jpg");
Image img = ImageFactory.getInstance().fromFile(imgPath);

Map<String, Object> arguments = new ConcurrentHashMap<>();
arguments.put("width", 640);
arguments.put("height", 640);
arguments.put("resize", "true");
arguments.put("toTensor", true);
arguments.put("applyRatio", true);
arguments.put("threshold", 0.6f);
arguments.put("synsetFileName", "yolov8_synset.txt");

YoloV8TranslatorFactory yoloV8TranslatorFactory = new YoloV8TranslatorFactory();
Translator<Image, DetectedObjects> translator =
yoloV8TranslatorFactory.newInstance(
Image.class, DetectedObjects.class, null, arguments);

Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.optModelPath(modelPath)
.optEngine("OnnxRuntime")
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();

DetectedObjects detectedObjects;
DetectedObject detectedObject;
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel();
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
Path outputPath = Paths.get(classPath + "/output");
Files.createDirectories(outputPath);

detectedObjects = predictor.predict(img);

if (detectedObjects.getNumberOfObjects() > 0) {
List<DetectedObject> detectedObjectList = detectedObjects.items();
for (DetectedObject object : detectedObjectList) {
detectedObject = object;
BoundingBox boundingBox = detectedObject.getBoundingBox();
Rectangle tectangle = boundingBox.getBounds();
logger.info(
detectedObject.getClassName()
+ " "
+ detectedObject.getProbability()
+ " "
+ tectangle.getX()
+ " "
+ tectangle.getY()
+ " "
+ tectangle.getWidth()
+ " "
+ tectangle.getHeight());
}

saveBoundingBoxImage(
img.resize(640, 640, false),
detectedObjects,
outputPath,
imgPath.toFile().getName());
}

return detectedObjects;
}
}

private static void saveBoundingBoxImage(
Image img, DetectedObjects detectedObjects, Path outputPath, String outputFileName)
throws IOException {
img.drawBoundingBoxes(detectedObjects);

Path imagePath = outputPath.resolve(outputFileName);
img.save(Files.newOutputStream(imagePath), "png");
}
}
Loading

0 comments on commit 8e38c83

Please sign in to comment.