Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add progress indicator of cellpose segmentation #8

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 66 additions & 47 deletions src/napari_serialcellpose/_tests/test_widget.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from napari_serialcellpose import SerialWidget
import numpy as np
import pandas as pd

import pytest
from pathlib import Path
import time
import os
import tempfile
import shutil

def test_load_single_image(make_napari_viewer):

viewer = make_napari_viewer()
widget = SerialWidget(viewer)

mypath = Path('src/napari_serialcellpose/_tests/data/single_file_singlechannel/')

widget.file_list.update_from_path(mypath)
assert len(viewer.layers) == 0
widget.file_list.setCurrentRow(0)
assert len(viewer.layers) == 1

def test_analyse_single_image_no_save(make_napari_viewer):
def test_analyse_single_image_no_save(qtbot, make_napari_viewer):

viewer = make_napari_viewer()
widget = SerialWidget(viewer)

mypath = Path('src/napari_serialcellpose/_tests/data/single_file_singlechannel/')

widget.file_list.update_from_path(mypath)
Expand All @@ -38,24 +41,24 @@ def test_analyse_single_image_no_save(make_napari_viewer):
# set diameter and run segmentation
widget.spinbox_diameter.setValue(70)
widget._on_click_run_on_current()

# check that segmentatio has been added, named 'mask' and results in 33 objects
assert len(viewer.layers) == 2
def check_layers():
assert len(viewer.layers) == 2

qtbot.waitUntil(check_layers, timeout=30000)
assert viewer.layers[1].name == 'mask'
assert viewer.layers[1].data.max() == 33

def test_analyse_single_image_save(make_napari_viewer):
def test_analyse_single_image_save(qtbot, make_napari_viewer):

viewer = make_napari_viewer()
widget = SerialWidget(viewer)

mypath = Path('src/napari_serialcellpose/_tests/data/single_file_multichannel')

output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_single')
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)

output_dir = Path(tempfile.mkdtemp())

widget.file_list.update_from_path(mypath)
widget.output_folder = output_dir
widget.file_list.setCurrentRow(0)
Expand All @@ -76,23 +79,26 @@ def test_analyse_single_image_save(make_napari_viewer):
widget.check_props['size'].setChecked(True)
widget.check_props['intensity'].setChecked(True)
widget.qcbox_channel_analysis.setCurrentRow(1)

widget._on_click_run_on_current()

assert len(list(output_dir.glob('*mask.tif'))) == 1
def check_outputs():
assert len(list(output_dir.glob('*mask.tif'))) == 1

qtbot.waitUntil(check_outputs, timeout=30000)

assert len(list(output_dir.joinpath('tables').glob('*_props.csv'))) == 1
shutil.rmtree(output_dir)

def test_analyse_multi_image(make_napari_viewer):
def test_analyse_multi_image(qtbot, make_napari_viewer):
"""Test analysis of multiple images in a folder. No properties are analyzed."""
viewer = make_napari_viewer()
widget = SerialWidget(viewer)


mypath = Path('src/napari_serialcellpose/_tests/data/multifile/')
output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_multiple')
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)

output_dir = Path(tempfile.mkdtemp())


widget.file_list.update_from_path(mypath)
widget.output_folder = output_dir
Expand All @@ -101,23 +107,22 @@ def test_analyse_multi_image(make_napari_viewer):
widget.qcbox_model_choice.setCurrentIndex(
[widget.qcbox_model_choice.itemText(i) for i in range(widget.qcbox_model_choice.count())].index('cyto2'))
widget.spinbox_diameter.setValue(70)
widget._on_click_run_on_current()
widget._on_click_run_on_folder()

assert len(list(output_dir.glob('*mask.tif'))) == 1
def check_output():
assert len(list(output_dir.glob('*mask.tif'))) == 4

widget._on_click_run_on_folder()
assert len(list(output_dir.glob('*mask.tif'))) == 4
qtbot.waitUntil(check_output, timeout=30000)
shutil.rmtree(output_dir)

def test_analyse_multi_image_props(make_napari_viewer):
def test_analyse_multi_image_props(qtbot, make_napari_viewer):

viewer = make_napari_viewer()
widget = SerialWidget(viewer)

mypath = Path('src/napari_serialcellpose/_tests/data/multifile/')
output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_multiple3')
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)
output_dir = Path(tempfile.mkdtemp())


