Skip to content

Commit

Permalink
get matches and mips writer only once
Browse files Browse the repository at this point in the history
  • Loading branch information
Cristian Goina committed Apr 3, 2024
1 parent 2184f59 commit b7890cc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ public long write(List<R> matches) {
public long writeUpdates(List<R> matches, List<Function<R, Pair<String, ?>>> fieldSelectors) {
return neuronMatchesDao.updateExistingMatches(matches, fieldSelectors);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ private void calculateAllGradientScores() {
excludedRegions
);
NeuronMatchesReader<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesReader = getCDMatchesReader();
NeuronMatchesWriter<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> matchesWriter = getCDMatchesWriter();
CDMIPsWriter cdmipsWriter = getCDMipsWriter();
List<String> matchesMasksToProcess = cdMatchesReader.listMatchesLocations(
args.masksLibraries.stream()
.map(larg -> new DataSourceParam()
Expand All @@ -149,7 +151,7 @@ private void calculateAllGradientScores() {
partitionId,
partionMasks.size());
long startProcessingPartitionTime = System.currentTimeMillis();
// process each item from the current partition sequentially
// process each item from the current partition sequentially
partionMasks.forEach(maskIdToProcess -> {
// read all matches for the current mask
List<CDMatchEntity<EMNeuronEntity, LMNeuronEntity>> cdMatchesForMask = getCDMatchesForMask(cdMatchesReader, maskIdToProcess);
Expand All @@ -172,21 +174,21 @@ private void calculateAllGradientScores() {
cdMatchesWithGradScores.size(), maskIdToProcess,
(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
(Runtime.getRuntime().totalMemory() / _1M));
long writtenUpdates = updateCDMatches(cdMatchesWithGradScores);
long writtenUpdates = updateCDMatches(cdMatchesWithGradScores, matchesWriter);
LOG.info("Partition {} - updated {} grad scores for {} matches of {} - memory usage {}M out of {}M",
partitionId,
writtenUpdates, cdMatchesWithGradScores.size(), maskIdToProcess,
(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
(Runtime.getRuntime().totalMemory() / _1M));
if (StringUtils.isNotBlank(args.processingTag)) {
long updatesWithProcessedTag = updateProcessingTag(cdMatchesForMask);
long updatesWithProcessedTag = updateProcessingTag(cdMatchesForMask, cdmipsWriter);
LOG.info("Partition {} - set processing tag {} for {} mips - memory usage {}M out of {}M",
partitionId, args.getProcessingTag(), updatesWithProcessedTag,
(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / _1M + 1, // round up
(Runtime.getRuntime().totalMemory() / _1M));
}
System.gc(); // explicitly garbage collect
});
System.gc(); // explicitly garbage collect
LOG.info("Finished partition {} ({} items) in {}s - memory usage {}M out of {}M",
partitionId,
partionMasks.size(),
Expand Down Expand Up @@ -241,11 +243,11 @@ private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> NeuronM
}
}

private Optional<CDMIPsWriter> getCDMipsWriter() {
private CDMIPsWriter getCDMipsWriter() {
if (args.commonArgs.resultsStorage == StorageType.DB) {
return Optional.of(new DBCheckedCDMIPsWriter(getDaosProvider().getNeuronMetadataDao()));
return new DBCheckedCDMIPsWriter(getDaosProvider().getNeuronMetadataDao());
} else {
return Optional.empty();
return null;
}
}

Expand Down Expand Up @@ -292,8 +294,8 @@ List<CDMatchEntity<M, T>> calculateGradientScores(
return matchesWithGradScores;
}

private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> long updateCDMatches(List<CDMatchEntity<M, T>> cdMatches) {
NeuronMatchesWriter<CDMatchEntity<M, T>> matchesWriter = getCDMatchesWriter();
private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> long updateCDMatches(List<CDMatchEntity<M, T>> cdMatches,
NeuronMatchesWriter<CDMatchEntity<M, T>> matchesWriter) {
return matchesWriter.writeUpdates(
cdMatches,
Arrays.asList(
Expand All @@ -304,19 +306,20 @@ private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> long up
));
}

private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> long updateProcessingTag(List<CDMatchEntity<M, T>> cdMatches) {
Set<String> processingTags = Collections.singleton(args.getProcessingTag());
return getCDMipsWriter()
.map(cdmipsWriter -> {
Set<M> masksToUpdate = cdMatches.stream()
.map(AbstractMatchEntity::getMaskImage).collect(Collectors.toSet());
Set<T> targetsToUpdate = cdMatches.stream()
.map(AbstractMatchEntity::getMatchedImage).collect(Collectors.toSet());
cdmipsWriter.addProcessingTags(masksToUpdate, ProcessingType.GradientScore, processingTags);
cdmipsWriter.addProcessingTags(targetsToUpdate, ProcessingType.GradientScore, processingTags);
return masksToUpdate.size() + targetsToUpdate.size();
})
.orElse(0);
private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity> long updateProcessingTag(List<CDMatchEntity<M, T>> cdMatches,
CDMIPsWriter cdmipsWriter) {
if (cdmipsWriter != null) {
Set<String> processingTags = Collections.singleton(args.getProcessingTag());
Set<M> masksToUpdate = cdMatches.stream()
.map(AbstractMatchEntity::getMaskImage).collect(Collectors.toSet());
Set<T> targetsToUpdate = cdMatches.stream()
.map(AbstractMatchEntity::getMatchedImage).collect(Collectors.toSet());
cdmipsWriter.addProcessingTags(masksToUpdate, ProcessingType.GradientScore, processingTags);
cdmipsWriter.addProcessingTags(targetsToUpdate, ProcessingType.GradientScore, processingTags);
return masksToUpdate.size() + targetsToUpdate.size();
} else {
return 0;
}
}

private <M extends AbstractNeuronEntity, T extends AbstractNeuronEntity>
Expand Down

0 comments on commit b7890cc

Please sign in to comment.