Skip to content

Commit

Permalink
replaced explicit CDMatchEntity mask and target types with generics -…
Browse files Browse the repository at this point in the history
… in order to allow running flywire vs hemibrain which is EM vs EM
  • Loading branch information
Cristian Goina committed Sep 11, 2024
1 parent 542d008 commit ff96b36
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@
import org.janelia.colormipsearch.model.AbstractNeuronEntity;
import org.janelia.colormipsearch.model.CDMatchEntity;
import org.janelia.colormipsearch.model.ComputeFileType;
import org.janelia.colormipsearch.model.EMNeuronEntity;
import org.janelia.colormipsearch.model.FileData;
import org.janelia.colormipsearch.model.LMNeuronEntity;
import org.janelia.colormipsearch.model.ProcessingType;
import org.janelia.colormipsearch.results.GroupedMatchedEntities;
import org.janelia.colormipsearch.results.ItemsHandling;
Expand Down Expand Up @@ -125,7 +123,7 @@ void execute() {
calculateAllGradientScores();
}

private void calculateAllGradientScores() {
private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> void calculateAllGradientScores() {
long startTime = System.currentTimeMillis();
ImageRegionDefinition excludedRegions = args.getRegionGeneratorForTextLabels();
ColorDepthSearchAlgorithmProvider<ShapeMatchScore> gradScoreAlgorithmProvider = ColorDepthSearchAlgorithmProviderFactory.createShapeMatchCDSAlgorithmProvider(
Expand All @@ -135,8 +133,8 @@ private void calculateAllGradientScores() {
loadQueryROIMask(args.queryROIMaskName),
excludedRegions
);
NeuronMatchesReader<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesReader = getCDMatchesReader();
NeuronMatchesWriter<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesWriter = getCDMatchesWriter();
NeuronMatchesReader<CDMatchEntity<M, T>> cdMatchesReader = getCDMatchesReader();
NeuronMatchesWriter<CDMatchEntity<M, T>> cdMatchesWriter = getCDMatchesWriter();
CDMIPsWriter cdmipsWriter = getCDMipsWriter();
Collection<String> matchesMasksToProcess = cdMatchesReader.listMatchesLocations(
args.masksLibraries.stream()
Expand Down Expand Up @@ -186,13 +184,13 @@ private void calculateAllGradientScores() {
(Runtime.getRuntime().totalMemory() / _1M));
}

private void processMasks(List<String> masksIds,
NeuronMatchesReader<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesReader,
NeuronMatchesWriter<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesWriter,
CDMIPsWriter cdmipsWriter,
ColorDepthSearchAlgorithmProvider<ShapeMatchScore> gradScoreAlgorithmProvider,
Executor executor,
String processingContext) {
private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> void processMasks(List<String> masksIds,
NeuronMatchesReader<CDMatchEntity<M, T>> cdMatchesReader,
NeuronMatchesWriter<CDMatchEntity<M, T>> cdMatchesWriter,
CDMIPsWriter cdmipsWriter,
ColorDepthSearchAlgorithmProvider<ShapeMatchScore> gradScoreAlgorithmProvider,
Executor executor,
String processingContext) {
LOG.info("Start {} - process {} masks", processingContext, masksIds.size());
long startProcessingPartitionTime = System.currentTimeMillis();
long updatedMatches = 0;
Expand All @@ -208,15 +206,15 @@ private void processMasks(List<String> masksIds,
(Runtime.getRuntime().totalMemory() / _1M));
}

private long processMask(String maskId,
NeuronMatchesReader<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesReader,
NeuronMatchesWriter<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesWriter,
private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> long processMask(String maskId,
NeuronMatchesReader<CDMatchEntity<M, T>> cdMatchesReader,
NeuronMatchesWriter<CDMatchEntity<M, T>> cdMatchesWriter,
CDMIPsWriter cdmipsWriter,
ColorDepthSearchAlgorithmProvider<ShapeMatchScore> gradScoreAlgorithmProvider,
Executor executor,
String processingContext) {
// read all matches for the current mask
List<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskId);
List<CDMatchEntity<M, T>> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskId);
long nPublishedNames = cdMatchesForMask.stream()
.map(cdm -> cdm.getMatchedImage().getPublishedName())
.distinct()
Expand All @@ -234,7 +232,7 @@ private long processMask(String maskId,
maskId,
(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
(Runtime.getRuntime().totalMemory() / _1M));
List<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesWithGradScores = calculateGradientScores(
List<CDMatchEntity<M, T>> cdMatchesWithGradScores = calculateGradientScores(
gradScoreAlgorithmProvider,
cdMatchesForMask,
args.gradScoreParallelism,
Expand Down Expand Up @@ -437,10 +435,10 @@ List<CDMatchEntity<M, T>> getCDMatchesForMask(NeuronMatchesReader<CDMatchEntity<

private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity>
List<CompletableFuture<List<CDMatchEntity<M, T>>>> startGradScoreComputations(M mask,
List<CDMatchEntity<M, T>> selectedMatches,
ColorDepthSearchAlgorithmProvider<ShapeMatchScore> gradScoreAlgorithmProvider,
int gradScoreParallelism,
Executor executor) {
List<CDMatchEntity<M, T>> selectedMatches,
ColorDepthSearchAlgorithmProvider<ShapeMatchScore> gradScoreAlgorithmProvider,
int gradScoreParallelism,
Executor executor) {
if (CollectionUtils.isEmpty(selectedMatches)) {
LOG.error("No matches were selected for {}", mask);
return Collections.emptyList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
import org.janelia.colormipsearch.model.AbstractNeuronEntity;
import org.janelia.colormipsearch.model.CDMatchEntity;
import org.janelia.colormipsearch.model.ComputeFileType;
import org.janelia.colormipsearch.model.EMNeuronEntity;
import org.janelia.colormipsearch.model.LMNeuronEntity;
import org.janelia.colormipsearch.model.ProcessingType;
import org.janelia.colormipsearch.results.ItemsHandling;
import org.slf4j.Logger;
Expand Down Expand Up @@ -89,9 +87,9 @@ void execute() {
normalizeAllGradientScores();
}

private void normalizeAllGradientScores() {
private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> void normalizeAllGradientScores() {
long startTime = System.currentTimeMillis();
NeuronMatchesReader<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesReader = getCDMatchesReader();
NeuronMatchesReader<CDMatchEntity<M, T>> cdMatchesReader = getCDMatchesReader();
Collection<String> matchesMasksToProcess = cdMatchesReader.listMatchesLocations(
args.masksLibraries.stream()
.map(larg -> new DataSourceParam()
Expand All @@ -115,7 +113,7 @@ private void normalizeAllGradientScores() {
// process each item from the current partition sequentially
indexedPartition.getValue().forEach(maskIdToProcess -> {
// read all matches for the current mask
List<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskIdToProcess);
List<CDMatchEntity<M, T>> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskIdToProcess);
// normalize the grad scores
LOG.info("Normalize grad scores for {} matches of {}", cdMatchesForMask.size(), maskIdToProcess);
updateNormalizedScores(cdMatchesForMask);
Expand Down

0 comments on commit ff96b36

Please sign in to comment.