diff --git a/api/src/main/java/ai/djl/modality/cv/ImageFactory.java b/api/src/main/java/ai/djl/modality/cv/ImageFactory.java index 4868e15a9d2..ab11a6fee97 100644 --- a/api/src/main/java/ai/djl/modality/cv/ImageFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/ImageFactory.java @@ -17,12 +17,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URL; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Base64; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** * {@code ImageFactory} contains image creation mechanism on top of different platforms like PC and @@ -38,6 +42,8 @@ public abstract class ImageFactory { "ai.djl.android.core.BitmapImageFactory" }; + private static final Pattern URL_PATTERN = Pattern.compile("^data:image/\\w+;base64,(.+)"); + private static ImageFactory factory = newInstance(); private static ImageFactory newInstance() { @@ -98,13 +104,22 @@ public Image fromUrl(URL url) throws IOException { } /** - * Gets {@link Image} from URL. + * Gets {@link Image} from string representation. * - * @param url the String represent URL to load from + * @param url the String represent URL or base64 encoded image to load from * @return {@link Image} * @throws IOException URL is not valid. */ public Image fromUrl(String url) throws IOException { + Matcher m = URL_PATTERN.matcher(url); + if (m.matches()) { + // url="data:image/png;base64,..." + byte[] buf = Base64.getDecoder().decode(m.group(1)); + try (InputStream is = new ByteArrayInputStream(buf)) { + return fromInputStream(is); + } + } + URI uri = URI.create(url); if (uri.isAbsolute()) { return fromUrl(uri.toURL()); diff --git a/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslatorFactory.java index d48176aa4e3..59b14f5288c 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslatorFactory.java @@ -15,8 +15,9 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.translator.wrapper.FileImagePreProcesor; +import ai.djl.modality.cv.translator.wrapper.FileImagePreProcessor; import ai.djl.modality.cv.translator.wrapper.InputStreamImagePreProcessor; +import ai.djl.modality.cv.translator.wrapper.StringImagePreProcessor; import ai.djl.modality.cv.translator.wrapper.UrlImagePreProcessor; import ai.djl.translate.ExpansionTranslatorFactory; import ai.djl.translate.PreProcessor; @@ -44,8 +45,9 @@ public abstract class BaseImageTranslatorFactory extends ExpansionTranslatorF getPreprocessorExpansions() { Map, PreProcessor>> expansions = new ConcurrentHashMap<>(); - expansions.put(Path.class, FileImagePreProcesor::new); + expansions.put(Path.class, FileImagePreProcessor::new); expansions.put(URL.class, UrlImagePreProcessor::new); + expansions.put(String.class, StringImagePreProcessor::new); expansions.put(InputStream.class, InputStreamImagePreProcessor::new); return expansions; } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageServingTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageServingTranslator.java index 4df9b7e2f75..688662a6d44 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageServingTranslator.java @@ -23,6 +23,12 @@ import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import ai.djl.util.JsonSerializable; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import com.google.gson.JsonPrimitive; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -80,7 +86,34 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception if (data == null) { throw new TranslateException("Input data is empty."); } - Image image = factory.fromInputStream(new ByteArrayInputStream(data.getAsBytes())); + String contentType = input.getProperty("Content-Type", null); + if (contentType != null) { + int pos = contentType.indexOf(';'); + if (pos > 0) { + contentType = contentType.substring(0, pos); + } + } + Image image; + if ("application/json".equalsIgnoreCase(contentType)) { + try { + JsonElement element = + JsonUtils.GSON.fromJson(data.getAsString(), JsonElement.class); + if (element == null || !element.isJsonObject()) { + throw new TranslateException("Invalid JsonObject input."); + } + JsonObject obj = element.getAsJsonObject(); + JsonPrimitive url = obj.getAsJsonPrimitive("image_url"); + if (url == null) { + throw new TranslateException("Missing \"image_url\" in json."); + } + + image = factory.fromUrl(url.getAsString()); + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } else { + image = factory.fromInputStream(new ByteArrayInputStream(data.getAsBytes())); + } return translator.processInput(ctx, image); } catch (IOException e) { throw new TranslateException("Input is not an Image data type", e); diff --git a/api/src/main/java/ai/djl/modality/cv/translator/wrapper/FileImagePreProcesor.java b/api/src/main/java/ai/djl/modality/cv/translator/wrapper/FileImagePreProcessor.java similarity index 82% rename from api/src/main/java/ai/djl/modality/cv/translator/wrapper/FileImagePreProcesor.java rename to api/src/main/java/ai/djl/modality/cv/translator/wrapper/FileImagePreProcessor.java index ad91e8f5255..1f611ad8158 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/wrapper/FileImagePreProcesor.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/wrapper/FileImagePreProcessor.java @@ -21,16 +21,16 @@ import java.nio.file.Path; /** Built-in {@code PreProcessor} that provides image pre-processing from file path. */ -public class FileImagePreProcesor implements PreProcessor { +public class FileImagePreProcessor implements PreProcessor { private PreProcessor preProcessor; /** - * Creates a {@code FileImagePreProcesor} instance. + * Creates a {@code FileImagePreProcessor} instance. * - * @param preProcessor a {@code {@link PreProcessor}} that can process image + * @param preProcessor a {@code PreProcessor} that can process image */ - public FileImagePreProcesor(PreProcessor preProcessor) { + public FileImagePreProcessor(PreProcessor preProcessor) { this.preProcessor = preProcessor; } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/wrapper/StringImagePreProcessor.java b/api/src/main/java/ai/djl/modality/cv/translator/wrapper/StringImagePreProcessor.java new file mode 100644 index 00000000000..8532ee3a449 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/wrapper/StringImagePreProcessor.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator.wrapper; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDList; +import ai.djl.translate.PreProcessor; +import ai.djl.translate.TranslatorContext; + +/** + * Built-in {@code PreProcessor} that provides image pre-processing from url or base64 encoded + * string. + */ +public class StringImagePreProcessor implements PreProcessor { + + private PreProcessor preProcessor; + + /** + * Creates a {@code StringImagePreProcessor} instance. + * + * @param preProcessor a {@code PreProcessor} that can process image + */ + public StringImagePreProcessor(PreProcessor preProcessor) { + this.preProcessor = preProcessor; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, String input) throws Exception { + Image image = ImageFactory.getInstance().fromUrl(input); + return preProcessor.processInput(ctx, image); + } +} diff --git a/api/src/test/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactoryTest.java index 5427cebe7e4..3bb221f01f4 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactoryTest.java @@ -41,7 +41,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 5); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test @@ -68,6 +68,10 @@ public void testNewInstance() { factory.newInstance(Input.class, Output.class, model, arguments); Assert.assertTrue(translator5 instanceof ImageServingTranslator); + Translator translator6 = + factory.newInstance(String.class, Classifications.class, model, arguments); + Assert.assertTrue(translator6 instanceof BasicTranslator); + Assert.assertThrows( IllegalArgumentException.class, () -> factory.newInstance(Image.class, Output.class, model, arguments)); diff --git a/api/src/test/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslatorFactoryTest.java index 21eaedadc82..9ee64c172a5 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslatorFactoryTest.java @@ -41,7 +41,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 5); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloTranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloTranslatorFactoryTest.java index cfbcfb829b2..103ef382d98 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/YoloTranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloTranslatorFactoryTest.java @@ -41,7 +41,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 5); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV5TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV5TranslatorFactoryTest.java index 1ac8bf7b805..77b566be20d 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/YoloV5TranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV5TranslatorFactoryTest.java @@ -41,7 +41,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 5); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java index 8fbbae7301b..c6702497f31 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java @@ -41,7 +41,7 @@ public void setUp() { @Test public void testGetSupportedTypes() { - Assert.assertEquals(factory.getSupportedTypes().size(), 5); + Assert.assertEquals(factory.getSupportedTypes().size(), 6); } @Test diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/ActionRecognition.java b/examples/src/main/java/ai/djl/examples/inference/cv/ActionRecognition.java index a069ed358b2..ba90f45ad01 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/ActionRecognition.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/ActionRecognition.java @@ -15,8 +15,6 @@ import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; -import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.ImageFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; @@ -26,6 +24,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.net.URL; /** * An example of inference using an action recognition model. @@ -46,22 +45,20 @@ public static void main(String[] args) throws IOException, ModelException, Trans } public static Classifications predict() throws IOException, ModelException, TranslateException { - String url = "https://resources.djl.ai/images/action_dance.jpg"; - Image img = ImageFactory.getInstance().fromUrl(url); - + URL url = new URL("https://resources.djl.ai/images/action_dance.jpg"); // Use DJL PyTorch model zoo model - Criteria criteria = + Criteria criteria = Criteria.builder() - .setTypes(Image.class, Classifications.class) + .setTypes(URL.class, Classifications.class) .optModelUrls( "djl://ai.djl.pytorch/Human-Action-Recognition-VIT-Base-patch16-224") .optEngine("PyTorch") .optProgress(new ProgressBar()) .build(); - try (ZooModel inception = criteria.loadModel(); - Predictor action = inception.newPredictor()) { - return action.predict(img); + try (ZooModel inception = criteria.loadModel(); + Predictor action = inception.newPredictor()) { + return action.predict(url); } } } diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/MaskDetection.java b/examples/src/main/java/ai/djl/examples/inference/cv/MaskDetection.java index b9d71ed677b..9de62f05e9c 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/MaskDetection.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/MaskDetection.java @@ -56,9 +56,9 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran // modelUrl can be replaced to local onnx model file String modelUrl = "https://resources.djl.ai/demo/onnxruntime/face_mask_detection.zip"; - Criteria criteria = + Criteria criteria = Criteria.builder() - .setTypes(Image.class, DetectedObjects.class) + .setTypes(String.class, DetectedObjects.class) .optModelUrls(modelUrl) .optEngine("OnnxRuntime") .optTranslatorFactory(new YoloV5TranslatorFactory()) @@ -67,9 +67,9 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran .optArgument("rescale", true) // post process .build(); - try (ZooModel model = criteria.loadModel()) { - try (Predictor predictor = model.newPredictor()) { - DetectedObjects detection = predictor.predict(img); + try (ZooModel model = criteria.loadModel()) { + try (Predictor predictor = model.newPredictor()) { + DetectedObjects detection = predictor.predict(imageUrl); String outputDir = "build/output"; saveBoundingBoxImage(img, detection, outputDir); return detection; diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/cv/Yolov8Detection.java index 5aee1246ce5..eab5bdbfc00 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/Yolov8Detection.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/Yolov8Detection.java @@ -50,9 +50,9 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran // Use DJL OnnxRuntime model zoo model, model can be found: // https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/0.0.1/yolov8n.zip - Criteria criteria = + Criteria criteria = Criteria.builder() - .setTypes(Image.class, DetectedObjects.class) + .setTypes(Path.class, DetectedObjects.class) .optModelUrls("djl://ai.djl.onnxruntime/yolov8n") .optEngine("OnnxRuntime") .optArgument("width", 640) @@ -68,12 +68,12 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran .optProgress(new ProgressBar()) .build(); - try (ZooModel model = criteria.loadModel(); - Predictor predictor = model.newPredictor()) { + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { Path outputPath = Paths.get("build/output"); Files.createDirectories(outputPath); - DetectedObjects detection = predictor.predict(img); + DetectedObjects detection = predictor.predict(imgPath); if (detection.getNumberOfObjects() > 0) { img.drawBoundingBoxes(detection); Path output = outputPath.resolve("yolov8_detected.png"); diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/cv/BufferedImageFactoryTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/cv/BufferedImageFactoryTest.java index fcc9fd3a1ff..6a010c89641 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/cv/BufferedImageFactoryTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/cv/BufferedImageFactoryTest.java @@ -24,7 +24,9 @@ import org.testng.Assert; import org.testng.annotations.Test; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.util.Base64; public class BufferedImageFactoryTest { @@ -32,11 +34,16 @@ public class BufferedImageFactoryTest { public void testLoadImage() throws IOException { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { ImageFactory factory = ImageFactory.getInstance(); - Image img = - factory.fromUrl( - "https://github.com/deepjavalibrary/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg"); + Image img = factory.fromUrl("https://resources.djl.ai/images/dog_bike_car.jpg"); NDArray array = img.toNDArray(manager); Assert.assertEquals(new Shape(img.getHeight(), img.getWidth(), 3), array.getShape()); + + try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { + img.save(bos, "png"); + String data = Base64.getEncoder().encodeToString(bos.toByteArray()); + Image img2 = factory.fromUrl("data:image/png;base64," + data); + Assert.assertEquals(img2.getWidth(), 768); + } } }