widget.file_list.update_from_path(mypath)
widget.output_folder = output_dir
Expand All @@ -133,27 +138,27 @@ def test_analyse_multi_image_props(make_napari_viewer):
widget.qcbox_channel_analysis.setCurrentRow(1)

widget._on_click_run_on_folder()
assert len(list(output_dir.glob('*mask.tif'))) == 4

# check that the properties are correct
def check_outputs():
assert len(list(output_dir.glob('*mask.tif'))) == 4

qtbot.waitUntil(check_outputs, timeout=30000)
# check that the properties are correct
df = pd.read_csv(output_dir.joinpath(
'tables',
Path(widget.file_list.currentItem().text()).stem + '_props.csv'
)
'tables',
Path(widget.file_list.currentItem().text()).stem + '_props.csv')
)
# check number of columns in df
assert df.shape[1] == 8
# check number of columns in df
assert df.shape[1] == 8


def test_analyse_multichannels(make_napari_viewer):
def test_analyse_multichannels(qtbot, make_napari_viewer):
"""Test that multiple channels can be used for intensity measurements"""
viewer = make_napari_viewer()
widget = SerialWidget(viewer)

mypath = Path('src/napari_serialcellpose/_tests/data/single_file_multichannel/')
output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_single_multichannelprops')
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)
output_dir = Path(tempfile.mkdtemp())

widget.file_list.update_from_path(mypath)
widget.output_folder = output_dir
Expand All @@ -171,6 +176,11 @@ def test_analyse_multichannels(make_napari_viewer):

widget._on_click_run_on_folder()

def check_outputs():
assert len(list(output_dir.glob('*mask.tif'))) == 1

qtbot.waitUntil(check_outputs, timeout=30000)

