Skip to content

Commit

Permalink
Yelov8 Translator optimization (#2908)
Browse files Browse the repository at this point in the history
* Yelov8 Translator optimization

- Improved post-processing performance up to 40x by reducing expensive native calls
- Additional argument 'maxBox' added to improve post-processing performance by reducing number of considered bounding boxes
- Sunset file fixed,, previous version ignored first 4 rows, so recognized classes were 4 off. Adding 4 rows header fixes the problem. New headers pointing to ultralytics doc pages and original coco dataset page.

* addressed PR comments

- reformatted code
- removed final

* Refactor Yolov8 example

1. Fixes windows line return
2. Furthe reduce NDArray operations
3. Refactor example code
4. Add unittest for Yolov8Detection

---------

Co-authored-by: Frank Liu <[email protected]>
  • Loading branch information
gevant and frankfliu authored Dec 31, 2023
1 parent e16ef9d commit 0e6f143
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 170 deletions.
227 changes: 124 additions & 103 deletions api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java
Original file line number Diff line number Diff line change
@@ -1,103 +1,124 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;

import java.util.ArrayList;
import java.util.Map;

/**
* A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check
* here: https://github.com/ultralytics/ultralytics
*/
public class YoloV8Translator extends YoloV5Translator {

/**
* Constructs an ImageTranslator with the provided builder.
*
* @param builder the data to build with
*/
protected YoloV8Translator(Builder builder) {
super(builder);
}

/**
* Creates a builder to build a {@code YoloV8Translator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static YoloV8Translator.Builder builder(Map<String, ?> arguments) {
YoloV8Translator.Builder builder = new YoloV8Translator.Builder();
builder.configPreProcess(arguments);
builder.configPostProcess(arguments);

return builder;
}

@Override
protected DetectedObjects processFromBoxOutput(NDList list) {
NDArray features4OneImg = list.get(0);
int sizeClasses = classes.size();
long sizeBoxes = features4OneImg.size(1);
ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();

for (long b = 0; b < sizeBoxes; b++) {
float maxClass = 0;
int maxIndex = 0;
for (int c = 4; c < sizeClasses; c++) {
float classProb = features4OneImg.getFloat(c, b);
if (classProb > maxClass) {
maxClass = classProb;
maxIndex = c;
}
}

if (maxClass > threshold) {
float xPos = features4OneImg.getFloat(0, b); // center x
float yPos = features4OneImg.getFloat(1, b); // center y
float w = features4OneImg.getFloat(2, b);
float h = features4OneImg.getFloat(3, b);
Rectangle rect =
new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
intermediateResults.add(
new IntermediateResult(classes.get(maxIndex), maxClass, maxIndex, rect));
}
}

return nms(intermediateResults);
}

/** The builder for {@link YoloV8Translator}. */
public static class Builder extends YoloV5Translator.Builder {
/**
* Builds the translator.
*
* @return the new translator
*/
@Override
public YoloV8Translator build() {
if (pipeline == null) {
addTransform(
array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255));
}
validate();
return new YoloV8Translator(this);
}
}
}
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;

import java.util.ArrayList;
import java.util.Map;

