Skip to content

Commit

Permalink
Merge pull request #9 from petebankhead/dnn
Browse files Browse the repository at this point in the history
Update to use simplified `DnnModel` interface
  • Loading branch information
petebankhead authored Nov 15, 2023
2 parents c32cdae + 6383340 commit 1e698ee
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 99 deletions.
145 changes: 56 additions & 89 deletions src/main/java/qupath/ext/djl/DjlDnnModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,31 @@

package qupath.ext.djl;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.io.UriResource;
import qupath.opencv.dnn.BlobFunction;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.dnn.PredictionFunction;

class DjlDnnModel implements DnnModel<NDList>, AutoCloseable, UriResource {
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

class DjlDnnModel implements DnnModel, AutoCloseable, UriResource {

private static final Logger logger = LoggerFactory.getLogger(DjlDnnModel.class);

Expand All @@ -58,10 +54,13 @@ class DjlDnnModel implements DnnModel<NDList>, AutoCloseable, UriResource {
private boolean lazyInitialize;

private transient boolean failed;
private transient ZooModel<NDList, NDList> model;
private transient Predictor<NDList, NDList> predictor;
private transient BlobFunction<NDList> blobFun;
private transient PredictionFunction<NDList> predictFun;
private transient ZooModel<Mat[], Mat[]> model;
private transient Predictor<Mat[], Mat[]> predictor;

/**
* Default layout for an OpenCV Mat
*/
private static final String DEFAULT_MAT_LAYOUT = getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);

DjlDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs, boolean lazyInitialize) {
this.engine = engine;
Expand All @@ -85,11 +84,13 @@ private void ensureInitialized() {
if (!failed && model == null) {
try {
logger.debug("Initializing DjlDnnModel");
model = DjlTools.loadModel(engine, uris.toArray(URI[]::new));
if (ndLayout != null && ndLayout.contains("N"))
predictor = model.newPredictor();
else
predictor = model.newPredictor(new NoopTranslator(Batchifier.STACK));
model = DjlTools.loadModel(engine,
Mat[].class, Mat[].class,
new ModelMatTranslator(),
uris.toArray(URI[]::new));
// if (ndLayout != null && ndLayout.contains("N"))
predictor = model.newPredictor();


// TODO: Better handling of missing inputs/outputs - we may need to run a prediction for this to work
if (this.inputs == null || this.inputs.isEmpty()) {
Expand All @@ -111,9 +112,6 @@ private void ensureInitialized() {
if (this.outputs == null || this.outputs.isEmpty())
outputs = Map.of(DnnModel.DEFAULT_OUTPUT_NAME, DnnShape.UNKNOWN_SHAPE);
}

blobFun = new BlobFun();
predictFun = new PredictFun();
} catch (Exception e) {
failed = true;
logger.debug("Failed to create DjlDnnModel");
Expand All @@ -125,36 +123,36 @@ private void ensureInitialized() {
}

@Override
public BlobFunction<NDList> getBlobFunction() {
ensureInitialized();
return blobFun;
public Map<String, Mat> predict(Map<String, Mat> blobs) {
synchronized (predictor) {
try {
var result = predictor.predict(blobs.values().stream().toArray(Mat[]::new));
return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]);
} catch (TranslateException e) {
throw new RuntimeException(e);
}
}
}

@Override
public BlobFunction<NDList> getBlobFunction(String name) {
ensureInitialized();
return blobFun;
public Mat predict(Mat mat) {
return DnnModel.super.predict(mat);
}

@Override
public PredictionFunction<NDList> getPredictionFunction() {
ensureInitialized();
return predictFun;
public List<Mat> batchPredict(List<? extends Mat> mats) {
return DnnModel.super.batchPredict(mats);
}

@Override
public synchronized void close() throws Exception {
if (model != null) {
model.close();
model = null;
blobFun = null;
predictFun = null;
logger.debug("Closed DjlDnnModel");
}
}

private static final String DEFAULT_MAT_LAYOUT = getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);

private static String getLayout(LayoutType... layouts) {
return LayoutType.toString(layouts);
}
Expand Down Expand Up @@ -192,63 +190,32 @@ private static String estimateOutputLayout(NDArray array) {



private class BlobFun implements BlobFunction<NDList> {
private class ModelMatTranslator implements NoBatchifyTranslator<Mat[], Mat[]> {

@Override
public NDList toBlob(Mat... mats) {
NDList list = new NDList();
String layout = ndLayout;
for (var mat : mats) {
// Try to figure out the layout
if (layout == null) {
layout = estimateInputLayout(mat);
}
list.add(DjlTools.matToNDArray(model.getNDManager(), mat, layout));
}
return list;
}

@Override
public List<Mat> fromBlob(NDList blob) {
public Mat[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
String layout;
if ((ndLayout == null || ndLayout.length() != blob.singletonOrThrow().getShape().dimension()) && !blob.isEmpty())
layout = estimateOutputLayout(blob.get(0));
if ((ndLayout == null || ndLayout.length() != list.singletonOrThrow().getShape().dimension()) && !list.isEmpty())
layout = estimateOutputLayout(list.get(0));
else
layout = ndLayout;
var output = blob.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).collect(Collectors.toList());
blob.close();
var output = list.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).toArray(Mat[]::new);
return output;
}

}

private class PredictFun implements PredictionFunction<NDList> {

@Override
public NDList predict(NDList input) {
try {
NDList output;
// TODO: Check whether to support per-thread predictors
synchronized (predictor) {
output = predictor.batchPredict(Collections.singletonList(input)).get(0);
public NDList processInput(TranslatorContext ctx, Mat... input) throws Exception {
NDList list = new NDList();
String layout = ndLayout;
for (var mat : input) {
// Try to figure out the layout
if (layout == null) {
layout = estimateInputLayout(mat);
}
input.close();
return output;
} catch (TranslateException e) {
throw new RuntimeException(e);
list.add(DjlTools.matToNDArray(ctx.getNDManager(), mat, layout));
}
return list;
}

@Override
public Map<String, DnnShape> getInputs() {
return inputs;
}

@Override
public Map<String, DnnShape> getOutputs(DnnShape... inputShapes) {
return outputs;
}

}

@Override
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/qupath/ext/djl/DjlDnnModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
*
* @author Pete Bankhead
*/
public class DjlDnnModelBuilder implements DnnModelBuilder<NDList> {
public class DjlDnnModelBuilder implements DnnModelBuilder {

private static String getEngineName(String framework) {
if (DjlTools.ALL_ENGINES.contains(framework))
Expand All @@ -41,6 +41,8 @@ private static String getEngineName(String framework) {
switch(framework) {
case DnnModelParams.FRAMEWORK_TENSORFLOW:
return DjlTools.ENGINE_TENSORFLOW;
case DnnModelParams.FRAMEWORK_TF_LITE:
return DjlTools.ENGINE_TFLITE;
case DnnModelParams.FRAMEWORK_ONNX_RUNTIME:
return DjlTools.ENGINE_ONNX_RUNTIME;
case DnnModelParams.FRAMEWORK_PYTORCH:
Expand Down Expand Up @@ -112,7 +114,7 @@ private static String axesToLayout(String axes) {
}

@Override
public DnnModel<NDList> buildModel(DnnModelParams params) {
public DnnModel buildModel(DnnModelParams params) {
var framework = params.getFramework();
String engineName = null;
if (framework == null) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/qupath/ext/djl/DjlExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class DjlExtension implements QuPathExtension, GitHubProject {

private final static Logger logger = LoggerFactory.getLogger(DjlExtension.class);

private final static DnnModelBuilder<?> builder = new DjlDnnModelBuilder();
private final static DnnModelBuilder builder = new DjlDnnModelBuilder();

static {
// Prevent downloading engines automatically
Expand Down
21 changes: 14 additions & 7 deletions src/main/java/qupath/ext/djl/DjlTools.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public class DjlTools {
* @param inputShape expected input shape, according to ndLayout
* @return
*/
public static DnnModel<NDList> createDnnModel(URI uri, String ndLayout, int[] inputShape) {
public static DnnModel createDnnModel(URI uri, String ndLayout, int[] inputShape) {
DnnShape shape = null;
if (inputShape != null)
shape = DnnShape.of(Arrays.stream(inputShape).mapToLong(i -> i).toArray());
Expand All @@ -169,7 +169,7 @@ public static DnnModel<NDList> createDnnModel(URI uri, String ndLayout, int[] in
* @param outputs outputs shapes, if known; if these are null, an attempt will be made to get them from DJL (but this does not always work)
* @return
*/
public static DnnModel<NDList> createDnnModel(String engine, URI uri, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
public static DnnModel createDnnModel(String engine, URI uri, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
return createDnnModel(engine, Collections.singletonList(uri), ndLayout, inputs, outputs);
}

Expand All @@ -182,7 +182,7 @@ public static DnnModel<NDList> createDnnModel(String engine, URI uri, String ndL
* @param outputs outputs shapes, if known; if these are null, an attempt will be made to get them from DJL (but this does not always work)
* @return
*/
private static DnnModel<NDList> createDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
private static DnnModel createDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
return new DjlDnnModel(engine, uris, ndLayout, inputs, outputs, false); // Eagerly initialize (so we know if it doesn't work sooner)
}

Expand Down Expand Up @@ -291,8 +291,12 @@ public static Engine getEngine(String name, boolean downloadIfNeeded) throws Ill
static DnnShape convertShape(Shape shape) {
return DnnShape.of(shape.getShape());
}

static ZooModel<NDList, NDList> loadModel(String engineName, URI... uris) throws ModelNotFoundException, MalformedModelException, IOException {
return loadModel(engineName, NDList.class, NDList.class, null, uris);
}

static <P, Q> ZooModel<P, Q> loadModel(String engineName, Class<P> inputClass, Class<Q> outputClass, Translator<P, Q> translator, URI... uris) throws ModelNotFoundException, MalformedModelException, IOException {
var sb = new StringBuilder();
boolean isFirst = true;
for (var uri : uris) {
Expand All @@ -307,13 +311,14 @@ static ZooModel<NDList, NDList> loadModel(String engineName, URI... uris) throws
}
sb.append(uri.toString());
}
return loadModel(engineName, sb.toString());
return loadModel(engineName, inputClass, outputClass, translator, sb.toString());
}

private static ZooModel<NDList, NDList> loadModel(String engineName, String urls) throws ModelNotFoundException, MalformedModelException, IOException {
private static <P, Q> ZooModel<P, Q> loadModel(String engineName, Class<P> inputClass, Class<Q> outputClass, Translator<P, Q> translator, String urls) throws ModelNotFoundException, MalformedModelException, IOException {
var builder = Criteria.builder()
.setTypes(NDList.class, NDList.class)
.setTypes(inputClass, outputClass)
.optModelUrls(urls)
.optTranslator(translator)
.optProgress(new ProgressBar());

String selectedEngine = null;
Expand All @@ -330,6 +335,8 @@ private static ZooModel<NDList, NDList> loadModel(String engineName, String urls
selectedEngine = "OnnxRuntime";
else if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine("PyTorch"))
selectedEngine = "PyTorch";
else if (urlString.endsWith(".tflite") && Engine.hasEngine("TFLite"))
selectedEngine = "TFLite";
else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") || urlString.endsWith("tf_savedmodel")) && Engine.hasEngine("TensorFlow"))
selectedEngine = "TensorFlow";
}
Expand Down

0 comments on commit 1e698ee

Please sign in to comment.