Skip to content

Commit

Permalink
Merge pull request #855 from MouseLand/stitch
Browse files Browse the repository at this point in the history
updating stitching to improve speed (#845)
  • Loading branch information
carsen-stringer authored Feb 12, 2024
2 parents 4f56619 + 6ce3bee commit 4741976
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 15 additions & 5 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
igood &= imfile[-len(imf):]==imf
if igood:
imn.append(im)

image_names = imn

# remove duplicates
Expand Down Expand Up @@ -240,29 +241,38 @@ def get_label_files(image_names, mask_filter, imf=None):
#elif os.path.exists(label_names[0] + '_seg.npy'):
# io_logger.info('labels found as _seg.npy files, converting to tif')
else:
raise ValueError('labels not provided with correct --mask_filter')
if not flow_names:
raise ValueError('labels not provided with correct --mask_filter')
else:
label_names = None
if not all([os.path.exists(label) for label in label_names]):
raise ValueError('labels not provided for all images in train and/or test set')
if not flow_names:
raise ValueError('labels not provided for all images in train and/or test set')
else:
label_names = None

return label_names, flow_names


def load_images_labels(tdir, mask_filter='_masks', image_filter=None, look_one_level_down=False, unet=False):
image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
nimg = len(image_names)

# training data
label_names, flow_names = get_label_files(image_names, mask_filter, imf=image_filter)

images = []
labels = []
k = 0
for n in range(nimg):
if os.path.isfile(label_names[n]):
if os.path.isfile(label_names[n]) or os.path.isfile(flow_names[0]):
print(image_names[n])
image = imread(image_names[n])
label = imread(label_names[n])
if label_names is not None:
label = imread(label_names[n])
if not unet:
if flow_names is not None and not unet:
print(flow_names[n])
flow = imread(flow_names[n])
if flow.shape[0]<4:
label = np.concatenate((label[np.newaxis,:,:], flow), axis=0)
Expand Down
4 changes: 2 additions & 2 deletions cellpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import logging
import os, warnings, time, tempfile, datetime, pathlib, shutil, subprocess
from tqdm import tqdm
from tqdm import tqdm, trange
from urllib.request import urlopen
from urllib.parse import urlparse
import cv2
Expand Down Expand Up @@ -403,7 +403,7 @@ def stitch3D(masks, stitch_threshold=0.25):
mmax = masks[0].max()
empty = 0

for i in range(len(masks)-1):
for i in trange(len(masks)-1):
iou = metrics._intersection_over_union(masks[i+1], masks[i])[1:,1:]
if not iou.size and empty == 0:
masks[i+1] = masks[i+1]
Expand Down

0 comments on commit 4741976

Please sign in to comment.