diff --git a/colormipsearch-tools/pom.xml b/colormipsearch-tools/pom.xml index 83d222fe..88357424 100644 --- a/colormipsearch-tools/pom.xml +++ b/colormipsearch-tools/pom.xml @@ -105,6 +105,10 @@ org.glassfish jakarta.el + + io.projectreactor + reactor-core + diff --git a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java index 46e65194..d862f0a5 100644 --- a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java +++ b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CalculateGradientScoresCmd.java @@ -8,8 +8,8 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -58,6 +58,11 @@ import org.janelia.colormipsearch.results.MatchEntitiesGrouping; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.MDC; +import reactor.core.publisher.Flux; +import reactor.core.publisher.ParallelFlux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; /** * Command to calculate the gradient scores. @@ -68,10 +73,6 @@ class CalculateGradientScoresCmd extends AbstractCmd { @Parameters(commandDescription = "Calculate gradient scores") static class CalculateGradientScoresArgs extends AbstractGradientScoresArgs { - @Parameter(names = {"--gradscore-parallelism"}, - description = "Specifies the degree of parallelism used for computing the gradscore") - int gradScoreParallelism; - @Parameter(names = {"--nBestLines"}, description = "Specifies the number of the top distinct lines to be used for gradient score") int numberOfBestLines; @@ -131,6 +132,7 @@ void execute() { private void calculateAllGradientScores() { long startTime = System.currentTimeMillis(); ImageRegionDefinition excludedRegions = args.getRegionGeneratorForTextLabels(); + ExecutorService executorService = CmdUtils.createCmdExecutor(args.commonArgs); ColorDepthSearchAlgorithmProvider gradScoreAlgorithmProvider = ColorDepthSearchAlgorithmProviderFactory.createShapeMatchCDSAlgorithmProvider( args.mirrorMask, args.negativeRadius, @@ -157,7 +159,6 @@ private void ca .setSize(larg.length)) .collect(Collectors.toList())); int size = matchesMasksToProcess.size(); - Executor executor = CmdUtils.createCmdExecutor(args.commonArgs); Stream>> masksPartitionedStream; // partition masks if (args.processPartitionsConcurrently) { @@ -178,7 +179,7 @@ private void ca cdMatchesWriter, cdmipsWriter, gradScoreAlgorithmProvider, - executor, + executorService, String.format("Partition %d", partitionId) ); }); @@ -195,13 +196,13 @@ void processMasks(List masksIds, NeuronMatchesWriter> cdMatchesWriter, CDMIPsWriter cdmipsWriter, ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider, - Executor executor, + ExecutorService executorService, String processingContext) { - LOG.info("Start {} - process {} masks", processingContext, masksIds.size()); + LOG.info("{} - start processing {} masks", processingContext, masksIds.size()); long startProcessingPartitionTime = System.currentTimeMillis(); long updatedMatches = 0; for (String maskId : masksIds) { - updatedMatches += processMask(maskId, cdMatchesReader, cdMatchesWriter, cdmipsWriter, shapeScoreAlgorithmProvider, executor, processingContext); + updatedMatches += processMask(maskId, cdMatchesReader, cdMatchesWriter, cdmipsWriter, shapeScoreAlgorithmProvider, executorService, processingContext); } LOG.info("Finished {} - completed {} masks, updated {} matches in {}s - memory usage {}M out of {}M", processingContext, @@ -218,8 +219,10 @@ long processMask(String maskId, NeuronMatchesWriter> cdMatchesWriter, CDMIPsWriter cdmipsWriter, ColorDepthSearchAlgorithmProvider gradScoreAlgorithmProvider, - Executor executor, + ExecutorService executorService, String processingContext) { + long startProcessingMask = System.currentTimeMillis(); + LOG.info("{} process mask {}", processingContext, maskId); // read all matches for the current mask List> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskId); long nPublishedNames = cdMatchesForMask.stream() @@ -239,25 +242,32 @@ long processMask(String maskId, maskId, (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up (Runtime.getRuntime().totalMemory() / _1M)); - List> cdMatchesWithGradScores = calculateGradientScores( + Scheduler scheduler = Schedulers.fromExecutorService(executorService); + Flux> cdMatchesWithGradScoresPublisher = calculateGradientScores( gradScoreAlgorithmProvider, cdMatchesForMask, - args.gradScoreParallelism, - executor); - LOG.info("{} - completed grad scores for {} matches of {} - memory usage {}M out of {}M", - processingContext, - cdMatchesWithGradScores.size(), - maskId, - (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up - (Runtime.getRuntime().totalMemory() / _1M)); - long writtenUpdates = updateCDMatches(cdMatchesWithGradScores, cdMatchesWriter); - LOG.info("{} - updated {} grad scores for {} matches of {} - memory usage {}M out of {}M", - processingContext, - writtenUpdates, - cdMatchesWithGradScores.size(), - maskId, - (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up - (Runtime.getRuntime().totalMemory() / _1M)); + scheduler); + AtomicLong nupdates = new AtomicLong(0); + cdMatchesWithGradScoresPublisher.collectList() + .map(cdMatchesWithGradScores -> { + LOG.info("{} - completed grad scores for {} matches of {} - memory usage {}M out of {}M", + processingContext, + cdMatchesWithGradScores.size(), + maskId, + (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up + (Runtime.getRuntime().totalMemory() / _1M)); + long writtenUpdates = updateCDMatches(cdMatchesWithGradScores, cdMatchesWriter); + LOG.info("{} - updated {} grad scores for {} matches of {} - memory usage {}M out of {}M", + processingContext, + writtenUpdates, + cdMatchesWithGradScores.size(), + maskId, + (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up + (Runtime.getRuntime().totalMemory() / _1M)); + nupdates.addAndGet(writtenUpdates); + return writtenUpdates; + }) + .block(); if (StringUtils.isNotBlank(args.processingTag)) { long updatesWithProcessedTag = updateProcessingTag(cdMatchesForMask, cdmipsWriter); LOG.info("{} - set processing tag {} for {} mips - memory usage {}M out of {}M", @@ -268,7 +278,7 @@ long processMask(String maskId, (Runtime.getRuntime().totalMemory() / _1M)); } System.gc(); // explicitly garbage collect - return writtenUpdates; + return nupdates.get(); } /** @@ -322,50 +332,34 @@ private CDMIPsWriter getCDMipsWriter() { /** * The method calculates and updates the gradient scores for all color depth matches of the given mask MIP ID. * - * @param shapeScoreAlgorithmProvider grad score algorithm provider - * @param cdMatches color depth matches for which the grad score will be computed - * @param gradScoreParallelism the degree of parallelism used for calculating the grad score(s) - * @param executor task executor - * @param mask type - * @param target type + * @param shapeScoreAlgorithmProvider shape score algorithm provider + * @param cdMatches color depth matches for which the grad score will be computed + * @param scheduler task scheduler + * @param mask type + * @param target type */ private - List> calculateGradientScores( + Flux> calculateGradientScores( ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider, List> cdMatches, - int gradScoreParallelism, - Executor executor) { - try { - // group the matches by the mask input file - this is because we do not want to mix FL and non-FL neuron images for example - List>> selectedMatchesGroupedByInput = - MatchEntitiesGrouping.simpleGroupByMaskFields( - cdMatches, - Arrays.asList( - AbstractNeuronEntity::getMipId, - m -> m.getComputeFileName(ComputeFileType.InputColorDepthImage) - ) - ); - List>>> gradScoreComputations = selectedMatchesGroupedByInput.stream() - .flatMap(selectedMaskMatches -> startGradScoreComputations( - selectedMaskMatches.getKey(), - selectedMaskMatches.getItems(), - shapeScoreAlgorithmProvider, - gradScoreParallelism, - executor - ).stream()) - .collect(Collectors.toList()); - // wait for all computation to finish - CompletableFuture[] allGradScoreComputations = gradScoreComputations.toArray(new CompletableFuture[0]); - CompletableFuture.allOf(allGradScoreComputations).join(); - - return gradScoreComputations.stream() - .map(CompletableFuture::join) - .flatMap(Collection::stream) - .filter(CDMatchEntity::hasGradScore) - .collect(Collectors.toList()); - } finally { - System.gc(); // force gc - } + Scheduler scheduler) { + // group the matches by the mask input file - this is because we do not want to mix FL and non-FL neuron images for example + List>> selectedMatchesGroupedByInput = + MatchEntitiesGrouping.simpleGroupByMaskFields( + cdMatches, + Arrays.asList( + AbstractNeuronEntity::getMipId, + m -> m.getComputeFileName(ComputeFileType.InputColorDepthImage) + ) + ); + return Flux.fromIterable(selectedMatchesGroupedByInput) + .flatMap(selectedMaskMatches -> startGradScoreComputations( + selectedMaskMatches.getKey(), + selectedMaskMatches.getItems(), + shapeScoreAlgorithmProvider, + scheduler)) + .filter(CDMatchEntity::hasGradScore) + ; } private long updateCDMatches(List> cdMatches, @@ -442,70 +436,63 @@ List> getCDMatchesForMask(NeuronMatchesReader - List>>> startGradScoreComputations(M mask, - List> selectedMatches, - ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider, - int gradScoreParallelism, - Executor executor) { + ParallelFlux> startGradScoreComputations(M mask, + List> selectedMatches, + ColorDepthSearchAlgorithmProvider shapeScoreAlgorithmProvider, + Scheduler scheduler) { if (CollectionUtils.isEmpty(selectedMatches)) { LOG.error("No matches were selected for {}", mask); - return Collections.emptyList(); + return ParallelFlux.from(); } LOG.info("Prepare gradient score computations for {} with {} matches", mask, selectedMatches.size()); LOG.info("Load query image {}", mask); NeuronMIP maskImage = NeuronMIPUtils.loadComputeFile(mask, ComputeFileType.InputColorDepthImage); if (NeuronMIPUtils.hasNoImageArray(maskImage)) { LOG.error("No image found for {}", mask); - return Collections.emptyList(); + return ParallelFlux.from(); } - ColorDepthSearchAlgorithm gradScoreAlgorithm = + ColorDepthSearchAlgorithm shapeScoreAlgorithm = shapeScoreAlgorithmProvider.createColorDepthQuerySearchAlgorithmWithDefaultParams( maskImage.getImageArray(), args.maskThreshold, args.borderSize); - Set requiredVariantTypes = gradScoreAlgorithm.getRequiredTargetVariantTypes(); - // partition size is the inverse of the parallelism (if zero process all in parallel) - int selectedMatchesPartitionSize = gradScoreParallelism > 0 - ? Math.max(selectedMatches.size() / gradScoreParallelism, 1) - : 1; - Map>> selectedMatchesParallelPartitions = - ItemsHandling.partitionCollection(selectedMatches, selectedMatchesPartitionSize); - LOG.info("Split {} matches of {} in {} partitions of size {}", - selectedMatches.size(), mask, selectedMatchesParallelPartitions.size(), selectedMatchesPartitionSize); - return selectedMatchesParallelPartitions - .values().stream() - .map(cdsMatches -> CompletableFuture.supplyAsync(() -> { - for (CDMatchEntity cdsMatch : cdsMatches) { - long startCalcTime = System.currentTimeMillis(); - T matchedTarget = cdsMatch.getMatchedImage(); - NeuronMIP matchedTargetImage = CachedMIPsUtils.loadMIP(matchedTarget, ComputeFileType.InputColorDepthImage); - if (NeuronMIPUtils.hasImageArray(matchedTargetImage)) { - LOG.debug("Calculate grad score between {} and {}", - cdsMatch.getMaskImage(), cdsMatch.getMatchedImage()); - ShapeMatchScore gradScore = gradScoreAlgorithm.calculateMatchingScore( - matchedTargetImage.getImageArray(), - NeuronMIPUtils.getImageLoaders( - matchedTarget, - requiredVariantTypes, - (n, cft) -> NeuronMIPUtils.getImageArray(CachedMIPsUtils.loadMIP(n, cft)) - ) - ); - cdsMatch.setBidirectionalAreaGap(gradScore.getBidirectionalAreaGap()); - cdsMatch.setGradientAreaGap(gradScore.getGradientAreaGap()); - cdsMatch.setHighExpressionArea(gradScore.getHighExpressionArea()); - cdsMatch.setNormalizedScore(gradScore.getNormalizedScore()); - LOG.debug("Finished calculating negative score between {} and {} in {}ms", - cdsMatch.getMaskImage(), cdsMatch.getMatchedImage(), System.currentTimeMillis() - startCalcTime); - } else { - cdsMatch.setBidirectionalAreaGap(-1L); - cdsMatch.setGradientAreaGap(-1L); - cdsMatch.setHighExpressionArea(-1L); - } + Set requiredVariantTypes = shapeScoreAlgorithm.getRequiredTargetVariantTypes(); + return Flux.fromIterable(selectedMatches) + .parallel() + .runOn(scheduler) + .doOnNext(cdsMatch -> { + long startCalcTime = System.currentTimeMillis(); + T matchedTarget = cdsMatch.getMatchedImage(); + MDC.put("maskId", mask.getMipId() + "/" + mask.getEntityId()); + MDC.put("targetId", matchedTarget.getMipId() + "/" + matchedTarget.getEntityId()); + NeuronMIP matchedTargetImage = CachedMIPsUtils.loadMIP(matchedTarget, ComputeFileType.InputColorDepthImage); + if (NeuronMIPUtils.hasImageArray(matchedTargetImage)) { + LOG.debug("Calculate grad score between {} and {}", + cdsMatch.getMaskImage(), cdsMatch.getMatchedImage()); + ShapeMatchScore gradScore = shapeScoreAlgorithm.calculateMatchingScore( + matchedTargetImage.getImageArray(), + NeuronMIPUtils.getImageLoaders( + matchedTarget, + requiredVariantTypes, + (n, cft) -> NeuronMIPUtils.getImageArray(CachedMIPsUtils.loadMIP(n, cft)) + ) + ); + cdsMatch.setBidirectionalAreaGap(gradScore.getBidirectionalAreaGap()); + cdsMatch.setGradientAreaGap(gradScore.getGradientAreaGap()); + cdsMatch.setHighExpressionArea(gradScore.getHighExpressionArea()); + cdsMatch.setNormalizedScore(gradScore.getNormalizedScore()); + LOG.debug("Finished calculating negative score between {} and {} in {}ms", + cdsMatch.getMaskImage(), cdsMatch.getMatchedImage(), System.currentTimeMillis() - startCalcTime); + } else { + LOG.info("No image found for {}", matchedTarget); + cdsMatch.setBidirectionalAreaGap(-1L); + cdsMatch.setGradientAreaGap(-1L); + cdsMatch.setHighExpressionArea(-1L); } - System.gc(); - return cdsMatches; - }, executor)) - .collect(Collectors.toList()); + MDC.remove("maskId"); + MDC.remove("targetId"); + }) + ; } private void updateNormalizedScores(List> cdMatches) { diff --git a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java index b41ddf3b..59b927cd 100644 --- a/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java +++ b/colormipsearch-tools/src/main/java/org/janelia/colormipsearch/cmd/CmdUtils.java @@ -1,6 +1,6 @@ package org.janelia.colormipsearch.cmd; -import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -11,7 +11,7 @@ public class CmdUtils { private static final Logger LOG = LoggerFactory.getLogger(CmdUtils.class); - static Executor createCmdExecutor(CommonArgs args) { + static ExecutorService createCmdExecutor(CommonArgs args) { if (args.taskConcurrency > 0) { LOG.info("Create a thread pool with {} worker threads ({} available processors for workstealing pool)", args.taskConcurrency, Runtime.getRuntime().availableProcessors()); diff --git a/pom.xml b/pom.xml index 447d3182..c31dbdc7 100644 --- a/pom.xml +++ b/pom.xml @@ -339,6 +339,14 @@ 5.1.2 + + io.projectreactor + reactor-bom + 2023.0.8 + pom + import + + junit