Skip to content

Commit

Permalink
[api] Adds base64 image support for ImageTranslator (#3456)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Sep 6, 2024
1 parent 3e936cf commit 1ff8bc5
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 37 deletions.
19 changes: 17 additions & 2 deletions api/src/main/java/ai/djl/modality/cv/ImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Base64;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* {@code ImageFactory} contains image creation mechanism on top of different platforms like PC and
Expand All @@ -38,6 +42,8 @@ public abstract class ImageFactory {
"ai.djl.android.core.BitmapImageFactory"
};

private static final Pattern URL_PATTERN = Pattern.compile("^data:image/\\w+;base64,(.+)");

private static ImageFactory factory = newInstance();

private static ImageFactory newInstance() {
Expand Down Expand Up @@ -98,13 +104,22 @@ public Image fromUrl(URL url) throws IOException {
}

/**
* Gets {@link Image} from URL.
* Gets {@link Image} from string representation.
*
* @param url the String represent URL to load from
* @param url the String represent URL or base64 encoded image to load from
* @return {@link Image}
* @throws IOException URL is not valid.
*/
public Image fromUrl(String url) throws IOException {
Matcher m = URL_PATTERN.matcher(url);
if (m.matches()) {
// url="data:image/png;base64,..."
byte[] buf = Base64.getDecoder().decode(m.group(1));
try (InputStream is = new ByteArrayInputStream(buf)) {
return fromInputStream(is);
}
}

URI uri = URI.create(url);
if (uri.isAbsolute()) {
return fromUrl(uri.toURL());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.translator.wrapper.FileImagePreProcesor;
import ai.djl.modality.cv.translator.wrapper.FileImagePreProcessor;
import ai.djl.modality.cv.translator.wrapper.InputStreamImagePreProcessor;
import ai.djl.modality.cv.translator.wrapper.StringImagePreProcessor;
import ai.djl.modality.cv.translator.wrapper.UrlImagePreProcessor;
import ai.djl.translate.ExpansionTranslatorFactory;
import ai.djl.translate.PreProcessor;
Expand Down Expand Up @@ -44,8 +45,9 @@ public abstract class BaseImageTranslatorFactory<O> extends ExpansionTranslatorF
getPreprocessorExpansions() {
Map<Type, Function<PreProcessor<Image>, PreProcessor<?>>> expansions =
new ConcurrentHashMap<>();
expansions.put(Path.class, FileImagePreProcesor::new);
expansions.put(Path.class, FileImagePreProcessor::new);
expansions.put(URL.class, UrlImagePreProcessor::new);
expansions.put(String.class, StringImagePreProcessor::new);
expansions.put(InputStream.class, InputStreamImagePreProcessor::new);
return expansions;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonPrimitive;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand Down Expand Up @@ -80,7 +86,34 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
if (data == null) {
throw new TranslateException("Input data is empty.");
}
Image image = factory.fromInputStream(new ByteArrayInputStream(data.getAsBytes()));
String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
Image image;
if ("application/json".equalsIgnoreCase(contentType)) {
try {
JsonElement element =
JsonUtils.GSON.fromJson(data.getAsString(), JsonElement.class);
if (element == null || !element.isJsonObject()) {
throw new TranslateException("Invalid JsonObject input.");
}
JsonObject obj = element.getAsJsonObject();
JsonPrimitive url = obj.getAsJsonPrimitive("image_url");
if (url == null) {
throw new TranslateException("Missing \"image_url\" in json.");
}

image = factory.fromUrl(url.getAsString());
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
image = factory.fromInputStream(new ByteArrayInputStream(data.getAsBytes()));
}
return translator.processInput(ctx, image);
} catch (IOException e) {
throw new TranslateException("Input is not an Image data type", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import java.nio.file.Path;

/** Built-in {@code PreProcessor} that provides image pre-processing from file path. */
public class FileImagePreProcesor implements PreProcessor<Path> {
public class FileImagePreProcessor implements PreProcessor<Path> {

private PreProcessor<Image> preProcessor;

/**
* Creates a {@code FileImagePreProcesor} instance.
* Creates a {@code FileImagePreProcessor} instance.
*
* @param preProcessor a {@code {@link PreProcessor}} that can process image
* @param preProcessor a {@code PreProcessor} that can process image
*/
public FileImagePreProcesor(PreProcessor<Image> preProcessor) {
public FileImagePreProcessor(PreProcessor<Image> preProcessor) {
this.preProcessor = preProcessor;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator.wrapper;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDList;
import ai.djl.translate.PreProcessor;
import ai.djl.translate.TranslatorContext;

/**
* Built-in {@code PreProcessor} that provides image pre-processing from url or base64 encoded
* string.
*/
public class StringImagePreProcessor implements PreProcessor<String> {

private PreProcessor<Image> preProcessor;

/**
* Creates a {@code StringImagePreProcessor} instance.
*
* @param preProcessor a {@code PreProcessor} that can process image
*/
public StringImagePreProcessor(PreProcessor<Image> preProcessor) {
this.preProcessor = preProcessor;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, String input) throws Exception {
Image image = ImageFactory.getInstance().fromUrl(input);
return preProcessor.processInput(ctx, image);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 5);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand All @@ -68,6 +68,10 @@ public void testNewInstance() {
factory.newInstance(Input.class, Output.class, model, arguments);
Assert.assertTrue(translator5 instanceof ImageServingTranslator);

Translator<String, Classifications> translator6 =
factory.newInstance(String.class, Classifications.class, model, arguments);
Assert.assertTrue(translator6 instanceof BasicTranslator);

Assert.assertThrows(
IllegalArgumentException.class,
() -> factory.newInstance(Image.class, Output.class, model, arguments));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 5);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 5);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 5);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void setUp() {

@Test
public void testGetSupportedTypes() {
Assert.assertEquals(factory.getSupportedTypes().size(), 5);
Assert.assertEquals(factory.getSupportedTypes().size(), 6);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
Expand All @@ -26,6 +24,7 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.URL;

/**
* An example of inference using an action recognition model.
Expand All @@ -46,22 +45,20 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static Classifications predict() throws IOException, ModelException, TranslateException {
String url = "https://resources.djl.ai/images/action_dance.jpg";
Image img = ImageFactory.getInstance().fromUrl(url);

URL url = new URL("https://resources.djl.ai/images/action_dance.jpg");
// Use DJL PyTorch model zoo model
Criteria<Image, Classifications> criteria =
Criteria<URL, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.setTypes(URL.class, Classifications.class)
.optModelUrls(
"djl://ai.djl.pytorch/Human-Action-Recognition-VIT-Base-patch16-224")
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();

try (ZooModel<Image, Classifications> inception = criteria.loadModel();
Predictor<Image, Classifications> action = inception.newPredictor()) {
return action.predict(img);
try (ZooModel<URL, Classifications> inception = criteria.loadModel();
Predictor<URL, Classifications> action = inception.newPredictor()) {
return action.predict(url);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran

// modelUrl can be replaced to local onnx model file
String modelUrl = "https://resources.djl.ai/demo/onnxruntime/face_mask_detection.zip";
Criteria<Image, DetectedObjects> criteria =
Criteria<String, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.setTypes(String.class, DetectedObjects.class)
.optModelUrls(modelUrl)
.optEngine("OnnxRuntime")
.optTranslatorFactory(new YoloV5TranslatorFactory())
Expand All @@ -67,9 +67,9 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
.optArgument("rescale", true) // post process
.build();

try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
try (ZooModel<String, DetectedObjects> model = criteria.loadModel()) {
try (Predictor<String, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(imageUrl);
String outputDir = "build/output";
saveBoundingBoxImage(img, detection, outputDir);
return detection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran

// Use DJL OnnxRuntime model zoo model, model can be found:
// https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/0.0.1/yolov8n.zip
Criteria<Image, DetectedObjects> criteria =
Criteria<Path, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.setTypes(Path.class, DetectedObjects.class)
.optModelUrls("djl://ai.djl.onnxruntime/yolov8n")
.optEngine("OnnxRuntime")
.optArgument("width", 640)
Expand All @@ -68,12 +68,12 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
.optProgress(new ProgressBar())
.build();

try (ZooModel<Image, DetectedObjects> model = criteria.loadModel();
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
try (ZooModel<Path, DetectedObjects> model = criteria.loadModel();
Predictor<Path, DetectedObjects> predictor = model.newPredictor()) {
Path outputPath = Paths.get("build/output");
Files.createDirectories(outputPath);

DetectedObjects detection = predictor.predict(img);
DetectedObjects detection = predictor.predict(imgPath);
if (detection.getNumberOfObjects() > 0) {
img.drawBoundingBoxes(detection);
Path output = outputPath.resolve("yolov8_detected.png");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,26 @@
import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Base64;

public class BufferedImageFactoryTest {

@Test
public void testLoadImage() throws IOException {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
ImageFactory factory = ImageFactory.getInstance();
Image img =
factory.fromUrl(
"https://github.com/deepjavalibrary/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg");
Image img = factory.fromUrl("https://resources.djl.ai/images/dog_bike_car.jpg");
NDArray array = img.toNDArray(manager);
Assert.assertEquals(new Shape(img.getHeight(), img.getWidth(), 3), array.getShape());

try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
img.save(bos, "png");
String data = Base64.getEncoder().encodeToString(bos.toByteArray());
Image img2 = factory.fromUrl("data:image/png;base64," + data);
Assert.assertEquals(img2.getWidth(), 768);
}
}
}

Expand Down

0 comments on commit 1ff8bc5

Please sign in to comment.