From 39bc378b343635a864258ac95f54d2439a46cc9b Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 2 Oct 2024 20:58:17 -0700 Subject: [PATCH] [api] Visualize sam2 output for Sam2ServingTranslator --- .../cv/translator/Sam2ServingTranslator.java | 33 +++++++++++-- .../cv/translator/Sam2Translator.java | 46 ++++++++++++++++++- .../java/ai/djl/util/JsonSerializable.java | 6 +-- .../modality/cv/translator/Sam2InputTest.java | 4 +- 4 files changed, 78 insertions(+), 11 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java index 909c3d01f53..97432dd036e 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java @@ -14,7 +14,9 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; @@ -23,7 +25,13 @@ import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import org.apache.commons.codec.binary.Base64OutputStream; + +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.Map; /** A {@link Translator} that can serve SAM2 model. */ public class Sam2ServingTranslator implements Translator { @@ -47,11 +55,29 @@ public Batchifier getBatchifier() { /** {@inheritDoc} */ @Override - public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { + public Output processOutput(TranslatorContext ctx, NDList list) throws IOException { Output output = new Output(); + Sam2Input sam2 = (Sam2Input) ctx.getAttachment("input"); output.addProperty("Content-Type", "application/json"); - DetectedObjects obj = translator.processOutput(ctx, list); - output.add(BytesSupplier.wrapAsJson(obj)); + DetectedObjects detection = translator.processOutput(ctx, list); + Map ret = new LinkedHashMap<>(); // NOPMD + ret.put("result", detection); + if (sam2.isVisualize()) { + Image img = sam2.getImage(); + img.drawBoundingBoxes(detection, 0.8f); + img.drawMarks(sam2.getPoints()); + for (Rectangle rect : sam2.getBoxes()) { + img.drawRectangle(rect, 0xff0000, 6); + } + ByteArrayOutputStream os = new ByteArrayOutputStream(); + os.write("data:image/png;base64,".getBytes(StandardCharsets.UTF_8)); + Base64OutputStream bos = new Base64OutputStream(os, true, 0, null); + img.save(bos, "png"); + bos.close(); + os.close(); + ret.put("image", os.toString(StandardCharsets.UTF_8.name())); + } + output.add(BytesSupplier.wrapAsJson(ret)); return output; } @@ -64,6 +90,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } Sam2Input sam2 = Sam2Input.fromJson(data.getAsString()); + ctx.setAttachment("input", sam2); return translator.processInput(ctx, sam2); } catch (IOException e) { throw new TranslateException("Input is not an Image data type", e); diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java index 0c05c569f6d..f1b75dd95d8 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java @@ -205,6 +205,7 @@ public static final class Sam2Input { private Image image; private Point[] points; private int[] labels; + private boolean visualize; /** * Constructs a {@code Sam2Input} instance. @@ -214,9 +215,22 @@ public static final class Sam2Input { * @param labels the labels for the locations (0: background, 1: foreground) */ public Sam2Input(Image image, Point[] points, int[] labels) { + this(image, points, labels, false); + } + + /** + * Constructs a {@code Sam2Input} instance. + * + * @param image the image + * @param points the locations on the image + * @param labels the labels for the locations (0: background, 1: foreground) + * @param visualize true if output visualized image + */ + public Sam2Input(Image image, Point[] points, int[] labels, boolean visualize) { this.image = image; this.points = points; this.labels = labels; + this.visualize = visualize; } /** @@ -228,6 +242,15 @@ public Image getImage() { return image; } + /** + * Returns {@code true} if output visualized image. + * + * @return {@code true} if output visualized image + */ + public boolean isVisualize() { + return visualize; + } + /** * Returns the locations. * @@ -288,13 +311,16 @@ float[][] getLabels() { public static Sam2Input fromJson(String input) throws IOException { Prompt prompt = JsonUtils.GSON.fromJson(input, Prompt.class); if (prompt.image == null) { - throw new IllegalArgumentException("Missing url value"); + throw new IllegalArgumentException("Missing image value"); } if (prompt.prompt == null || prompt.prompt.length == 0) { throw new IllegalArgumentException("Missing prompt value"); } Image image = ImageFactory.getInstance().fromUrl(prompt.image); Builder builder = builder(image); + if (prompt.visualize) { + builder.visualize(); + } for (Location location : prompt.prompt) { int[] data = location.data; if ("point".equals(location.type)) { @@ -322,6 +348,7 @@ public static final class Builder { private Image image; private List points; private List labels; + private boolean visualize; Builder(Image image) { this.image = image; @@ -380,6 +407,16 @@ public Builder addBox(int x, int y, int right, int bottom) { return this; } + /** + * Sets the visualize for the {@code Sam2Input}. + * + * @return the builder + */ + public Builder visualize() { + visualize = true; + return this; + } + /** * Builds the {@code Sam2Input}. * @@ -388,7 +425,7 @@ public Builder addBox(int x, int y, int right, int bottom) { public Sam2Input build() { Point[] location = points.toArray(new Point[0]); int[] array = labels.stream().mapToInt(Integer::intValue).toArray(); - return new Sam2Input(image, location, array); + return new Sam2Input(image, location, array, visualize); } } @@ -413,6 +450,7 @@ public void setLabel(int label) { private static final class Prompt { String image; Location[] prompt; + boolean visualize; public void setImage(String image) { this.image = image; @@ -421,6 +459,10 @@ public void setImage(String image) { public void setPrompt(Location[] prompt) { this.prompt = prompt; } + + public void setVisualize(boolean visualize) { + this.visualize = visualize; + } } } } diff --git a/api/src/main/java/ai/djl/util/JsonSerializable.java b/api/src/main/java/ai/djl/util/JsonSerializable.java index b998e94e51a..e321ec21efd 100644 --- a/api/src/main/java/ai/djl/util/JsonSerializable.java +++ b/api/src/main/java/ai/djl/util/JsonSerializable.java @@ -49,11 +49,7 @@ default ByteBuffer toByteBuffer() { return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8)); } - /** - * Serializes the object to the {@code JsonElement}. - * - * @return the {@code JsonElement} - */ + /** {@inheritDoc} */ JsonElement serialize(); /** A customized Gson serializer to serialize the {@code Segmentation} object. */ diff --git a/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java b/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java index 9f9811bdaac..a742e92d49d 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java @@ -33,15 +33,17 @@ public void test() throws IOException { "{\"image\": \"" + file.toUri().toURL() + "\",\n" + + "\"visualize\": true,\n" + "\"prompt\": [\n" + " {\"type\": \"point\", \"data\": [575, 750], \"label\": 0},\n" + " {\"type\": \"rectangle\", \"data\": [425, 600, 700, 875]}\n" + "]}"; Sam2Input input = Sam2Input.fromJson(json); + Assert.assertTrue(input.isVisualize()); Assert.assertEquals(input.getPoints().size(), 1); Assert.assertEquals(input.getBoxes().size(), 1); - input = Sam2Input.builder(img).addPoint(0, 1).addBox(0, 0, 1, 1).build(); + input = Sam2Input.builder(img).visualize().addPoint(0, 1).addBox(0, 0, 1, 1).build(); Assert.assertEquals(input.getPoints().size(), 1); Assert.assertEquals(input.getBoxes().size(), 1); }