diff --git a/colormipsearch-tools/pom.xml b/colormipsearch-tools/pom.xml
index 83d222fe..88357424 100644
--- a/colormipsearch-tools/pom.xml
+++ b/colormipsearch-tools/pom.xml
@@ -105,6 +105,10 @@
org.glassfish
jakarta.el
+
+ io.projectreactor
+ reactor-core
+
diff --git a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java
index 46e65194..d862f0a5 100644
--- a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java
+++ b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java
@@ -8,8 +8,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -58,6 +58,11 @@
import org.janelia.colormipsearch.results.MatchEntitiesGrouping;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.slf4j.MDC;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.ParallelFlux;
+import reactor.core.scheduler.Scheduler;
+import reactor.core.scheduler.Schedulers;
/**
* Command to calculate the gradient scores.
@@ -68,10 +73,6 @@ class CalculateGradientScoresCmd extends AbstractCmd {
@Parameters(commandDescription = "Calculate gradient scores")
static class CalculateGradientScoresArgs extends AbstractGradientScoresArgs {
- @Parameter(names = {"--gradscore-parallelism"},
- description = "Specifies the degree of parallelism used for computing the gradscore")
- int gradScoreParallelism;
-
@Parameter(names = {"--nBestLines"},
description = "Specifies the number of the top distinct lines to be used for gradient score")
int numberOfBestLines;
@@ -131,6 +132,7 @@ void execute() {
private void calculateAllGradientScores() {
long startTime = System.currentTimeMillis();
ImageRegionDefinition excludedRegions = args.getRegionGeneratorForTextLabels();
+ ExecutorService executorService = CmdUtils.createCmdExecutor(args.commonArgs);
ColorDepthSearchAlgorithmProvider gradScoreAlgorithmProvider = ColorDepthSearchAlgorithmProviderFactory.createShapeMatchCDSAlgorithmProvider(
args.mirrorMask,
args.negativeRadius,
@@ -157,7 +159,6 @@ private void ca
.setSize(larg.length))
.collect(Collectors.toList()));
int size = matchesMasksToProcess.size();
- Executor executor = CmdUtils.createCmdExecutor(args.commonArgs);
Stream>> masksPartitionedStream;
// partition masks
if (args.processPartitionsConcurrently) {
@@ -178,7 +179,7 @@ private void ca
cdMatchesWriter,
cdmipsWriter,
gradScoreAlgorithmProvider,
- executor,
+ executorService,
String.format("Partition %d", partitionId)
);
});
@@ -195,13 +196,13 @@ void processMasks(List masksIds,
NeuronMatchesWriter> cdMatchesWriter,
CDMIPsWriter cdmipsWriter,
ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider,
- Executor executor,
+ ExecutorService executorService,
String processingContext) {
- LOG.info("Start {} - process {} masks", processingContext, masksIds.size());
+ LOG.info("{} - start processing {} masks", processingContext, masksIds.size());
long startProcessingPartitionTime = System.currentTimeMillis();
long updatedMatches = 0;
for (String maskId : masksIds) {
- updatedMatches += processMask(maskId, cdMatchesReader, cdMatchesWriter, cdmipsWriter, shapeScoreAlgorithmProvider, executor, processingContext);
+ updatedMatches += processMask(maskId, cdMatchesReader, cdMatchesWriter, cdmipsWriter, shapeScoreAlgorithmProvider, executorService, processingContext);
}
LOG.info("Finished {} - completed {} masks, updated {} matches in {}s - memory usage {}M out of {}M",
processingContext,
@@ -218,8 +219,10 @@ long processMask(String maskId,
NeuronMatchesWriter> cdMatchesWriter,
CDMIPsWriter cdmipsWriter,
ColorDepthSearchAlgorithmProvider gradScoreAlgorithmProvider,
- Executor executor,
+ ExecutorService executorService,
String processingContext) {
+ long startProcessingMask = System.currentTimeMillis();
+ LOG.info("{} process mask {}", processingContext, maskId);
// read all matches for the current mask
List> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskId);
long nPublishedNames = cdMatchesForMask.stream()
@@ -239,25 +242,32 @@ long processMask(String maskId,
maskId,
(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
(Runtime.getRuntime().totalMemory() / _1M));
- List> cdMatchesWithGradScores = calculateGradientScores(
+ Scheduler scheduler = Schedulers.fromExecutorService(executorService);
+ Flux> cdMatchesWithGradScoresPublisher = calculateGradientScores(
gradScoreAlgorithmProvider,
cdMatchesForMask,
- args.gradScoreParallelism,
- executor);
- LOG.info("{} - completed grad scores for {} matches of {} - memory usage {}M out of {}M",
- processingContext,
- cdMatchesWithGradScores.size(),
- maskId,
- (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
- (Runtime.getRuntime().totalMemory() / _1M));
- long writtenUpdates = updateCDMatches(cdMatchesWithGradScores, cdMatchesWriter);
- LOG.info("{} - updated {} grad scores for {} matches of {} - memory usage {}M out of {}M",
- processingContext,
- writtenUpdates,
- cdMatchesWithGradScores.size(),
- maskId,
- (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
- (Runtime.getRuntime().totalMemory() / _1M));
+ scheduler);
+ AtomicLong nupdates = new AtomicLong(0);
+ cdMatchesWithGradScoresPublisher.collectList()
+ .map(cdMatchesWithGradScores -> {
+ LOG.info("{} - completed grad scores for {} matches of {} - memory usage {}M out of {}M",
+ processingContext,
+ cdMatchesWithGradScores.size(),
+ maskId,
+ (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
+ (Runtime.getRuntime().totalMemory() / _1M));
+ long writtenUpdates = updateCDMatches(cdMatchesWithGradScores, cdMatchesWriter);
+ LOG.info("{} - updated {} grad scores for {} matches of {} - memory usage {}M out of {}M",
+ processingContext,
+ writtenUpdates,
+ cdMatchesWithGradScores.size(),
+ maskId,
+ (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
+ (Runtime.getRuntime().totalMemory() / _1M));
+ nupdates.addAndGet(writtenUpdates);
+ return writtenUpdates;
+ })
+ .block();
if (StringUtils.isNotBlank(args.processingTag)) {
long updatesWithProcessedTag = updateProcessingTag(cdMatchesForMask, cdmipsWriter);
LOG.info("{} - set processing tag {} for {} mips - memory usage {}M out of {}M",
@@ -268,7 +278,7 @@ long processMask(String maskId,
(Runtime.getRuntime().totalMemory() / _1M));
}
System.gc(); // explicitly garbage collect
- return writtenUpdates;
+ return nupdates.get();
}
/**
@@ -322,50 +332,34 @@ private CDMIPsWriter getCDMipsWriter() {
/**
* The method calculates and updates the gradient scores for all color depth matches of the given mask MIP ID.
*
- * @param shapeScoreAlgorithmProvider grad score algorithm provider
- * @param cdMatches color depth matches for which the grad score will be computed
- * @param gradScoreParallelism the degree of parallelism used for calculating the grad score(s)
- * @param executor task executor
- * @param mask type
- * @param target type
+ * @param shapeScoreAlgorithmProvider shape score algorithm provider
+ * @param cdMatches color depth matches for which the grad score will be computed
+ * @param scheduler task scheduler
+ * @param mask type
+ * @param target type
*/
private
- List> calculateGradientScores(
+ Flux> calculateGradientScores(
ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider,
List> cdMatches,
- int gradScoreParallelism,
- Executor executor) {
- try {
- // group the matches by the mask input file - this is because we do not want to mix FL and non-FL neuron images for example
- List>> selectedMatchesGroupedByInput =
- MatchEntitiesGrouping.simpleGroupByMaskFields(
- cdMatches,
- Arrays.asList(
- AbstractNeuronEntity::getMipId,
- m -> m.getComputeFileName(ComputeFileType.InputColorDepthImage)
- )
- );
- List>>> gradScoreComputations = selectedMatchesGroupedByInput.stream()
- .flatMap(selectedMaskMatches -> startGradScoreComputations(
- selectedMaskMatches.getKey(),
- selectedMaskMatches.getItems(),
- shapeScoreAlgorithmProvider,
- gradScoreParallelism,
- executor
- ).stream())
- .collect(Collectors.toList());
- // wait for all computation to finish
- CompletableFuture>[] allGradScoreComputations = gradScoreComputations.toArray(new CompletableFuture>[0]);
- CompletableFuture.allOf(allGradScoreComputations).join();
-
- return gradScoreComputations.stream()
- .map(CompletableFuture::join)
- .flatMap(Collection::stream)
- .filter(CDMatchEntity::hasGradScore)
- .collect(Collectors.toList());
- } finally {
- System.gc(); // force gc
- }
+ Scheduler scheduler) {
+ // group the matches by the mask input file - this is because we do not want to mix FL and non-FL neuron images for example
+ List>> selectedMatchesGroupedByInput =
+ MatchEntitiesGrouping.simpleGroupByMaskFields(
+ cdMatches,
+ Arrays.asList(
+ AbstractNeuronEntity::getMipId,
+ m -> m.getComputeFileName(ComputeFileType.InputColorDepthImage)
+ )
+ );
+ return Flux.fromIterable(selectedMatchesGroupedByInput)
+ .flatMap(selectedMaskMatches -> startGradScoreComputations(
+ selectedMaskMatches.getKey(),
+ selectedMaskMatches.getItems(),
+ shapeScoreAlgorithmProvider,
+ scheduler))
+ .filter(CDMatchEntity::hasGradScore)
+ ;
}
private long updateCDMatches(List> cdMatches,
@@ -442,70 +436,63 @@ List> getCDMatchesForMask(NeuronMatchesReader
- List>>> startGradScoreComputations(M mask,
- List> selectedMatches,
- ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider,
- int gradScoreParallelism,
- Executor executor) {
+ ParallelFlux> startGradScoreComputations(M mask,
+ List> selectedMatches,
+ ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider,
+ Scheduler scheduler) {
if (CollectionUtils.isEmpty(selectedMatches)) {
LOG.error("No matches were selected for {}", mask);
- return Collections.emptyList();
+ return ParallelFlux.from();
}
LOG.info("Prepare gradient score computations for {} with {} matches", mask, selectedMatches.size());
LOG.info("Load query image {}", mask);
NeuronMIP maskImage = NeuronMIPUtils.loadComputeFile(mask, ComputeFileType.InputColorDepthImage);
if (NeuronMIPUtils.hasNoImageArray(maskImage)) {
LOG.error("No image found for {}", mask);
- return Collections.emptyList();
+ return ParallelFlux.from();
}
- ColorDepthSearchAlgorithm gradScoreAlgorithm =
+ ColorDepthSearchAlgorithm shapeScoreAlgorithm =
shapeScoreAlgorithmProvider.createColorDepthQuerySearchAlgorithmWithDefaultParams(
maskImage.getImageArray(),
args.maskThreshold,
args.borderSize);
- Set requiredVariantTypes = gradScoreAlgorithm.getRequiredTargetVariantTypes();
- // partition size is the inverse of the parallelism (if zero process all in parallel)
- int selectedMatchesPartitionSize = gradScoreParallelism > 0
- ? Math.max(selectedMatches.size() / gradScoreParallelism, 1)
- : 1;
- Map>> selectedMatchesParallelPartitions =
- ItemsHandling.partitionCollection(selectedMatches, selectedMatchesPartitionSize);
- LOG.info("Split {} matches of {} in {} partitions of size {}",
- selectedMatches.size(), mask, selectedMatchesParallelPartitions.size(), selectedMatchesPartitionSize);
- return selectedMatchesParallelPartitions
- .values().stream()
- .map(cdsMatches -> CompletableFuture.supplyAsync(() -> {
- for (CDMatchEntity cdsMatch : cdsMatches) {
- long startCalcTime = System.currentTimeMillis();
- T matchedTarget = cdsMatch.getMatchedImage();
- NeuronMIP matchedTargetImage = CachedMIPsUtils.loadMIP(matchedTarget, ComputeFileType.InputColorDepthImage);
- if (NeuronMIPUtils.hasImageArray(matchedTargetImage)) {
- LOG.debug("Calculate grad score between {} and {}",
- cdsMatch.getMaskImage(), cdsMatch.getMatchedImage());
- ShapeMatchScore gradScore = gradScoreAlgorithm.calculateMatchingScore(
- matchedTargetImage.getImageArray(),
- NeuronMIPUtils.getImageLoaders(
- matchedTarget,
- requiredVariantTypes,
- (n, cft) -> NeuronMIPUtils.getImageArray(CachedMIPsUtils.loadMIP(n, cft))
- )
- );
- cdsMatch.setBidirectionalAreaGap(gradScore.getBidirectionalAreaGap());
- cdsMatch.setGradientAreaGap(gradScore.getGradientAreaGap());
- cdsMatch.setHighExpressionArea(gradScore.getHighExpressionArea());
- cdsMatch.setNormalizedScore(gradScore.getNormalizedScore());
- LOG.debug("Finished calculating negative score between {} and {} in {}ms",
- cdsMatch.getMaskImage(), cdsMatch.getMatchedImage(), System.currentTimeMillis() - startCalcTime);
- } else {
- cdsMatch.setBidirectionalAreaGap(-1L);
- cdsMatch.setGradientAreaGap(-1L);
- cdsMatch.setHighExpressionArea(-1L);
- }
+ Set requiredVariantTypes = shapeScoreAlgorithm.getRequiredTargetVariantTypes();
+ return Flux.fromIterable(selectedMatches)
+ .parallel()
+ .runOn(scheduler)
+ .doOnNext(cdsMatch -> {
+ long startCalcTime = System.currentTimeMillis();
+ T matchedTarget = cdsMatch.getMatchedImage();
+ MDC.put("maskId", mask.getMipId() + "/" + mask.getEntityId());
+ MDC.put("targetId", matchedTarget.getMipId() + "/" + matchedTarget.getEntityId());
+ NeuronMIP matchedTargetImage = CachedMIPsUtils.loadMIP(matchedTarget, ComputeFileType.InputColorDepthImage);
+ if (NeuronMIPUtils.hasImageArray(matchedTargetImage)) {
+ LOG.debug("Calculate grad score between {} and {}",
+ cdsMatch.getMaskImage(), cdsMatch.getMatchedImage());
+ ShapeMatchScore gradScore = shapeScoreAlgorithm.calculateMatchingScore(
+ matchedTargetImage.getImageArray(),
+ NeuronMIPUtils.getImageLoaders(
+ matchedTarget,
+ requiredVariantTypes,
+ (n, cft) -> NeuronMIPUtils.getImageArray(CachedMIPsUtils.loadMIP(n, cft))
+ )
+ );
+ cdsMatch.setBidirectionalAreaGap(gradScore.getBidirectionalAreaGap());
+ cdsMatch.setGradientAreaGap(gradScore.getGradientAreaGap());
+ cdsMatch.setHighExpressionArea(gradScore.getHighExpressionArea());
+ cdsMatch.setNormalizedScore(gradScore.getNormalizedScore());
+ LOG.debug("Finished calculating negative score between {} and {} in {}ms",
+ cdsMatch.getMaskImage(), cdsMatch.getMatchedImage(), System.currentTimeMillis() - startCalcTime);
+ } else {
+ LOG.info("No image found for {}", matchedTarget);
+ cdsMatch.setBidirectionalAreaGap(-1L);
+ cdsMatch.setGradientAreaGap(-1L);
+ cdsMatch.setHighExpressionArea(-1L);
}
- System.gc();
- return cdsMatches;
- }, executor))
- .collect(Collectors.toList());
+ MDC.remove("maskId");
+ MDC.remove("targetId");
+ })
+ ;
}
private void updateNormalizedScores(List> cdMatches) {
diff --git a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java
index b41ddf3b..59b927cd 100644
--- a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java
+++ b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java
@@ -1,6 +1,6 @@
package org.janelia.colormipsearch.cmd;
-import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
@@ -11,7 +11,7 @@
public class CmdUtils {
private static final Logger LOG = LoggerFactory.getLogger(CmdUtils.class);
- static Executor createCmdExecutor(CommonArgs args) {
+ static ExecutorService createCmdExecutor(CommonArgs args) {
if (args.taskConcurrency > 0) {
LOG.info("Create a thread pool with {} worker threads ({} available processors for workstealing pool)",
args.taskConcurrency, Runtime.getRuntime().availableProcessors());
diff --git a/pom.xml b/pom.xml
index 447d3182..c31dbdc7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -339,6 +339,14 @@
5.1.2
+
+ io.projectreactor
+ reactor-bom
+ 2023.0.8
+ pom
+ import
+
+
junit