diff --git a/pyproject.toml b/pyproject.toml index c7fba64ef..ad8a963bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "alpineer==0.1.12", "anndata", "Cython>=3", - "dask[distributed]", "datasets>=2.6,<3.0", "dill>=0.3.5,<0.4", "feather-format>=0.4.1,<1", diff --git a/src/ark/segmentation/ez_seg/merge_masks.py b/src/ark/segmentation/ez_seg/merge_masks.py index 5d2e54de5..4a98205f3 100644 --- a/src/ark/segmentation/ez_seg/merge_masks.py +++ b/src/ark/segmentation/ez_seg/merge_masks.py @@ -5,6 +5,8 @@ import os from skimage.io import imread from skimage.morphology import label +from skimage.measure import regionprops_table +import pandas as pd from alpineer import load_utils, image_utils from ark.segmentation.ez_seg.ez_seg_utils import log_creator @@ -16,6 +18,7 @@ def merge_masks_seq( cell_mask_dir: Union[pathlib.Path, str], cell_mask_suffix: str, overlap_percent_threshold: int, + expansion_factor: int, save_path: Union[pathlib.Path, str], log_dir: Union[pathlib.Path, str] ) -> None: @@ -31,6 +34,7 @@ def merge_masks_seq( cell_mask_dir (Union[str, pathlib.Path]): Path to where the original cell masks are located. cell_mask_suffix (str): Name of the cell type you are merging. Usually "whole_cell". overlap_percent_threshold (int): Percent overlap of total pixel area needed fo object to be merged to a cell. + expansion_factor (int): How many pixels out from an objects bbox a cell should be looked for. save_path (Union[str, pathlib.Path]): The directory where merged masks and remaining cell mask will be saved. log_dir (Union[str, pathlib.Path]): The directory to save log information to. """ @@ -61,13 +65,9 @@ def merge_masks_seq( # for each object type in the fov, merge with cell masks for obj in fov_object_names: curr_object_mask = imread(fname=(object_mask_dir / obj)) - remaining_cells = merge_masks_single( - object_mask=curr_object_mask, - cell_mask=curr_cell_mask, - overlap_thresh=overlap_percent_threshold, - object_name=obj, - mask_save_path=save_path, - ) + remaining_cells = merge_masks_single(object_mask=curr_object_mask, cell_mask=curr_cell_mask, + overlap_thresh=overlap_percent_threshold, object_name=obj, + mask_save_path=save_path, expansion_factor=expansion_factor) curr_cell_mask = remaining_cells # save the unmerged cells as a tiff. @@ -93,6 +93,7 @@ def merge_masks_single( overlap_thresh: int, object_name: str, mask_save_path: str, + expansion_factor: int ) -> np.ndarray: """ Combines overlapping object and cell masks. For any combination which represents has at least `overlap` percentage @@ -104,6 +105,7 @@ def merge_masks_single( overlap_thresh (int): The percentage overlap required for a cell to be merged. object_name (str): The name of the object. mask_save_path (str): The path to save the mask. + expansion_factor (int): How many pixels out from an objects bbox a cell should be looked for. Returns: np.ndarray: The cells remaining mask, which will be used for the next cycle in merging while there are objects. @@ -123,6 +125,12 @@ def merge_masks_single( # Set up list to store merged cell labels remove_cells_list = [0] + # Create a dictionary of the bounding boxes for all object labels + object_labels_bounding_boxes = get_bounding_boxes(object_labels) + + # Calculate all cell regionprops for filtering, convert to DataFrame + cell_props = pd.DataFrame(regionprops_table(cell_labels, properties=('label', 'centroid'))) + # Find connected components in object and cell masks. Merge only those with highest overlap that meets threshold. for obj_label in range(1, num_object_labels + 1): # Extract a connected component from object_mask @@ -132,7 +140,11 @@ def merge_masks_single( best_cell_mask_component = None cell_to_merge_label = None - for cell_label in range(1, num_cell_labels + 1): + # Filter for cell_labels that fall within the expanded bounding box of the obj_label + cell_labels_in_range = filter_labels_in_bbox( + object_labels_bounding_boxes[obj_label], cell_props, expansion_factor) + + for cell_label in cell_labels_in_range: # Extract a connected component from cell_mask cell_mask_component = cell_labels == cell_label @@ -165,3 +177,51 @@ def merge_masks_single( # Return unmerged cells return cell_labels + + +def get_bounding_boxes(object_labels: np.ndarray): + """ + Gets the bounding boxes of labeled images based on object major axis length. + + Args: + object_labels (np.ndarray): label array + Returns: + dict: Dictionary containing labels as keys and bounding box as values + """ + bounding_boxes = {} + + # Get region properties as a DataFrame + props_df = pd.DataFrame(regionprops_table(object_labels, properties=('label', 'bbox'))) + + # Return closed interval bounding box + # label_id, min_row, min_col, max_row, max_col used to define bbox + props_df.apply(lambda row: bounding_boxes.update( + {row['label']: ((row['bbox-0'], row['bbox-1']), (row['bbox-2'] - 1, row['bbox-3'] - 1))}), axis=1) + + return bounding_boxes + + +def filter_labels_in_bbox(bounding_box: List, cell_props: pd.DataFrame, expansion_factor: int): + """ + Gets the cell labels that fall within the expanded bounding box of a given object. + + Args: + bounding_box (List): The bounding box values for the input obj_label + cell_props (pd.DataFrame): The cell label regionprops DataFrame. + expansion_factor: how many pixels from the bounding box you want to expand the search for compatible cells. + + Returns: + List: The cell labels that fall within the expanded bounding box. + + """ + min_row, min_col = bounding_box[0] + max_row, max_col = bounding_box[1] + + # Filter labels based on bounding box + filtered_labels = cell_props[(cell_props['centroid-0'] >= min_row-expansion_factor) & + (cell_props['centroid-0'] <= max_row+expansion_factor) & + (cell_props['centroid-1'] >= min_col-expansion_factor) & + (cell_props['centroid-1'] <= max_col+expansion_factor)]['label'].tolist() + + return filtered_labels + diff --git a/src/ark/utils/data_utils.py b/src/ark/utils/data_utils.py index 6587aee71..dd2b978ec 100644 --- a/src/ark/utils/data_utils.py +++ b/src/ark/utils/data_utils.py @@ -19,11 +19,10 @@ from ark import settings from skimage.segmentation import find_boundaries import dask.dataframe as dd -from dask import delayed +from pandas.core.groupby.generic import DataFrameGroupBy from anndata import AnnData, read_zarr from anndata.experimental import AnnCollection from anndata.experimental.multi_files._anncollection import ConvertType -from tqdm.dask import TqdmCallback from torchdata.datapipes.iter import IterDataPipe from typing import Iterator, Optional try: @@ -823,14 +822,13 @@ def stitch_images_by_shape(data_dir, stitched_dir, img_sub_folder=None, channels current_img) -@delayed -def _convert_ct_fov_to_adata(fov_dd: dd.DataFrame, var_names: list[str], obs_names: list[str], save_dir: os.PathLike) -> str: +def _convert_ct_fov_to_adata(fov_group: DataFrameGroupBy, var_names: list[str], obs_names: list[str], save_dir: os.PathLike) -> str: """Converts the cell table for a single FOV to an `AnnData` object and saves it to disk as a `Zarr` store. Parameters ---------- - fov_dd : dd.DataFrame + fov_group : DataFrameGroupBy The cell table subset on a single FOV. var_names: list[str] The marker names to extract from the cell table. @@ -845,7 +843,7 @@ def _convert_ct_fov_to_adata(fov_dd: dd.DataFrame, var_names: list[str], obs_nam The path of the saved `AnnData` object. """ - fov_dd: dd.DataFrame = fov_dd.sort_values(by=settings.CELL_LABEL, key=ns.natsort_key).reset_index() + fov_dd: dd.DataFrame = fov_group.sort_values(by=settings.CELL_LABEL, key=ns.natsort_key).reset_index() fov_id: str = fov_dd[settings.FOV_ID].iloc[0] # Set the index to be the FOV and the segmentation label to create a unique index @@ -906,11 +904,10 @@ def __init__(self, cell_table_path: os.PathLike, io_utils.validate_paths(paths=cell_table_path) - # Read in the cell table - cell_table: dd.DataFrame = dd.read_csv(cell_table_path) + cell_table: pd.DataFrame = pd.read_csv(cell_table_path) ct_columns = cell_table.columns - + # Get the marker column indices marker_index_start: int = ct_columns.get_loc(settings.PRE_CHANNEL_COL) + 1 marker_index_stop: int = ct_columns.get_loc(settings.POST_CHANNEL_COL) @@ -970,20 +967,16 @@ def convert_to_adata( if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) + n_unique_fovs = self.cell_table[settings.FOV_ID].nunique() - with TqdmCallback(desc="Converting to AnnData"): - g: pd.Series = ( - self.cell_table.groupby(by=settings.FOV_ID, sort=True) - .apply( - _convert_ct_fov_to_adata, - var_names=self.var_names, - obs_names=self.obs_names, - save_dir=save_dir, - meta=("anndata_save_results", str), - ) - ).compute() + tqdm.pandas(desc="Converting Cell Table to AnnData Tables", total=n_unique_fovs, unit="FOVs") - return g.to_dict() + result: pd.Series = self.cell_table.groupby(by=settings.FOV_ID, sort=True).progress_apply( + lambda x: _convert_ct_fov_to_adata( + x, var_names=self.var_names, obs_names=self.obs_names, save_dir=save_dir + ), + ) + return result.to_dict() class AnnCollectionKwargs(TypedDict): diff --git a/templates/anndata_conversion.ipynb b/templates/anndata_conversion.ipynb index bbee809e1..924db61d2 100644 --- a/templates/anndata_conversion.ipynb +++ b/templates/anndata_conversion.ipynb @@ -33,26 +33,11 @@ }, "outputs": [], "source": [ - "from dask.distributed import Client\n", "from anndata import read_zarr\n", "from ark.utils.data_utils import ConvertToAnnData\n", "import os" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-30T17:09:02.298445Z", - "start_time": "2023-11-30T17:09:00.528933Z" - } - }, - "outputs": [], - "source": [ - "Client(threads_per_worker = 2)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -181,6 +166,13 @@ "source": [ "We recommend reading both a brief overview of the `AnnData` datatype documentation [here](https://ark-analysis.readthedocs.io/en/latest/_rtd/data_types.html), and the official documentation [here](https://anndata.readthedocs.io/en/latest/index.html)." ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -199,7 +191,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.0" } }, "nbformat": 4, diff --git a/templates/ez_segmenter.ipynb b/templates/ez_segmenter.ipynb index 8e9bb0c99..96d03f75e 100644 --- a/templates/ez_segmenter.ipynb +++ b/templates/ez_segmenter.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "tags": [ "import" @@ -552,7 +552,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "pycharm": { + "is_executing": true + } + }, "source": [ "#### Here you can merge object segmentation masks with cell masks (or any other type of mask).\n", "Here you will provide a list of what objects you would like to merge with previously segmented cell masks (or other base mask).\n", @@ -563,10 +567,12 @@ "\n", "* `merge_masks_list`: list of object masks to merge to the base (cell) `image.List` of object masks to merge to the base (cell) image.\n", "* `percent_overlap`: percent threshold required for a cell mask to be merged into an object mask\n", + "* `expansion_factor`: number of pixels out from an objects bounding box a cell should be looked for. (Default 10 pixels)\n", "* `cell_dir`: the final mask directory\n", "* `cell_mask_suffix`: Suffix name of the cell mask files. Usually \"whole_cell\"\n", "* `merged_masks_dir`: the directory to store the merged masks" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -580,6 +586,7 @@ "source": [ "merge_masks_list = [\"microglia-arms\", \"astrocyte-arms\"]\n", "percent_overlap = 30\n", + "expansion_factor = 10\n", "\n", "# Overwrite if different from above\n", "cell_dir = os.path.join(segmentation_dir, \"deepcell_output\")\n", @@ -633,6 +640,7 @@ " cell_mask_dir=cell_dir,\n", " cell_mask_suffix=cell_mask_suffix,\n", " overlap_percent_threshold=percent_overlap,\n", + " expansion_factor=expansion_factor,\n", " save_path=merged_masks_dir,\n", " log_dir=log_dir\n", ")" diff --git a/tests/segmentation/ez_seg/merge_masks_test.py b/tests/segmentation/ez_seg/merge_masks_test.py index 695c2f72a..7b42ce80d 100644 --- a/tests/segmentation/ez_seg/merge_masks_test.py +++ b/tests/segmentation/ez_seg/merge_masks_test.py @@ -4,10 +4,11 @@ import skimage.io as io import tempfile import xarray as xr - +import pytest from alpineer import image_utils from ark.segmentation.ez_seg import merge_masks -from skimage.morphology import label +from skimage.measure import regionprops_table, label +import pandas as pd from skimage.draw import disk from typing import List, Union @@ -26,6 +27,7 @@ def test_merge_masks_seq(): os.mkdir(directory) overlap_thresh: int = 10 + expansion_factor: int = 10 for fov in fov_list: cell_mask_data: np.ndarray = np.random.randint(0, 16, (32, 32)) @@ -42,10 +44,9 @@ def test_merge_masks_seq(): image_utils.save_image(object_mask_fov_file, object_mask_data) # we're only testing functionality, for in-depth merge testing see test_merge_masks_single - merge_masks.merge_masks_seq( - fov_list, object_list, object_mask_dir, cell_mask_dir, cell_mask_suffix, - overlap_thresh, merged_mask_dir, log_dir - ) + merge_masks.merge_masks_seq(fov_list, object_list, object_mask_dir, cell_mask_dir, + cell_mask_suffix, overlap_thresh, expansion_factor, + merged_mask_dir, log_dir) for fov in fov_list: print("checking fov") @@ -77,6 +78,7 @@ def test_merge_masks_single(): expected_cell_mask: np.ndarray = np.zeros((32, 32)) overlap_thresh: int = 10 + expansion_factor: int = 10 merged_mask_name: str = "merged_mask" # case 1: overlap below threshold, don't merge @@ -107,9 +109,12 @@ def test_merge_masks_single(): mask_save_dir: Union[str, pathlib.Path] = os.path.join(td, "mask_save_dir") os.mkdir(mask_save_dir) - created_cell_mask: np.ndarray = merge_masks.merge_masks_single( - object_mask, cell_mask, overlap_thresh, merged_mask_name, mask_save_dir - ) + created_cell_mask: np.ndarray = merge_masks.merge_masks_single(object_mask, + cell_mask, + overlap_thresh, + merged_mask_name, + mask_save_dir, + expansion_factor) created_merged_mask: np.ndarray = io.imread( os.path.join(mask_save_dir, merged_mask_name + "_merged.tiff") @@ -117,3 +122,69 @@ def test_merge_masks_single(): assert np.all(created_merged_mask == expected_merged_mask) assert np.all(created_cell_mask == expected_cell_mask) + + +def test_get_bounding_boxes(): + # Create a labeled array + labels = np.array([[1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 2]]) + + # Call the function to get bounding boxes + bounding_boxes = merge_masks.get_bounding_boxes(labels) + + # Expected bounding boxes + expected_bounding_boxes = {1: ((0, 0), (1, 1)), + 2: ((2, 2), (2, 3))} + + assert bounding_boxes == expected_bounding_boxes + + +def test_filter_labels_in_bbox(): + # Create a labeled array + labels = np.array([[1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 2]]) + + # Get regionprops df + label_df = pd.DataFrame(regionprops_table( + label(labels), properties=('label', 'centroid', 'major_axis_length'))) + + # Get the bounding boxes + bounding_boxes = merge_masks.get_bounding_boxes(label(labels)) + + # Filter labels within the bounding box of label 1 + filtered_labels = merge_masks.filter_labels_in_bbox( + bounding_boxes[1], label_df, expansion_factor=0) + + # Expected filtered labels for label 1 + expected_filtered_labels_1 = [1] + + assert filtered_labels == expected_filtered_labels_1 + + # Filter labels within the bounding box of label 2 + filtered_labels = merge_masks.filter_labels_in_bbox( + bounding_boxes[2], label_df, expansion_factor=0) + + # Expected filtered labels for label 2 + expected_filtered_labels_2 = [2] + + assert filtered_labels == expected_filtered_labels_2 + + # Filter labels within the bounding box of an empty label (should return an empty list) + filtered_labels = merge_masks.filter_labels_in_bbox( + ((0, 0), (0, 0)), label_df, expansion_factor=0) + + # Expected filtered labels for an empty label + expected_filtered_labels_empty = [] + + assert filtered_labels == expected_filtered_labels_empty + + # Filter labels within the bounding box of label 1 and expansion + filtered_labels = merge_masks.filter_labels_in_bbox( + bounding_boxes[1], label_df, expansion_factor=10) + + # Expected filtered labels for label 1 and 2 + expected_filtered_labels_expanded = [1, 2] + + assert filtered_labels == expected_filtered_labels_expanded diff --git a/tests/utils/data_utils_test.py b/tests/utils/data_utils_test.py index 09fa8e626..a5ef8396b 100644 --- a/tests/utils/data_utils_test.py +++ b/tests/utils/data_utils_test.py @@ -778,38 +778,35 @@ def test_convert_ct_fov_to_adata(tmp_path: pytest.TempPathFactory): n_cells = 100 n_markers = 10 ct = ark_test_utils.make_cell_table(n_cells=n_cells, n_markers=n_markers) - ct_dd = dd.from_pandas(ct, npartitions=2) - fov1_dd = ct_dd[ct_dd[settings.FOV_ID] == 1] + + ct_gb = ct.groupby(by=settings.FOV_ID) + fov1_ct = ct_gb.get_group(1) var_names = [f"marker_{i}" for i in range(n_markers)] - obs_names = fov1_dd.drop(columns=var_names).columns.to_list() + obs_names = fov1_ct.drop(columns=var_names).columns.to_list() fov1_adata_save_path = data_utils._convert_ct_fov_to_adata( - fov_dd=fov1_dd, + fov_group=fov1_ct, var_names=var_names, obs_names=obs_names, save_dir=tmp_path ) - save_path = fov1_adata_save_path.compute() # Assert that the file exists assert (tmp_path / "1.zarr").exists() # Load the AnnData Zarr Store - fov1_adata = read_zarr(save_path) - - # compute fov1_dd for asserts - fov1_df = fov1_dd.compute() + fov1_adata = read_zarr(fov1_adata_save_path) # Assert that the obs_names follow "{fov_id}_{cell_label}" - true_obs_names = list(map(lambda label: f"1_{int(label)}", fov1_df[settings.CELL_LABEL])) + true_obs_names = list(map(lambda label: f"1_{int(label)}", fov1_ct[settings.CELL_LABEL])) assert fov1_adata.obs_names.tolist() == true_obs_names # Assert that the X / Markers values are correct - np.testing.assert_allclose(actual=fov1_adata.X, desired=fov1_df[var_names].values) + np.testing.assert_allclose(actual=fov1_adata.X, desired=fov1_ct[var_names].values) # Assert that the obs columns are correct - expected_obs_columns = fov1_df.drop( + expected_obs_columns = fov1_ct.drop( columns=[*var_names, settings.CENTROID_0, settings.CENTROID_1] ).columns assert fov1_adata.obs.columns.tolist() == expected_obs_columns.tolist() @@ -817,7 +814,7 @@ def test_convert_ct_fov_to_adata(tmp_path: pytest.TempPathFactory): # Assert that the obsm values are correct np.testing.assert_allclose( actual=fov1_adata.obsm["spatial"].values, - desired=fov1_df[[settings.CENTROID_0, settings.CENTROID_1]].values + desired=fov1_ct[[settings.CENTROID_0, settings.CENTROID_1]].values )