/**
* A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check
* here: https://github.com/ultralytics/ultralytics
*/
public class YoloV8Translator extends YoloV5Translator {

private int maxBoxes;

/**
* Constructs an ImageTranslator with the provided builder.
*
* @param builder the data to build with
*/
protected YoloV8Translator(Builder builder) {
super(builder);
maxBoxes = builder.maxBox;
}

/**
* Creates a builder to build a {@code YoloV8Translator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static YoloV8Translator.Builder builder(Map<String, ?> arguments) {
YoloV8Translator.Builder builder = new YoloV8Translator.Builder();
builder.configPreProcess(arguments);
builder.configPostProcess(arguments);

return builder;
}

/** {@inheritDoc} */
@Override
protected DetectedObjects processFromBoxOutput(NDList list) {
NDArray rawResult = list.get(0);
NDArray reshapedResult = rawResult.transpose();
Shape shape = reshapedResult.getShape();
float[] buf = reshapedResult.toFloatArray();
int numberRows = Math.toIntExact(shape.get(0));
int nClasses = Math.toIntExact(shape.get(1));

ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();
// reverse order search in heap; searches through #maxBoxes for optimization when set
for (int i = numberRows - 1; i > numberRows - maxBoxes; --i) {
int index = i * nClasses;
float maxClassProb = -1f;
int maxIndex = -1;
for (int c = 4; c < nClasses; c++) {
float classProb = buf[index + c];
if (classProb > maxClassProb) {
maxClassProb = classProb;
maxIndex = c;
}
}

if (maxClassProb > threshold) {
float xPos = buf[index]; // center x
float yPos = buf[index + 1]; // center y
float w = buf[index + 2];
float h = buf[index + 3];
Rectangle rect =
new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
intermediateResults.add(
new IntermediateResult(
classes.get(maxIndex), maxClassProb, maxIndex, rect));
}
}
return nms(intermediateResults);
}

/** The builder for {@link YoloV8Translator}. */
public static class Builder extends YoloV5Translator.Builder {

private int maxBox = 8400;

/**
* Builds the translator.
*
* @return the new translator
*/
@Override
public YoloV8Translator build() {
if (pipeline == null) {
addTransform(
array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255));
}
validate();
return new YoloV8Translator(this);
}

/** {@inheritDoc} */
@Override
protected void configPostProcess(Map<String, ?> arguments) {
super.configPostProcess(arguments);
maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,21 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV8TranslatorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** An example of inference using an yolov8 model. */
public final class Yolov8Detection {
Expand All @@ -51,82 +45,44 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
String classPath = System.getProperty("java.class.path");
String pathSeparator = System.getProperty("path.separator");
classPath = classPath.split(pathSeparator)[0];
Path modelPath = Paths.get(classPath + "/yolov8n.onnx");
Path imgPath = Paths.get(classPath + "/yolov8_test.jpg");
Path modelPath = Paths.get("src/test/resources/yolov8n.onnx");
Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg");
Image img = ImageFactory.getInstance().fromFile(imgPath);

Map<String, Object> arguments = new ConcurrentHashMap<>();
arguments.put("width", 640);
arguments.put("height", 640);
arguments.put("resize", "true");
arguments.put("toTensor", true);
arguments.put("applyRatio", true);
arguments.put("threshold", 0.6f);
arguments.put("synsetFileName", "yolov8_synset.txt");

YoloV8TranslatorFactory yoloV8TranslatorFactory = new YoloV8TranslatorFactory();
Translator<Image, DetectedObjects> translator =
yoloV8TranslatorFactory.newInstance(
Image.class, DetectedObjects.class, null, arguments);

Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.optModelPath(modelPath)
.optEngine("OnnxRuntime")
.optTranslator(translator)
.optArgument("width", 640)
.optArgument("height", 640)
.optArgument("resize", true)
.optArgument("toTensor", true)
.optArgument("applyRatio", true)
.optArgument("threshold", 0.6f)
// for performance optimization maxBox parameter can reduce number of
// considered boxes from 8400
.optArgument("maxBox", 1000)
.optArgument("synsetFileName", "yolov8_synset.txt")
.optTranslatorFactory(new YoloV8TranslatorFactory())
.optProgress(new ProgressBar())
.build();

DetectedObjects detectedObjects;
DetectedObject detectedObject;
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel();
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
Path outputPath = Paths.get(classPath + "/output");
Path outputPath = Paths.get("build/output");
Files.createDirectories(outputPath);

detectedObjects = predictor.predict(img);

if (detectedObjects.getNumberOfObjects() > 0) {
List<DetectedObject> detectedObjectList = detectedObjects.items();
for (DetectedObject object : detectedObjectList) {
detectedObject = object;
BoundingBox boundingBox = detectedObject.getBoundingBox();
Rectangle tectangle = boundingBox.getBounds();
logger.info(
detectedObject.getClassName()
+ " "
+ detectedObject.getProbability()
+ " "
+ tectangle.getX()
+ " "
+ tectangle.getY()
+ " "
+ tectangle.getWidth()
+ " "
+ tectangle.getHeight());
DetectedObjects detection = predictor.predict(img);
if (detection.getNumberOfObjects() > 0) {
img.drawBoundingBoxes(detection);
Path output = outputPath.resolve("yolov8_detected.png");
try (OutputStream os = Files.newOutputStream(output)) {
img.save(os, "png");
}

saveBoundingBoxImage(
img.resize(640, 640, false),
detectedObjects,
outputPath,
imgPath.toFile().getName());
logger.info("Detected object saved in: {}", output);
}

return detectedObjects;
return detection;
}
}

private static void saveBoundingBoxImage(
Image img, DetectedObjects detectedObjects, Path outputPath, String outputFileName)
throws IOException {
img.drawBoundingBoxes(detectedObjects);

Path imagePath = outputPath.resolve(outputFileName);
img.save(Files.newOutputStream(imagePath), "png");
}
}
Loading

0 comments on commit 0e6f143

Please sign in to comment.