Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1119 reducing time to merge in merge maskspy wi ez segmenter #1123

Merged
79 changes: 71 additions & 8 deletions src/ark/segmentation/ez_seg/merge_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
from skimage.io import imread
from skimage.morphology import label
from skimage.measure import regionprops_table
import pandas as pd
from alpineer import load_utils, image_utils
from ark.segmentation.ez_seg.ez_seg_utils import log_creator

Expand All @@ -16,6 +18,7 @@ def merge_masks_seq(
cell_mask_dir: Union[pathlib.Path, str],
cell_mask_suffix: str,
overlap_percent_threshold: int,
expansion_factor: int,
save_path: Union[pathlib.Path, str],
log_dir: Union[pathlib.Path, str]
) -> None:
Expand All @@ -31,6 +34,7 @@ def merge_masks_seq(
cell_mask_dir (Union[str, pathlib.Path]): Path to where the original cell masks are located.
cell_mask_suffix (str): Name of the cell type you are merging. Usually "whole_cell".
overlap_percent_threshold (int): Percent overlap of total pixel area needed fo object to be merged to a cell.
expansion_factor (int): How many pixels out from an objects bbox a cell should be looked for.
save_path (Union[str, pathlib.Path]): The directory where merged masks and remaining cell mask will be saved.
log_dir (Union[str, pathlib.Path]): The directory to save log information to.
"""
Expand Down Expand Up @@ -61,13 +65,9 @@ def merge_masks_seq(
# for each object type in the fov, merge with cell masks
for obj in fov_object_names:
curr_object_mask = imread(fname=(object_mask_dir / obj))
remaining_cells = merge_masks_single(
object_mask=curr_object_mask,
cell_mask=curr_cell_mask,
overlap_thresh=overlap_percent_threshold,
object_name=obj,
mask_save_path=save_path,
)
remaining_cells = merge_masks_single(object_mask=curr_object_mask, cell_mask=curr_cell_mask,
overlap_thresh=overlap_percent_threshold, object_name=obj,
mask_save_path=save_path, expansion_factor=expansion_factor)
curr_cell_mask = remaining_cells

# save the unmerged cells as a tiff.
Expand All @@ -93,6 +93,7 @@ def merge_masks_single(
overlap_thresh: int,
object_name: str,
mask_save_path: str,
expansion_factor: int
) -> np.ndarray:
"""
Combines overlapping object and cell masks. For any combination which represents has at least `overlap` percentage
Expand All @@ -104,6 +105,7 @@ def merge_masks_single(
overlap_thresh (int): The percentage overlap required for a cell to be merged.
object_name (str): The name of the object.
mask_save_path (str): The path to save the mask.
expansion_factor (int): How many pixels out from an objects bbox a cell should be looked for.

Returns:
np.ndarray: The cells remaining mask, which will be used for the next cycle in merging while there are objects.
Expand All @@ -123,6 +125,12 @@ def merge_masks_single(
# Set up list to store merged cell labels
remove_cells_list = [0]

# Create a dictionary of the bounding boxes for all object labels
object_labels_bounding_boxes = get_bounding_boxes(object_labels)

# Calculate all cell regionprops for filtering, convert to DataFrame
cell_props = pd.DataFrame(regionprops_table(cell_labels, properties=('label', 'centroid', 'major_axis_length')))
bryjcannon marked this conversation as resolved.
Show resolved Hide resolved

# Find connected components in object and cell masks. Merge only those with highest overlap that meets threshold.
for obj_label in range(1, num_object_labels + 1):
# Extract a connected component from object_mask
Expand All @@ -132,7 +140,11 @@ def merge_masks_single(
best_cell_mask_component = None
cell_to_merge_label = None

for cell_label in range(1, num_cell_labels + 1):
# Filter for cell_labels that fall within the expanded bounding box of the obj_label
cell_labels_in_range = filter_labels_in_bbox(
object_labels_bounding_boxes.pop(obj_label), cell_props, expansion_factor)
bryjcannon marked this conversation as resolved.
Show resolved Hide resolved

for cell_label in cell_labels_in_range:
# Extract a connected component from cell_mask
cell_mask_component = cell_labels == cell_label

Expand Down Expand Up @@ -165,3 +177,54 @@ def merge_masks_single(

# Return unmerged cells
return cell_labels


def get_bounding_boxes(object_labels: np.ndarray):
"""
Gets the bounding boxes of labeled images based on object major axis length.

Args:
object_labels (np.ndarray): label array
Returns:
dict: Dictionary containing labels as keys and bounding box as values
"""
bounding_boxes = {}

# Get region properties as a DataFrame
props = regionprops_table(object_labels, properties=('label', 'bbox'))

# Convert to DataFrame
df = pd.DataFrame(props)
bryjcannon marked this conversation as resolved.
Show resolved Hide resolved

for _, row in df.iterrows():
bryjcannon marked this conversation as resolved.
Show resolved Hide resolved
label_id, min_row, min_col, max_row, max_col = row.values
# Return closed interval bounding box
bounding_boxes[label_id] = ((min_row, min_col), (max_row-1, max_col-1))

return bounding_boxes


def filter_labels_in_bbox(bounding_box: List, cell_props: pd.DataFrame, expansion_factor: int):
bryjcannon marked this conversation as resolved.
Show resolved Hide resolved
"""
Gets the cell labels that fall within the expanded bounding box of a given object.

Args:
bounding_box (List): The bounding box values for the input obj_label
cell_props (pd.DataFrame): The cell label regionprops DataFrame.
expansion_factor: how many pixels from the bounding box you want to expand the search for compatible cells.

Returns:
List: The cell labels that fall within the expanded bounding box.

"""
min_row, min_col = bounding_box[0]
max_row, max_col = bounding_box[1]

# Filter labels based on bounding box
filtered_labels = cell_props[(cell_props['centroid-0'] >= min_row-expansion_factor) &
(cell_props['centroid-0'] <= max_row+expansion_factor) &
(cell_props['centroid-1'] >= min_col-expansion_factor) &
(cell_props['centroid-1'] <= max_col+expansion_factor)]['label'].tolist()

return filtered_labels

14 changes: 11 additions & 3 deletions templates/ez_segmenter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"tags": [
"import"
Expand Down Expand Up @@ -552,7 +552,11 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"pycharm": {
"is_executing": true
}
},
"source": [
"#### Here you can merge object segmentation masks with cell masks (or any other type of mask).\n",
"Here you will provide a list of what objects you would like to merge with previously segmented cell masks (or other base mask).\n",
Expand All @@ -563,10 +567,12 @@
"\n",
"* `merge_masks_list`: list of object masks to merge to the base (cell) `image.List` of object masks to merge to the base (cell) image.\n",
"* `percent_overlap`: percent threshold required for a cell mask to be merged into an object mask\n",
"* `expansion_factor`: number of pixels out from an objects bounding box a cell should be looked for. (Default 10 pixels)\n",
"* `cell_dir`: the final mask directory\n",
"* `cell_mask_suffix`: Suffix name of the cell mask files. Usually \"whole_cell\"\n",
"* `merged_masks_dir`: the directory to store the merged masks"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -580,6 +586,7 @@
"source": [
"merge_masks_list = [\"microglia-arms\", \"astrocyte-arms\"]\n",
"percent_overlap = 30\n",
"expansion_factor = 10\n",
"\n",
"# Overwrite if different from above\n",
"cell_dir = os.path.join(segmentation_dir, \"deepcell_output\")\n",
Expand Down Expand Up @@ -633,6 +640,7 @@
" cell_mask_dir=cell_dir,\n",
" cell_mask_suffix=cell_mask_suffix,\n",
" overlap_percent_threshold=percent_overlap,\n",
" expansion_factor=expansion_factor,\n",
" save_path=merged_masks_dir,\n",
" log_dir=log_dir\n",
")"
Expand Down
89 changes: 80 additions & 9 deletions tests/segmentation/ez_seg/merge_masks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import skimage.io as io
import tempfile
import xarray as xr

import pytest
from alpineer import image_utils
from ark.segmentation.ez_seg import merge_masks
from skimage.morphology import label
from skimage.measure import regionprops_table, label
import pandas as pd
from skimage.draw import disk
from typing import List, Union

Expand All @@ -26,6 +27,7 @@ def test_merge_masks_seq():
os.mkdir(directory)

overlap_thresh: int = 10
expansion_factor: int = 10

for fov in fov_list:
cell_mask_data: np.ndarray = np.random.randint(0, 16, (32, 32))
Expand All @@ -42,10 +44,9 @@ def test_merge_masks_seq():
image_utils.save_image(object_mask_fov_file, object_mask_data)

# we're only testing functionality, for in-depth merge testing see test_merge_masks_single
merge_masks.merge_masks_seq(
fov_list, object_list, object_mask_dir, cell_mask_dir, cell_mask_suffix,
overlap_thresh, merged_mask_dir, log_dir
)
merge_masks.merge_masks_seq(fov_list, object_list, object_mask_dir, cell_mask_dir,
cell_mask_suffix, overlap_thresh, expansion_factor,
merged_mask_dir, log_dir)

for fov in fov_list:
print("checking fov")
Expand Down Expand Up @@ -77,6 +78,7 @@ def test_merge_masks_single():
expected_cell_mask: np.ndarray = np.zeros((32, 32))

overlap_thresh: int = 10
expansion_factor: int = 10
merged_mask_name: str = "merged_mask"

# case 1: overlap below threshold, don't merge
Expand Down Expand Up @@ -107,13 +109,82 @@ def test_merge_masks_single():
mask_save_dir: Union[str, pathlib.Path] = os.path.join(td, "mask_save_dir")
os.mkdir(mask_save_dir)

created_cell_mask: np.ndarray = merge_masks.merge_masks_single(
object_mask, cell_mask, overlap_thresh, merged_mask_name, mask_save_dir
)
created_cell_mask: np.ndarray = merge_masks.merge_masks_single(object_mask,
cell_mask,
overlap_thresh,
merged_mask_name,
mask_save_dir,
expansion_factor)

created_merged_mask: np.ndarray = io.imread(
os.path.join(mask_save_dir, merged_mask_name + "_merged.tiff")
)

assert np.all(created_merged_mask == expected_merged_mask)
assert np.all(created_cell_mask == expected_cell_mask)


def test_get_bounding_boxes():
# Create a labeled array
labels = np.array([[1, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 2, 2]])

# Call the function to get bounding boxes
bounding_boxes = merge_masks.get_bounding_boxes(labels)

# Expected bounding boxes
expected_bounding_boxes = {1: ((0, 0), (1, 1)),
2: ((2, 2), (2, 3))}

assert bounding_boxes == expected_bounding_boxes


def test_filter_labels_in_bbox():
# Create a labeled array
labels = np.array([[1, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 2, 2]])

# Get regionprops df
label_df = pd.DataFrame(regionprops_table(
label(labels), properties=('label', 'centroid', 'major_axis_length')))

# Get the bounding boxes
bounding_boxes = merge_masks.get_bounding_boxes(label(labels))

# Filter labels within the bounding box of label 1
filtered_labels = merge_masks.filter_labels_in_bbox(
bounding_boxes[1], label_df, expansion_factor=0)

# Expected filtered labels for label 1
expected_filtered_labels_1 = [1]

assert filtered_labels == expected_filtered_labels_1

# Filter labels within the bounding box of label 2
filtered_labels = merge_masks.filter_labels_in_bbox(
bounding_boxes[2], label_df, expansion_factor=0)

# Expected filtered labels for label 2
expected_filtered_labels_2 = [2]

assert filtered_labels == expected_filtered_labels_2

# Filter labels within the bounding box of an empty label (should return an empty list)
filtered_labels = merge_masks.filter_labels_in_bbox(
((0, 0), (0, 0)), label_df, expansion_factor=0)

# Expected filtered labels for an empty label
expected_filtered_labels_empty = []

assert filtered_labels == expected_filtered_labels_empty

# Filter labels within the bounding box of label 1 and expansion
filtered_labels = merge_masks.filter_labels_in_bbox(
bounding_boxes[1], label_df, expansion_factor=10)

# Expected filtered labels for label 1 and 2
expected_filtered_labels_expanded = [1, 2]

assert filtered_labels == expected_filtered_labels_expanded
Loading