From 87f0439dd834094b822499500e925f94489bdd55 Mon Sep 17 00:00:00 2001 From: bryjc Date: Thu, 29 Feb 2024 09:55:01 -0800 Subject: [PATCH 1/8] Added updates for merge_mask speedup + testing. --- src/ark/segmentation/ez_seg/merge_masks.py | 69 ++++++++++++++++++- tests/segmentation/ez_seg/merge_masks_test.py | 52 +++++++++++++- 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/src/ark/segmentation/ez_seg/merge_masks.py b/src/ark/segmentation/ez_seg/merge_masks.py index 5d2e54de5..27748528d 100644 --- a/src/ark/segmentation/ez_seg/merge_masks.py +++ b/src/ark/segmentation/ez_seg/merge_masks.py @@ -5,6 +5,7 @@ import os from skimage.io import imread from skimage.morphology import label +from skimage.measure import regionprops from alpineer import load_utils, image_utils from ark.segmentation.ez_seg.ez_seg_utils import log_creator @@ -123,6 +124,9 @@ def merge_masks_single( # Set up list to store merged cell labels remove_cells_list = [0] + # Create a dictionary of the bounding boxes for all object labels + object_labels_bounding_boxes = get_bounding_boxes(object_labels) + # Find connected components in object and cell masks. Merge only those with highest overlap that meets threshold. for obj_label in range(1, num_object_labels + 1): # Extract a connected component from object_mask @@ -132,9 +136,12 @@ def merge_masks_single( best_cell_mask_component = None cell_to_merge_label = None - for cell_label in range(1, num_cell_labels + 1): + # Filter for cell_labels that fall within the expanded bounding box of the obj_label + cell_labels_in_range = filter_labels_in_bbox(object_labels_bounding_boxes.pop(str(obj_label)), cell_labels) + + for cell_label in cell_labels_in_range: # Extract a connected component from cell_mask - cell_mask_component = cell_labels == cell_label + cell_mask_component = cell_labels == int(cell_label) # Calculate the overlap between cell_mask_component and object_mask_component intersection = np.logical_and(cell_mask_component, object_mask_component) @@ -165,3 +172,61 @@ def merge_masks_single( # Return unmerged cells return cell_labels + + +def get_bounding_boxes(object_labels: np.ndarray): + """ + Gets the bounding boxes of labeled images based on object major axis length. + + Args: + object_labels (np.ndarray): label array + Returns: + dict: Dictionary containing labels as keys and bounding box as values + """ + bounding_boxes = {} + + props = regionprops(object_labels) + + for prop in props: + # Get major axis length + major_axis_length = prop.major_axis_length + + # Define bounding box based on major axis length + centroid = prop.centroid + radius = int(major_axis_length / 2) + min_row = max(0, int(centroid[0]) - radius) + max_row = min(object_labels.shape[0] - 1, int(centroid[0]) + radius) + min_col = max(0, int(centroid[1]) - radius) + max_col = min(object_labels.shape[1] - 1, int(centroid[1]) + radius) + + bounding_boxes[prop.label] = ((min_row, min_col), (max_row, max_col)) + + return bounding_boxes + + +def filter_labels_in_bbox(bounding_box: List, cell_labels: np.ndarray): + """ + Gets the cell labels that fall within the expanded bounding box of a given object. + + Args: + bounding_box (List): The bounding box values for the input obj_label + cell_labels (np.ndarray): The cell label array. + + Returns: + List: The cell labels that fall within the expanded bounding box. + + """ + min_row, min_col = bounding_box[0] + max_row, max_col = bounding_box[1] + + filtered_labels = [] + + props = regionprops(cell_labels) + + for prop in props: + centroid = prop.centroid + if min_row <= centroid[0] <= max_row and min_col <= centroid[1] <= max_col: + filtered_labels.append(prop.label) + + return filtered_labels + diff --git a/tests/segmentation/ez_seg/merge_masks_test.py b/tests/segmentation/ez_seg/merge_masks_test.py index 695c2f72a..4012028d7 100644 --- a/tests/segmentation/ez_seg/merge_masks_test.py +++ b/tests/segmentation/ez_seg/merge_masks_test.py @@ -7,7 +7,7 @@ from alpineer import image_utils from ark.segmentation.ez_seg import merge_masks -from skimage.morphology import label +from skimage.measure import label from skimage.draw import disk from typing import List, Union @@ -117,3 +117,53 @@ def test_merge_masks_single(): assert np.all(created_merged_mask == expected_merged_mask) assert np.all(created_cell_mask == expected_cell_mask) + + +def test_get_bounding_boxes(): + # Create a labeled array + labels = np.array([[1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 2]]) + + # Call the function to get bounding boxes + bounding_boxes = merge_masks.get_bounding_boxes(labels) + + # Expected bounding boxes + expected_bounding_boxes = {1: ((0, 0), (1, 1)), + 2: ((2, 2), (2, 3))} + + assert bounding_boxes == expected_bounding_boxes + + +def test_filter_labels_in_bbox(): + # Create a labeled array + labels = np.array([[1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 2]]) + + # Get the bounding boxes + bounding_boxes = merge_masks.get_bounding_boxes(labels) + + # Filter labels within the bounding box of label 1 + filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], labels) + + # Expected filtered labels for label 1 + expected_filtered_labels_1 = [1] + + assert filtered_labels == expected_filtered_labels_1 + + # Filter labels within the bounding box of label 2 + filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[2], labels) + + # Expected filtered labels for label 2 + expected_filtered_labels_2 = [2] + + assert filtered_labels == expected_filtered_labels_2 + + # Filter labels within the bounding box of an empty label (should return an empty list) + filtered_labels = merge_masks.filter_labels_in_bbox(((0, 0), (0, 0)), labels) + + # Expected filtered labels for an empty label + expected_filtered_labels_empty = [] + + assert filtered_labels == expected_filtered_labels_empty From 9523e8a723b9c2435668ce5e881849281041bfaf Mon Sep 17 00:00:00 2001 From: bryjc Date: Thu, 29 Feb 2024 20:31:50 -0800 Subject: [PATCH 2/8] Updated localized merging strat and added tests. --- src/ark/segmentation/ez_seg/merge_masks.py | 64 +++++++++---------- templates/ez_segmenter.ipynb | 3 + tests/segmentation/ez_seg/merge_masks_test.py | 40 ++++++++---- 3 files changed, 60 insertions(+), 47 deletions(-) diff --git a/src/ark/segmentation/ez_seg/merge_masks.py b/src/ark/segmentation/ez_seg/merge_masks.py index 27748528d..8e8bc4cac 100644 --- a/src/ark/segmentation/ez_seg/merge_masks.py +++ b/src/ark/segmentation/ez_seg/merge_masks.py @@ -5,7 +5,8 @@ import os from skimage.io import imread from skimage.morphology import label -from skimage.measure import regionprops +from skimage.measure import regionprops_table +import pandas as pd from alpineer import load_utils, image_utils from ark.segmentation.ez_seg.ez_seg_utils import log_creator @@ -17,6 +18,7 @@ def merge_masks_seq( cell_mask_dir: Union[pathlib.Path, str], cell_mask_suffix: str, overlap_percent_threshold: int, + expansion_factor: int, save_path: Union[pathlib.Path, str], log_dir: Union[pathlib.Path, str] ) -> None: @@ -32,6 +34,7 @@ def merge_masks_seq( cell_mask_dir (Union[str, pathlib.Path]): Path to where the original cell masks are located. cell_mask_suffix (str): Name of the cell type you are merging. Usually "whole_cell". overlap_percent_threshold (int): Percent overlap of total pixel area needed fo object to be merged to a cell. + expansion_factor (int): How many pixels out from an objects bbox a cell should be looked for. save_path (Union[str, pathlib.Path]): The directory where merged masks and remaining cell mask will be saved. log_dir (Union[str, pathlib.Path]): The directory to save log information to. """ @@ -62,13 +65,9 @@ def merge_masks_seq( # for each object type in the fov, merge with cell masks for obj in fov_object_names: curr_object_mask = imread(fname=(object_mask_dir / obj)) - remaining_cells = merge_masks_single( - object_mask=curr_object_mask, - cell_mask=curr_cell_mask, - overlap_thresh=overlap_percent_threshold, - object_name=obj, - mask_save_path=save_path, - ) + remaining_cells = merge_masks_single(object_mask=curr_object_mask, cell_mask=curr_cell_mask, + overlap_thresh=overlap_percent_threshold, object_name=obj, + mask_save_path=save_path, expansion_factor=expansion_factor) curr_cell_mask = remaining_cells # save the unmerged cells as a tiff. @@ -94,6 +93,7 @@ def merge_masks_single( overlap_thresh: int, object_name: str, mask_save_path: str, + expansion_factor: int ) -> np.ndarray: """ Combines overlapping object and cell masks. For any combination which represents has at least `overlap` percentage @@ -105,6 +105,7 @@ def merge_masks_single( overlap_thresh (int): The percentage overlap required for a cell to be merged. object_name (str): The name of the object. mask_save_path (str): The path to save the mask. + expansion_factor (int): How many pixels out from an objects bbox a cell should be looked for. Returns: np.ndarray: The cells remaining mask, which will be used for the next cycle in merging while there are objects. @@ -127,6 +128,9 @@ def merge_masks_single( # Create a dictionary of the bounding boxes for all object labels object_labels_bounding_boxes = get_bounding_boxes(object_labels) + # Calculate all cell regionprops for filtering, convert to DataFrame + cell_props = pd.DataFrame(regionprops_table(cell_labels, properties=('label', 'centroid', 'major_axis_length'))) + # Find connected components in object and cell masks. Merge only those with highest overlap that meets threshold. for obj_label in range(1, num_object_labels + 1): # Extract a connected component from object_mask @@ -137,11 +141,12 @@ def merge_masks_single( cell_to_merge_label = None # Filter for cell_labels that fall within the expanded bounding box of the obj_label - cell_labels_in_range = filter_labels_in_bbox(object_labels_bounding_boxes.pop(str(obj_label)), cell_labels) + cell_labels_in_range = filter_labels_in_bbox( + object_labels_bounding_boxes.pop(obj_label), cell_props, expansion_factor) for cell_label in cell_labels_in_range: # Extract a connected component from cell_mask - cell_mask_component = cell_labels == int(cell_label) + cell_mask_component = cell_labels == cell_label # Calculate the overlap between cell_mask_component and object_mask_component intersection = np.logical_and(cell_mask_component, object_mask_component) @@ -185,32 +190,28 @@ def get_bounding_boxes(object_labels: np.ndarray): """ bounding_boxes = {} - props = regionprops(object_labels) - - for prop in props: - # Get major axis length - major_axis_length = prop.major_axis_length + # Get region properties as a DataFrame + props = regionprops_table(object_labels, properties=('label', 'bbox')) - # Define bounding box based on major axis length - centroid = prop.centroid - radius = int(major_axis_length / 2) - min_row = max(0, int(centroid[0]) - radius) - max_row = min(object_labels.shape[0] - 1, int(centroid[0]) + radius) - min_col = max(0, int(centroid[1]) - radius) - max_col = min(object_labels.shape[1] - 1, int(centroid[1]) + radius) + # Convert to DataFrame + df = pd.DataFrame(props) - bounding_boxes[prop.label] = ((min_row, min_col), (max_row, max_col)) + for _, row in df.iterrows(): + label_id, min_row, min_col, max_row, max_col = row.values + # Return closed interval bounding box + bounding_boxes[label_id] = ((min_row, min_col), (max_row-1, max_col-1)) return bounding_boxes -def filter_labels_in_bbox(bounding_box: List, cell_labels: np.ndarray): +def filter_labels_in_bbox(bounding_box: List, cell_props: pd.DataFrame, expansion_factor: int): """ Gets the cell labels that fall within the expanded bounding box of a given object. Args: bounding_box (List): The bounding box values for the input obj_label - cell_labels (np.ndarray): The cell label array. + cell_props (pd.DataFrame): The cell label regionprops DataFrame. + expansion_factor: how many pixels from the bounding box you want to expand the search for compatible cells. Returns: List: The cell labels that fall within the expanded bounding box. @@ -219,14 +220,11 @@ def filter_labels_in_bbox(bounding_box: List, cell_labels: np.ndarray): min_row, min_col = bounding_box[0] max_row, max_col = bounding_box[1] - filtered_labels = [] - - props = regionprops(cell_labels) - - for prop in props: - centroid = prop.centroid - if min_row <= centroid[0] <= max_row and min_col <= centroid[1] <= max_col: - filtered_labels.append(prop.label) + # Filter labels based on bounding box + filtered_labels = cell_props[(cell_props['centroid-0'] >= min_row-expansion_factor) & + (cell_props['centroid-0'] <= max_row+expansion_factor) & + (cell_props['centroid-1'] >= min_col-expansion_factor) & + (cell_props['centroid-1'] <= max_col+expansion_factor)]['label'].tolist() return filtered_labels diff --git a/templates/ez_segmenter.ipynb b/templates/ez_segmenter.ipynb index 8e9bb0c99..d3ca1f09a 100644 --- a/templates/ez_segmenter.ipynb +++ b/templates/ez_segmenter.ipynb @@ -563,6 +563,7 @@ "\n", "* `merge_masks_list`: list of object masks to merge to the base (cell) `image.List` of object masks to merge to the base (cell) image.\n", "* `percent_overlap`: percent threshold required for a cell mask to be merged into an object mask\n", + "* `expansion_factor`: number of pixels out from an objects bbox a cell should be looked for. (Deafualt 10 pixels)\n", "* `cell_dir`: the final mask directory\n", "* `cell_mask_suffix`: Suffix name of the cell mask files. Usually \"whole_cell\"\n", "* `merged_masks_dir`: the directory to store the merged masks" @@ -580,6 +581,7 @@ "source": [ "merge_masks_list = [\"microglia-arms\", \"astrocyte-arms\"]\n", "percent_overlap = 30\n", + "expansion_factor = 10\n", "\n", "# Overwrite if different from above\n", "cell_dir = os.path.join(segmentation_dir, \"deepcell_output\")\n", @@ -633,6 +635,7 @@ " cell_mask_dir=cell_dir,\n", " cell_mask_suffix=cell_mask_suffix,\n", " overlap_percent_threshold=percent_overlap,\n", + " expansion_factor=expansion_factor,\n", " save_path=merged_masks_dir,\n", " log_dir=log_dir\n", ")" diff --git a/tests/segmentation/ez_seg/merge_masks_test.py b/tests/segmentation/ez_seg/merge_masks_test.py index 4012028d7..af9388007 100644 --- a/tests/segmentation/ez_seg/merge_masks_test.py +++ b/tests/segmentation/ez_seg/merge_masks_test.py @@ -4,10 +4,11 @@ import skimage.io as io import tempfile import xarray as xr - +import pytest from alpineer import image_utils from ark.segmentation.ez_seg import merge_masks -from skimage.measure import label +from skimage.measure import regionprops_table, label +import pandas as pd from skimage.draw import disk from typing import List, Union @@ -26,6 +27,7 @@ def test_merge_masks_seq(): os.mkdir(directory) overlap_thresh: int = 10 + expansion_factor: int = 10 for fov in fov_list: cell_mask_data: np.ndarray = np.random.randint(0, 16, (32, 32)) @@ -42,10 +44,8 @@ def test_merge_masks_seq(): image_utils.save_image(object_mask_fov_file, object_mask_data) # we're only testing functionality, for in-depth merge testing see test_merge_masks_single - merge_masks.merge_masks_seq( - fov_list, object_list, object_mask_dir, cell_mask_dir, cell_mask_suffix, - overlap_thresh, merged_mask_dir, log_dir - ) + merge_masks.merge_masks_seq(fov_list, object_list, object_mask_dir, cell_mask_dir, cell_mask_suffix, + overlap_thresh, merged_mask_dir, log_dir, expansion_factor) for fov in fov_list: print("checking fov") @@ -77,6 +77,7 @@ def test_merge_masks_single(): expected_cell_mask: np.ndarray = np.zeros((32, 32)) overlap_thresh: int = 10 + expansion_factor: int = 10 merged_mask_name: str = "merged_mask" # case 1: overlap below threshold, don't merge @@ -107,9 +108,9 @@ def test_merge_masks_single(): mask_save_dir: Union[str, pathlib.Path] = os.path.join(td, "mask_save_dir") os.mkdir(mask_save_dir) - created_cell_mask: np.ndarray = merge_masks.merge_masks_single( - object_mask, cell_mask, overlap_thresh, merged_mask_name, mask_save_dir - ) + created_cell_mask: np.ndarray = merge_masks.merge_masks_single(object_mask, cell_mask, + overlap_thresh, merged_mask_name, + mask_save_dir, expansion_factor) created_merged_mask: np.ndarray = io.imread( os.path.join(mask_save_dir, merged_mask_name + "_merged.tiff") @@ -130,7 +131,7 @@ def test_get_bounding_boxes(): # Expected bounding boxes expected_bounding_boxes = {1: ((0, 0), (1, 1)), - 2: ((2, 2), (2, 3))} + 2: ((2, 2), (2, 3))} assert bounding_boxes == expected_bounding_boxes @@ -141,11 +142,14 @@ def test_filter_labels_in_bbox(): [0, 1, 0, 0], [0, 0, 2, 2]]) + # Get regionprops df + label_df = pd.DataFrame(regionprops_table(label(labels), properties=('label', 'centroid', 'major_axis_length'))) + # Get the bounding boxes - bounding_boxes = merge_masks.get_bounding_boxes(labels) + bounding_boxes = merge_masks.get_bounding_boxes(label(labels)) # Filter labels within the bounding box of label 1 - filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], labels) + filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], label_df, expansion_factor=0) # Expected filtered labels for label 1 expected_filtered_labels_1 = [1] @@ -153,7 +157,7 @@ def test_filter_labels_in_bbox(): assert filtered_labels == expected_filtered_labels_1 # Filter labels within the bounding box of label 2 - filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[2], labels) + filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[2], label_df, expansion_factor=0) # Expected filtered labels for label 2 expected_filtered_labels_2 = [2] @@ -161,9 +165,17 @@ def test_filter_labels_in_bbox(): assert filtered_labels == expected_filtered_labels_2 # Filter labels within the bounding box of an empty label (should return an empty list) - filtered_labels = merge_masks.filter_labels_in_bbox(((0, 0), (0, 0)), labels) + filtered_labels = merge_masks.filter_labels_in_bbox(((0, 0), (0, 0)), label_df, expansion_factor=0) # Expected filtered labels for an empty label expected_filtered_labels_empty = [] assert filtered_labels == expected_filtered_labels_empty + + # Filter labels within the bounding box of label 1 and expansion + filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], label_df, expansion_factor=10) + + # Expected filtered labels for label 1 and 2 + expected_filtered_labels_expanded = [1,2] + + assert filtered_labels == expected_filtered_labels_expanded From 0bbda94fcd592f05415746f18bc1b2cdc56ed55d Mon Sep 17 00:00:00 2001 From: bryjc Date: Mon, 4 Mar 2024 15:31:28 -0800 Subject: [PATCH 3/8] Tweaked a couple words in notebook. --- templates/ez_segmenter.ipynb | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/templates/ez_segmenter.ipynb b/templates/ez_segmenter.ipynb index d3ca1f09a..96d03f75e 100644 --- a/templates/ez_segmenter.ipynb +++ b/templates/ez_segmenter.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "tags": [ "import" @@ -552,7 +552,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "pycharm": { + "is_executing": true + } + }, "source": [ "#### Here you can merge object segmentation masks with cell masks (or any other type of mask).\n", "Here you will provide a list of what objects you would like to merge with previously segmented cell masks (or other base mask).\n", @@ -563,11 +567,12 @@ "\n", "* `merge_masks_list`: list of object masks to merge to the base (cell) `image.List` of object masks to merge to the base (cell) image.\n", "* `percent_overlap`: percent threshold required for a cell mask to be merged into an object mask\n", - "* `expansion_factor`: number of pixels out from an objects bbox a cell should be looked for. (Deafualt 10 pixels)\n", + "* `expansion_factor`: number of pixels out from an objects bounding box a cell should be looked for. (Default 10 pixels)\n", "* `cell_dir`: the final mask directory\n", "* `cell_mask_suffix`: Suffix name of the cell mask files. Usually \"whole_cell\"\n", "* `merged_masks_dir`: the directory to store the merged masks" - ] + ], + "outputs": [] }, { "cell_type": "code", From 2fd085e1b35742066cd801513fad97d07526e09c Mon Sep 17 00:00:00 2001 From: bryjc Date: Thu, 7 Mar 2024 17:40:00 -0800 Subject: [PATCH 4/8] Fixed merge_masks_test.py error. --- tests/segmentation/ez_seg/merge_masks_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/segmentation/ez_seg/merge_masks_test.py b/tests/segmentation/ez_seg/merge_masks_test.py index af9388007..e106f1e36 100644 --- a/tests/segmentation/ez_seg/merge_masks_test.py +++ b/tests/segmentation/ez_seg/merge_masks_test.py @@ -45,7 +45,7 @@ def test_merge_masks_seq(): # we're only testing functionality, for in-depth merge testing see test_merge_masks_single merge_masks.merge_masks_seq(fov_list, object_list, object_mask_dir, cell_mask_dir, cell_mask_suffix, - overlap_thresh, merged_mask_dir, log_dir, expansion_factor) + overlap_thresh, expansion_factor, merged_mask_dir, log_dir) for fov in fov_list: print("checking fov") @@ -176,6 +176,6 @@ def test_filter_labels_in_bbox(): filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], label_df, expansion_factor=10) # Expected filtered labels for label 1 and 2 - expected_filtered_labels_expanded = [1,2] + expected_filtered_labels_expanded = [1, 2] assert filtered_labels == expected_filtered_labels_expanded From 65fe3f3879491dd247b35b87788b903c66a86298 Mon Sep 17 00:00:00 2001 From: bryjc Date: Thu, 7 Mar 2024 18:25:52 -0800 Subject: [PATCH 5/8] Fixed pycodestyle --- tests/segmentation/ez_seg/merge_masks_test.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/segmentation/ez_seg/merge_masks_test.py b/tests/segmentation/ez_seg/merge_masks_test.py index e106f1e36..7b42ce80d 100644 --- a/tests/segmentation/ez_seg/merge_masks_test.py +++ b/tests/segmentation/ez_seg/merge_masks_test.py @@ -44,8 +44,9 @@ def test_merge_masks_seq(): image_utils.save_image(object_mask_fov_file, object_mask_data) # we're only testing functionality, for in-depth merge testing see test_merge_masks_single - merge_masks.merge_masks_seq(fov_list, object_list, object_mask_dir, cell_mask_dir, cell_mask_suffix, - overlap_thresh, expansion_factor, merged_mask_dir, log_dir) + merge_masks.merge_masks_seq(fov_list, object_list, object_mask_dir, cell_mask_dir, + cell_mask_suffix, overlap_thresh, expansion_factor, + merged_mask_dir, log_dir) for fov in fov_list: print("checking fov") @@ -108,9 +109,12 @@ def test_merge_masks_single(): mask_save_dir: Union[str, pathlib.Path] = os.path.join(td, "mask_save_dir") os.mkdir(mask_save_dir) - created_cell_mask: np.ndarray = merge_masks.merge_masks_single(object_mask, cell_mask, - overlap_thresh, merged_mask_name, - mask_save_dir, expansion_factor) + created_cell_mask: np.ndarray = merge_masks.merge_masks_single(object_mask, + cell_mask, + overlap_thresh, + merged_mask_name, + mask_save_dir, + expansion_factor) created_merged_mask: np.ndarray = io.imread( os.path.join(mask_save_dir, merged_mask_name + "_merged.tiff") @@ -143,13 +147,15 @@ def test_filter_labels_in_bbox(): [0, 0, 2, 2]]) # Get regionprops df - label_df = pd.DataFrame(regionprops_table(label(labels), properties=('label', 'centroid', 'major_axis_length'))) + label_df = pd.DataFrame(regionprops_table( + label(labels), properties=('label', 'centroid', 'major_axis_length'))) # Get the bounding boxes bounding_boxes = merge_masks.get_bounding_boxes(label(labels)) # Filter labels within the bounding box of label 1 - filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], label_df, expansion_factor=0) + filtered_labels = merge_masks.filter_labels_in_bbox( + bounding_boxes[1], label_df, expansion_factor=0) # Expected filtered labels for label 1 expected_filtered_labels_1 = [1] @@ -157,7 +163,8 @@ def test_filter_labels_in_bbox(): assert filtered_labels == expected_filtered_labels_1 # Filter labels within the bounding box of label 2 - filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[2], label_df, expansion_factor=0) + filtered_labels = merge_masks.filter_labels_in_bbox( + bounding_boxes[2], label_df, expansion_factor=0) # Expected filtered labels for label 2 expected_filtered_labels_2 = [2] @@ -165,7 +172,8 @@ def test_filter_labels_in_bbox(): assert filtered_labels == expected_filtered_labels_2 # Filter labels within the bounding box of an empty label (should return an empty list) - filtered_labels = merge_masks.filter_labels_in_bbox(((0, 0), (0, 0)), label_df, expansion_factor=0) + filtered_labels = merge_masks.filter_labels_in_bbox( + ((0, 0), (0, 0)), label_df, expansion_factor=0) # Expected filtered labels for an empty label expected_filtered_labels_empty = [] @@ -173,7 +181,8 @@ def test_filter_labels_in_bbox(): assert filtered_labels == expected_filtered_labels_empty # Filter labels within the bounding box of label 1 and expansion - filtered_labels = merge_masks.filter_labels_in_bbox(bounding_boxes[1], label_df, expansion_factor=10) + filtered_labels = merge_masks.filter_labels_in_bbox( + bounding_boxes[1], label_df, expansion_factor=10) # Expected filtered labels for label 1 and 2 expected_filtered_labels_expanded = [1, 2] From 85c75cf45342ba9e349932eec0d2515ad20ac1a8 Mon Sep 17 00:00:00 2001 From: bryjc Date: Mon, 11 Mar 2024 14:21:01 -0700 Subject: [PATCH 6/8] updated merge_masks.py after review suggestions. --- src/ark/segmentation/ez_seg/merge_masks.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/ark/segmentation/ez_seg/merge_masks.py b/src/ark/segmentation/ez_seg/merge_masks.py index 8e8bc4cac..4a98205f3 100644 --- a/src/ark/segmentation/ez_seg/merge_masks.py +++ b/src/ark/segmentation/ez_seg/merge_masks.py @@ -129,7 +129,7 @@ def merge_masks_single( object_labels_bounding_boxes = get_bounding_boxes(object_labels) # Calculate all cell regionprops for filtering, convert to DataFrame - cell_props = pd.DataFrame(regionprops_table(cell_labels, properties=('label', 'centroid', 'major_axis_length'))) + cell_props = pd.DataFrame(regionprops_table(cell_labels, properties=('label', 'centroid'))) # Find connected components in object and cell masks. Merge only those with highest overlap that meets threshold. for obj_label in range(1, num_object_labels + 1): @@ -142,7 +142,7 @@ def merge_masks_single( # Filter for cell_labels that fall within the expanded bounding box of the obj_label cell_labels_in_range = filter_labels_in_bbox( - object_labels_bounding_boxes.pop(obj_label), cell_props, expansion_factor) + object_labels_bounding_boxes[obj_label], cell_props, expansion_factor) for cell_label in cell_labels_in_range: # Extract a connected component from cell_mask @@ -191,15 +191,12 @@ def get_bounding_boxes(object_labels: np.ndarray): bounding_boxes = {} # Get region properties as a DataFrame - props = regionprops_table(object_labels, properties=('label', 'bbox')) + props_df = pd.DataFrame(regionprops_table(object_labels, properties=('label', 'bbox'))) - # Convert to DataFrame - df = pd.DataFrame(props) - - for _, row in df.iterrows(): - label_id, min_row, min_col, max_row, max_col = row.values - # Return closed interval bounding box - bounding_boxes[label_id] = ((min_row, min_col), (max_row-1, max_col-1)) + # Return closed interval bounding box + # label_id, min_row, min_col, max_row, max_col used to define bbox + props_df.apply(lambda row: bounding_boxes.update( + {row['label']: ((row['bbox-0'], row['bbox-1']), (row['bbox-2'] - 1, row['bbox-3'] - 1))}), axis=1) return bounding_boxes From 281b8466151dbff3010d9122208d16426b883bfb Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Fri, 15 Mar 2024 11:37:10 -0700 Subject: [PATCH 7/8] bye bye dask, see you another day --- pyproject.toml | 1 - src/ark/utils/data_utils.py | 35 ++++++--------- templates/anndata_conversion.ipynb | 72 +++++++++++++++++++----------- tests/utils/data_utils_test.py | 23 +++++----- 4 files changed, 71 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7fba64ef..ad8a963bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "alpineer==0.1.12", "anndata", "Cython>=3", - "dask[distributed]", "datasets>=2.6,<3.0", "dill>=0.3.5,<0.4", "feather-format>=0.4.1,<1", diff --git a/src/ark/utils/data_utils.py b/src/ark/utils/data_utils.py index 6587aee71..dd2b978ec 100644 --- a/src/ark/utils/data_utils.py +++ b/src/ark/utils/data_utils.py @@ -19,11 +19,10 @@ from ark import settings from skimage.segmentation import find_boundaries import dask.dataframe as dd -from dask import delayed +from pandas.core.groupby.generic import DataFrameGroupBy from anndata import AnnData, read_zarr from anndata.experimental import AnnCollection from anndata.experimental.multi_files._anncollection import ConvertType -from tqdm.dask import TqdmCallback from torchdata.datapipes.iter import IterDataPipe from typing import Iterator, Optional try: @@ -823,14 +822,13 @@ def stitch_images_by_shape(data_dir, stitched_dir, img_sub_folder=None, channels current_img) -@delayed -def _convert_ct_fov_to_adata(fov_dd: dd.DataFrame, var_names: list[str], obs_names: list[str], save_dir: os.PathLike) -> str: +def _convert_ct_fov_to_adata(fov_group: DataFrameGroupBy, var_names: list[str], obs_names: list[str], save_dir: os.PathLike) -> str: """Converts the cell table for a single FOV to an `AnnData` object and saves it to disk as a `Zarr` store. Parameters ---------- - fov_dd : dd.DataFrame + fov_group : DataFrameGroupBy The cell table subset on a single FOV. var_names: list[str] The marker names to extract from the cell table. @@ -845,7 +843,7 @@ def _convert_ct_fov_to_adata(fov_dd: dd.DataFrame, var_names: list[str], obs_nam The path of the saved `AnnData` object. """ - fov_dd: dd.DataFrame = fov_dd.sort_values(by=settings.CELL_LABEL, key=ns.natsort_key).reset_index() + fov_dd: dd.DataFrame = fov_group.sort_values(by=settings.CELL_LABEL, key=ns.natsort_key).reset_index() fov_id: str = fov_dd[settings.FOV_ID].iloc[0] # Set the index to be the FOV and the segmentation label to create a unique index @@ -906,11 +904,10 @@ def __init__(self, cell_table_path: os.PathLike, io_utils.validate_paths(paths=cell_table_path) - # Read in the cell table - cell_table: dd.DataFrame = dd.read_csv(cell_table_path) + cell_table: pd.DataFrame = pd.read_csv(cell_table_path) ct_columns = cell_table.columns - + # Get the marker column indices marker_index_start: int = ct_columns.get_loc(settings.PRE_CHANNEL_COL) + 1 marker_index_stop: int = ct_columns.get_loc(settings.POST_CHANNEL_COL) @@ -970,20 +967,16 @@ def convert_to_adata( if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) + n_unique_fovs = self.cell_table[settings.FOV_ID].nunique() - with TqdmCallback(desc="Converting to AnnData"): - g: pd.Series = ( - self.cell_table.groupby(by=settings.FOV_ID, sort=True) - .apply( - _convert_ct_fov_to_adata, - var_names=self.var_names, - obs_names=self.obs_names, - save_dir=save_dir, - meta=("anndata_save_results", str), - ) - ).compute() + tqdm.pandas(desc="Converting Cell Table to AnnData Tables", total=n_unique_fovs, unit="FOVs") - return g.to_dict() + result: pd.Series = self.cell_table.groupby(by=settings.FOV_ID, sort=True).progress_apply( + lambda x: _convert_ct_fov_to_adata( + x, var_names=self.var_names, obs_names=self.obs_names, save_dir=save_dir + ), + ) + return result.to_dict() class AnnCollectionKwargs(TypedDict): diff --git a/templates/anndata_conversion.ipynb b/templates/anndata_conversion.ipynb index bbee809e1..e3047d045 100644 --- a/templates/anndata_conversion.ipynb +++ b/templates/anndata_conversion.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:00.527535Z", @@ -33,7 +33,6 @@ }, "outputs": [], "source": [ - "from dask.distributed import Client\n", "from anndata import read_zarr\n", "from ark.utils.data_utils import ConvertToAnnData\n", "import os" @@ -41,21 +40,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-30T17:09:02.298445Z", - "start_time": "2023-11-30T17:09:00.528933Z" - } - }, - "outputs": [], - "source": [ - "Client(threads_per_worker = 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:08:54.069774Z", @@ -80,14 +65,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:04.139994Z", "start_time": "2023-11-30T17:09:02.283610Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/image_data. They will be overwritten by the downloaded example dataset.\n", + " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n", + "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/segmentation/cell_table. They will be overwritten by the downloaded example dataset.\n", + " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n", + "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/segmentation/deepcell_output. They will be overwritten by the downloaded example dataset.\n", + " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n", + "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/pixie/example_cell_output_dir. They will be overwritten by the downloaded example dataset.\n", + " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n" + ] + } + ], "source": [ "from ark.utils.example_dataset import get_example_dataset\n", "\n", @@ -108,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:55.614306Z", @@ -131,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:56.466954Z", @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:56.879793Z", @@ -162,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:10:01.796185Z", @@ -170,7 +170,22 @@ }, "collapsed": false }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1451b0f016224bb99bc89fc95c35b2b7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Converting Cell Table to AnnData Tables: 0%| | 0/11 [00:00 Date: Fri, 15 Mar 2024 11:37:54 -0700 Subject: [PATCH 8/8] notebook cleanup --- templates/anndata_conversion.ipynb | 48 ++++++------------------------ 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/templates/anndata_conversion.ipynb b/templates/anndata_conversion.ipynb index e3047d045..924db61d2 100644 --- a/templates/anndata_conversion.ipynb +++ b/templates/anndata_conversion.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:00.527535Z", @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:08:54.069774Z", @@ -65,29 +65,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:04.139994Z", "start_time": "2023-11-30T17:09:02.283610Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/image_data. They will be overwritten by the downloaded example dataset.\n", - " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n", - "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/segmentation/cell_table. They will be overwritten by the downloaded example dataset.\n", - " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n", - "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/segmentation/deepcell_output. They will be overwritten by the downloaded example dataset.\n", - " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n", - "/Users/srivarra/Angelo Lab/Internal/ark-analysis/src/ark/utils/example_dataset.py:144: UserWarning: Files exist in ../data/example_dataset/pixie/example_cell_output_dir. They will be overwritten by the downloaded example dataset.\n", - " warnings.warn(UserWarning(f\"Files exist in {dst_path}. \\\n" - ] - } - ], + "outputs": [], "source": [ "from ark.utils.example_dataset import get_example_dataset\n", "\n", @@ -108,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:55.614306Z", @@ -131,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:56.466954Z", @@ -148,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:09:56.879793Z", @@ -162,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-11-30T17:10:01.796185Z", @@ -170,22 +155,7 @@ }, "collapsed": false }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1451b0f016224bb99bc89fc95c35b2b7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Converting Cell Table to AnnData Tables: 0%| | 0/11 [00:00