From 89f254483b6e2979f9d77e46e61ebf5ed8d08630 Mon Sep 17 00:00:00 2001 From: gevant <39420213+gevant@users.noreply.github.com> Date: Sat, 30 Dec 2023 18:43:49 -0500 Subject: [PATCH 1/3] Yelov8 Translator optimization - Improved post-processing performance up to 40x by reducing expensive native calls - Additional argument 'maxBox' added to improve post-processing performance by reducing number of considered bounding boxes - Sunset file fixed,, previous version ignored first 4 rows, so recognized classes were 4 off. Adding 4 rows header fixes the problem. New headers pointing to ultralytics doc pages and original coco dataset page. --- .../cv/translator/YoloV8Translator.java | 55 ++++++++++++------- .../examples/inference/Yolov8Detection.java | 2 + examples/src/test/resources/yolov8_synset.txt | 4 ++ 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java index dc160ba754b..037fc0a6d52 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -17,6 +17,8 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; import java.util.ArrayList; import java.util.Map; @@ -27,6 +29,9 @@ */ public class YoloV8Translator extends YoloV5Translator { + + final int maxBoxes; + /** * Constructs an ImageTranslator with the provided builder. * @@ -34,6 +39,7 @@ public class YoloV8Translator extends YoloV5Translator { */ protected YoloV8Translator(Builder builder) { super(builder); + maxBoxes= builder.maxBox; } /** @@ -52,39 +58,44 @@ public static YoloV8Translator.Builder builder(Map arguments) { @Override protected DetectedObjects processFromBoxOutput(NDList list) { - NDArray features4OneImg = list.get(0); - int sizeClasses = classes.size(); - long sizeBoxes = features4OneImg.size(1); - ArrayList intermediateResults = new ArrayList<>(); + final NDArray rawResult = list.get(0); + final NDArray reshapedResult = rawResult.transpose(); + final Shape preparedResult = reshapedResult.getShape(); + final long numberRows = preparedResult.get(0); + final long sizeClasses = preparedResult.get(1); - for (long b = 0; b < sizeBoxes; b++) { - float maxClass = 0; - int maxIndex = 0; + final ArrayList intermediateResults = new ArrayList<>(); + // reverse order search in heap; searches through #maxBoxes for optimization when set + for (int i = (int) numberRows - 1; i > numberRows - maxBoxes; i--) { + final float[] row = reshapedResult.get(i).toFloatArray(); + + float maxClassProb = -1f; + int maxIndex = -1; for (int c = 4; c < sizeClasses; c++) { - float classProb = features4OneImg.getFloat(c, b); - if (classProb > maxClass) { - maxClass = classProb; + float classProb = row[c]; + if (classProb > maxClassProb) { + maxClassProb = classProb; maxIndex = c; } } - if (maxClass > threshold) { - float xPos = features4OneImg.getFloat(0, b); // center x - float yPos = features4OneImg.getFloat(1, b); // center y - float w = features4OneImg.getFloat(2, b); - float h = features4OneImg.getFloat(3, b); - Rectangle rect = + if (maxClassProb > threshold) { + float xPos = row[0]; // center x + float yPos = row[1]; // center y + float w = row[2]; + float h = row[3]; + final Rectangle rect = new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); intermediateResults.add( - new IntermediateResult(classes.get(maxIndex), maxClass, maxIndex, rect)); + new IntermediateResult(classes.get(maxIndex), maxClassProb, maxIndex, rect)); } } - return nms(intermediateResults); } /** The builder for {@link YoloV8Translator}. */ public static class Builder extends YoloV5Translator.Builder { + int maxBox = 8400; /** * Builds the translator. * @@ -99,5 +110,11 @@ public YoloV8Translator build() { validate(); return new YoloV8Translator(this); } + + @Override + protected void configPostProcess(Map arguments) { + super.configPostProcess(arguments); + maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400); + } } -} +} \ No newline at end of file diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java index 5a474c1da31..cb0ea89baae 100644 --- a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -65,6 +65,8 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran arguments.put("toTensor", true); arguments.put("applyRatio", true); arguments.put("threshold", 0.6f); + // for performance optimization maxBox parameter can reduce number of considered boxes from 8400 + arguments.put("maxBox", 1000); arguments.put("synsetFileName", "yolov8_synset.txt"); YoloV8TranslatorFactory yoloV8TranslatorFactory = new YoloV8TranslatorFactory(); diff --git a/examples/src/test/resources/yolov8_synset.txt b/examples/src/test/resources/yolov8_synset.txt index 7139f0cc628..ffba2064933 100644 --- a/examples/src/test/resources/yolov8_synset.txt +++ b/examples/src/test/resources/yolov8_synset.txt @@ -1,3 +1,7 @@ +# Classes for coco dataset on which yelov8 is trained +# source config https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml. +# COCO dataset website: https://cocodataset.org/#home +# Ultralytics Coco doc page: https://docs.ultralytics.com/datasets/detect/coco/ person bicycle car From 83c792507b53d6d244faffe81881a08a5f4b67e3 Mon Sep 17 00:00:00 2001 From: gevant <39420213+gevant@users.noreply.github.com> Date: Sun, 31 Dec 2023 14:34:55 -0500 Subject: [PATCH 2/3] addressed PR comments - reformatted code - removed final --- .../cv/translator/YoloV8Translator.java | 22 +++++++++---------- .../examples/inference/Yolov8Detection.java | 3 ++- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java index 037fc0a6d52..f9cf999acaf 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -29,8 +29,7 @@ */ public class YoloV8Translator extends YoloV5Translator { - - final int maxBoxes; + int maxBoxes; /** * Constructs an ImageTranslator with the provided builder. @@ -39,7 +38,7 @@ public class YoloV8Translator extends YoloV5Translator { */ protected YoloV8Translator(Builder builder) { super(builder); - maxBoxes= builder.maxBox; + maxBoxes = builder.maxBox; } /** @@ -58,13 +57,13 @@ public static YoloV8Translator.Builder builder(Map arguments) { @Override protected DetectedObjects processFromBoxOutput(NDList list) { - final NDArray rawResult = list.get(0); - final NDArray reshapedResult = rawResult.transpose(); - final Shape preparedResult = reshapedResult.getShape(); - final long numberRows = preparedResult.get(0); - final long sizeClasses = preparedResult.get(1); + NDArray rawResult = list.get(0); + NDArray reshapedResult = rawResult.transpose(); + Shape preparedResult = reshapedResult.getShape(); + long numberRows = preparedResult.get(0); + long sizeClasses = preparedResult.get(1); - final ArrayList intermediateResults = new ArrayList<>(); + ArrayList intermediateResults = new ArrayList<>(); // reverse order search in heap; searches through #maxBoxes for optimization when set for (int i = (int) numberRows - 1; i > numberRows - maxBoxes; i--) { final float[] row = reshapedResult.get(i).toFloatArray(); @@ -87,7 +86,8 @@ protected DetectedObjects processFromBoxOutput(NDList list) { final Rectangle rect = new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); intermediateResults.add( - new IntermediateResult(classes.get(maxIndex), maxClassProb, maxIndex, rect)); + new IntermediateResult( + classes.get(maxIndex), maxClassProb, maxIndex, rect)); } } return nms(intermediateResults); @@ -117,4 +117,4 @@ protected void configPostProcess(Map arguments) { maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400); } } -} \ No newline at end of file +} diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java index cb0ea89baae..ef94ebb2273 100644 --- a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -65,7 +65,8 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran arguments.put("toTensor", true); arguments.put("applyRatio", true); arguments.put("threshold", 0.6f); - // for performance optimization maxBox parameter can reduce number of considered boxes from 8400 + // for performance optimization maxBox parameter can reduce number of considered boxes from + // 8400 arguments.put("maxBox", 1000); arguments.put("synsetFileName", "yolov8_synset.txt"); From 805c448ac240abe365e6926421b1262b3b179146 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 31 Dec 2023 13:19:16 -0800 Subject: [PATCH 3/3] Refactor Yolov8 example 1. Fixes windows line return 2. Furthe reduce NDArray operations 3. Refactor example code 4. Add unittest for Yolov8Detection --- .../cv/translator/YoloV8Translator.java | 244 +++++++++--------- .../examples/inference/Yolov8Detection.java | 93 ++----- .../inference/Yolov8DetectionTest.java | 40 +++ 3 files changed, 187 insertions(+), 190 deletions(-) create mode 100644 examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java index f9cf999acaf..d47f7a4a14a 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -1,120 +1,124 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.modality.cv.translator; - -import ai.djl.modality.cv.output.DetectedObjects; -import ai.djl.modality.cv.output.Rectangle; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.types.DataType; -import ai.djl.ndarray.types.Shape; -import ai.djl.translate.ArgumentsUtil; - -import java.util.ArrayList; -import java.util.Map; - -/** - * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check - * here: https://github.com/ultralytics/ultralytics - */ -public class YoloV8Translator extends YoloV5Translator { - - int maxBoxes; - - /** - * Constructs an ImageTranslator with the provided builder. - * - * @param builder the data to build with - */ - protected YoloV8Translator(Builder builder) { - super(builder); - maxBoxes = builder.maxBox; - } - - /** - * Creates a builder to build a {@code YoloV8Translator} with specified arguments. - * - * @param arguments arguments to specify builder options - * @return a new builder - */ - public static YoloV8Translator.Builder builder(Map arguments) { - YoloV8Translator.Builder builder = new YoloV8Translator.Builder(); - builder.configPreProcess(arguments); - builder.configPostProcess(arguments); - - return builder; - } - - @Override - protected DetectedObjects processFromBoxOutput(NDList list) { - NDArray rawResult = list.get(0); - NDArray reshapedResult = rawResult.transpose(); - Shape preparedResult = reshapedResult.getShape(); - long numberRows = preparedResult.get(0); - long sizeClasses = preparedResult.get(1); - - ArrayList intermediateResults = new ArrayList<>(); - // reverse order search in heap; searches through #maxBoxes for optimization when set - for (int i = (int) numberRows - 1; i > numberRows - maxBoxes; i--) { - final float[] row = reshapedResult.get(i).toFloatArray(); - - float maxClassProb = -1f; - int maxIndex = -1; - for (int c = 4; c < sizeClasses; c++) { - float classProb = row[c]; - if (classProb > maxClassProb) { - maxClassProb = classProb; - maxIndex = c; - } - } - - if (maxClassProb > threshold) { - float xPos = row[0]; // center x - float yPos = row[1]; // center y - float w = row[2]; - float h = row[3]; - final Rectangle rect = - new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); - intermediateResults.add( - new IntermediateResult( - classes.get(maxIndex), maxClassProb, maxIndex, rect)); - } - } - return nms(intermediateResults); - } - - /** The builder for {@link YoloV8Translator}. */ - public static class Builder extends YoloV5Translator.Builder { - int maxBox = 8400; - /** - * Builds the translator. - * - * @return the new translator - */ - @Override - public YoloV8Translator build() { - if (pipeline == null) { - addTransform( - array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255)); - } - validate(); - return new YoloV8Translator(this); - } - - @Override - protected void configPostProcess(Map arguments) { - super.configPostProcess(arguments); - maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400); - } - } -} +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; + +import java.util.ArrayList; +import java.util.Map; + +/** + * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check + * here: https://github.com/ultralytics/ultralytics + */ +public class YoloV8Translator extends YoloV5Translator { + + private int maxBoxes; + + /** + * Constructs an ImageTranslator with the provided builder. + * + * @param builder the data to build with + */ + protected YoloV8Translator(Builder builder) { + super(builder); + maxBoxes = builder.maxBox; + } + + /** + * Creates a builder to build a {@code YoloV8Translator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static YoloV8Translator.Builder builder(Map arguments) { + YoloV8Translator.Builder builder = new YoloV8Translator.Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** {@inheritDoc} */ + @Override + protected DetectedObjects processFromBoxOutput(NDList list) { + NDArray rawResult = list.get(0); + NDArray reshapedResult = rawResult.transpose(); + Shape shape = reshapedResult.getShape(); + float[] buf = reshapedResult.toFloatArray(); + int numberRows = Math.toIntExact(shape.get(0)); + int nClasses = Math.toIntExact(shape.get(1)); + + ArrayList intermediateResults = new ArrayList<>(); + // reverse order search in heap; searches through #maxBoxes for optimization when set + for (int i = numberRows - 1; i > numberRows - maxBoxes; --i) { + int index = i * nClasses; + float maxClassProb = -1f; + int maxIndex = -1; + for (int c = 4; c < nClasses; c++) { + float classProb = buf[index + c]; + if (classProb > maxClassProb) { + maxClassProb = classProb; + maxIndex = c; + } + } + + if (maxClassProb > threshold) { + float xPos = buf[index]; // center x + float yPos = buf[index + 1]; // center y + float w = buf[index + 2]; + float h = buf[index + 3]; + Rectangle rect = + new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); + intermediateResults.add( + new IntermediateResult( + classes.get(maxIndex), maxClassProb, maxIndex, rect)); + } + } + return nms(intermediateResults); + } + + /** The builder for {@link YoloV8Translator}. */ + public static class Builder extends YoloV5Translator.Builder { + + private int maxBox = 8400; + + /** + * Builds the translator. + * + * @return the new translator + */ + @Override + public YoloV8Translator build() { + if (pipeline == null) { + addTransform( + array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255)); + } + validate(); + return new YoloV8Translator(this); + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + super.configPostProcess(arguments); + maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400); + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java index ef94ebb2273..27ee1211d05 100644 --- a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -16,27 +16,21 @@ import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.modality.cv.output.BoundingBox; import ai.djl.modality.cv.output.DetectedObjects; -import ai.djl.modality.cv.output.DetectedObjects.DetectedObject; -import ai.djl.modality.cv.output.Rectangle; import ai.djl.modality.cv.translator.YoloV8TranslatorFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.io.OutputStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; /** An example of inference using an yolov8 model. */ public final class Yolov8Detection { @@ -51,85 +45,44 @@ public static void main(String[] args) throws IOException, ModelException, Trans } public static DetectedObjects predict() throws IOException, ModelException, TranslateException { - String classPath = System.getProperty("java.class.path"); - String pathSeparator = System.getProperty("path.separator"); - classPath = classPath.split(pathSeparator)[0]; - Path modelPath = Paths.get(classPath + "/yolov8n.onnx"); - Path imgPath = Paths.get(classPath + "/yolov8_test.jpg"); + Path modelPath = Paths.get("src/test/resources/yolov8n.onnx"); + Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg"); Image img = ImageFactory.getInstance().fromFile(imgPath); - Map arguments = new ConcurrentHashMap<>(); - arguments.put("width", 640); - arguments.put("height", 640); - arguments.put("resize", "true"); - arguments.put("toTensor", true); - arguments.put("applyRatio", true); - arguments.put("threshold", 0.6f); - // for performance optimization maxBox parameter can reduce number of considered boxes from - // 8400 - arguments.put("maxBox", 1000); - arguments.put("synsetFileName", "yolov8_synset.txt"); - - YoloV8TranslatorFactory yoloV8TranslatorFactory = new YoloV8TranslatorFactory(); - Translator translator = - yoloV8TranslatorFactory.newInstance( - Image.class, DetectedObjects.class, null, arguments); - Criteria criteria = Criteria.builder() .setTypes(Image.class, DetectedObjects.class) .optModelPath(modelPath) .optEngine("OnnxRuntime") - .optTranslator(translator) + .optArgument("width", 640) + .optArgument("height", 640) + .optArgument("resize", true) + .optArgument("toTensor", true) + .optArgument("applyRatio", true) + .optArgument("threshold", 0.6f) + // for performance optimization maxBox parameter can reduce number of + // considered boxes from 8400 + .optArgument("maxBox", 1000) + .optArgument("synsetFileName", "yolov8_synset.txt") + .optTranslatorFactory(new YoloV8TranslatorFactory()) .optProgress(new ProgressBar()) .build(); - DetectedObjects detectedObjects; - DetectedObject detectedObject; try (ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor()) { - Path outputPath = Paths.get(classPath + "/output"); + Path outputPath = Paths.get("build/output"); Files.createDirectories(outputPath); - detectedObjects = predictor.predict(img); - - if (detectedObjects.getNumberOfObjects() > 0) { - List detectedObjectList = detectedObjects.items(); - for (DetectedObject object : detectedObjectList) { - detectedObject = object; - BoundingBox boundingBox = detectedObject.getBoundingBox(); - Rectangle tectangle = boundingBox.getBounds(); - logger.info( - detectedObject.getClassName() - + " " - + detectedObject.getProbability() - + " " - + tectangle.getX() - + " " - + tectangle.getY() - + " " - + tectangle.getWidth() - + " " - + tectangle.getHeight()); + DetectedObjects detection = predictor.predict(img); + if (detection.getNumberOfObjects() > 0) { + img.drawBoundingBoxes(detection); + Path output = outputPath.resolve("yolov8_detected.png"); + try (OutputStream os = Files.newOutputStream(output)) { + img.save(os, "png"); } - - saveBoundingBoxImage( - img.resize(640, 640, false), - detectedObjects, - outputPath, - imgPath.toFile().getName()); + logger.info("Detected object saved in: {}", output); } - - return detectedObjects; + return detection; } } - - private static void saveBoundingBoxImage( - Image img, DetectedObjects detectedObjects, Path outputPath, String outputFileName) - throws IOException { - img.drawBoundingBoxes(detectedObjects); - - Path imagePath = outputPath.resolve(outputFileName); - img.save(Files.newOutputStream(imagePath), "png"); - } } diff --git a/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java new file mode 100644 index 00000000000..35e3fc434aa --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.testing.TestRequirements; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +public class Yolov8DetectionTest { + + @Test + public void testYolov8Detection() throws ModelException, TranslateException, IOException { + TestRequirements.engine("MXNet", "PyTorch"); + + DetectedObjects result = Yolov8Detection.predict(); + + Assert.assertTrue(result.getNumberOfObjects() >= 1); + Classifications.Classification obj = result.best(); + String className = obj.getClassName(); + Assert.assertEquals(className, "dog"); + Assert.assertTrue(obj.getProbability() > 0.6); + } +}