# check that the properties are correct
df = pd.read_csv(output_dir.joinpath(
'tables',
Expand All @@ -180,16 +190,14 @@ def test_analyse_multichannels(make_napari_viewer):
# check number of columns in df
assert df.shape[1] == 11

def test_mask_loading(make_napari_viewer):
def test_mask_loading(qtbot, make_napari_viewer):

viewer = make_napari_viewer()
widget = SerialWidget(viewer)

mypath = Path('src/napari_serialcellpose/_tests/data/multifile/')
output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_multiple2')
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)
output_dir = Path(tempfile.mkdtemp())


widget.file_list.update_from_path(mypath)
widget.output_folder = output_dir
Expand All @@ -200,6 +208,12 @@ def test_mask_loading(make_napari_viewer):
widget.spinbox_diameter.setValue(70)
widget._on_click_run_on_current()

# check that segmentation has been added
def check_layers():
assert len(viewer.layers) == 3

qtbot.waitUntil(check_layers, timeout=30000)

# check that when selecting the second file, we get only 2 channels and no mask
widget.file_list.setCurrentRow(1)
assert len(viewer.layers) == 2
Expand All @@ -208,7 +222,7 @@ def test_mask_loading(make_napari_viewer):
widget.file_list.setCurrentRow(0)
assert len(viewer.layers) == 3

def test_analyse_single_image_options_yml(make_napari_viewer):
def test_analyse_single_image_options_yml(qtbot, make_napari_viewer):

viewer = make_napari_viewer()
widget = SerialWidget(viewer)
Expand All @@ -232,5 +246,10 @@ def test_analyse_single_image_options_yml(make_napari_viewer):

widget._on_click_run_on_current()

# check that because of small diameter from yml file, we get only 5 elements
# check that segmentation has been added
def check_layers():
assert len(viewer.layers) == 3

qtbot.waitUntil(check_layers, timeout=30000)
# check that because of small diameter from yml file, we get only 7 elements
assert viewer.layers[2].data.max() == 7
99 changes: 57 additions & 42 deletions src/napari_serialcellpose/serial_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from qtpy.QtCore import Qt
import magicgui.widgets
from napari.layers import Image
from napari.qt import create_worker, thread_worker
from napari.utils.notifications import show_info


from .folder_list_widget import FolderList
from .serial_analysis import run_cellpose, load_props, load_allprops
Expand Down Expand Up @@ -313,31 +316,37 @@ def _on_click_run_on_current(self):
channel_analysis_names = [x.text() for x in self.qcbox_channel_analysis.selectedItems()]
reg_props = [k for k in self.check_props.keys() if self.check_props[k].isChecked()]

# run cellpose
segmented, props = run_cellpose(
image_path=image_path,
cellpose_model=self.cellpose_model,
output_path=self.output_folder,
diameter=diameter,
flow_threshold=self.flow_threshold.value(),
cellprob_threshold=self.cellprob_threshold.value(),
clear_border=self.check_clear_border.isChecked(),
channel_to_segment=channel_to_segment,
channel_helper=channel_helper,
channel_measure=channel_analysis,
channel_measure_names=channel_analysis_names,
properties=reg_props,
options_file=self.options_file_path,
force_no_rgb=self.check_no_rgb.isChecked(),
)

self.viewer.layers.events.inserted.disconnect(self._on_change_layers)
self.viewer.add_labels(segmented, name='mask')

if len(reg_props) > 0:
self.add_table_props(props)
# run cellpose
seg_worker = create_worker(run_cellpose,
image_path=image_path,
cellpose_model=self.cellpose_model,
output_path=self.output_folder,
diameter=diameter,
flow_threshold=self.flow_threshold.value(),
cellprob_threshold=self.cellprob_threshold.value(),
clear_border=self.check_clear_border.isChecked(),
channel_to_segment=channel_to_segment,
channel_helper=channel_helper,
channel_measure=channel_analysis,
channel_measure_names=channel_analysis_names,
properties=reg_props,
options_file=self.options_file_path,
force_no_rgb=self.check_no_rgb.isChecked(),
_progress=True
)

def get_seg_worker(output):
self.viewer.add_labels(output[0], name='mask')
if len(reg_props) > 0:
self.add_table_props(output[1])
self.viewer.layers.events.inserted.connect(self._on_change_layers)

show_info('Running Segmentation...')
seg_worker.start()
seg_worker.returned.connect(get_seg_worker)

self.viewer.layers.events.inserted.connect(self._on_change_layers)

def _on_click_run_on_folder(self):
"""Run cellpose on all images in folder"""
Expand All @@ -359,25 +368,31 @@ def _on_click_run_on_folder(self):
channel_analysis_names = [x.text() for x in self.qcbox_channel_analysis.selectedItems()]
reg_props = [k for k in self.check_props.keys() if self.check_props[k].isChecked()]

for batch in file_list_partition:
_, _ = run_cellpose(
image_path=batch,
cellpose_model=self.cellpose_model,
output_path=self.output_folder,
diameter=diameter,
flow_threshold=self.flow_threshold.value(),
cellprob_threshold=self.cellprob_threshold.value(),
clear_border=self.check_clear_border.isChecked(),
channel_to_segment=channel_to_segment,
channel_helper=channel_helper,
channel_measure=channel_analysis,
channel_measure_names=channel_analysis_names,
properties=reg_props,
options_file=self.options_file_path,
force_no_rgb=self.check_no_rgb.isChecked(),
)

self._on_click_load_summary()
@thread_worker(progress={'total': len(file_list_partition), 'desc': 'Running batch segmentation'})
def run_batch(file_list_partition):
for batch in file_list_partition:
yield run_cellpose(
image_path=batch,
cellpose_model=self.cellpose_model,
output_path=self.output_folder,
diameter=diameter,
flow_threshold=self.flow_threshold.value(),
cellprob_threshold=self.cellprob_threshold.value(),
clear_border=self.check_clear_border.isChecked(),
channel_to_segment=channel_to_segment,
channel_helper=channel_helper,
channel_measure=channel_analysis,
channel_measure_names=channel_analysis_names,
properties=reg_props,
options_file=self.options_file_path,
force_no_rgb=self.check_no_rgb.isChecked(),
)

show_info('Running Segmentation...')
batch_worker = run_batch(file_list_partition)
batch_worker.start()

batch_worker.returned.connect(self._on_click_load_summary)

def get_channels_to_use(self):
"""Translate selected channels in QCombox into indices.
Expand Down Expand Up @@ -609,4 +624,4 @@ def __init__(self, parent=None, col=1, row=1, width=6, height=4, dpi=100):
for j in range(col):
self.ax[i,j] = fig.add_subplot(row, col, count)
count+=1
super(MplCanvas, self).__init__(fig)
super(MplCanvas, self).__init__(fig)