Skip to content

Commit

Permalink
Add comments; Remove draft evaluation function
Browse files Browse the repository at this point in the history
  • Loading branch information
LLehner committed Oct 9, 2024
1 parent 342375c commit cb4e4d1
Showing 1 changed file with 63 additions and 106 deletions.
169 changes: 63 additions & 106 deletions src/squidpy/gr/_niche.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

def calculate_niche(
adata: AnnData | SpatialData,
groups: str,
flavor: str = "neighborhood",
library_key: str | None = None,
table_key: str | None = None,
mask: pd.core.series.Series = None,
groups: str | None = None,
n_neighbors: int = 15,
resolutions: int | list[float] | None = None,
subset_groups: list[str] | None = None,
Expand All @@ -40,26 +40,21 @@ def calculate_niche(
aggregation: str = "mean",
n_components: int = 3,
random_state: int = 42,
spatial_key: str = "spatial",
spatial_connectivities_key: str = "spatial_connectivities",
spatial_distances_key: str = "spatial_distances",
copy: bool = False,
) -> AnnData | pd.DataFrame:
"""Calculate niches (spatial clusters) based on a user-defined method in 'flavor'.
The resulting niche labels with be stored in 'adata.obs'. If flavor = 'all' then all available methods
will be applied and additionally compared using cluster validation scores.
Parameters
----------
%(adata)s
groups
groups based on which to calculate neighborhood profile.
flavor
Method to use for niche calculation. Available options are:
- `{c.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile.
- `{c.SPOT.s!r}` - calculate niches using optimal transport.
- `{c.BANKSY.s!r}`- use Banksy algorithm.
- `{c.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach.
- `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication).
- `{c.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach.
- `{c.SPOT.s!r}` - calculate niches using optimal transport. (coming soon)
- `{c.BANKSY.s!r}`- use Banksy algorithm. (coming soon)
%(library_key)s
subset
Restrict niche calculation to a subset of the data.
Expand All @@ -68,6 +63,9 @@ def calculate_niche(
mask
Boolean array to filter cells which won't get assigned to a niche.
Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'.
groups
Groups based on which to calculate neighborhood profile (E.g. columns of cell type annotations in adata.obs).
Required if flavor == 'neighborhood'.
n_neighbors
Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm.
Required if flavor == 'neighborhood' or flavor == 'UTAG'.
Expand All @@ -77,7 +75,6 @@ def calculate_niche(
subset_groups
Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile.
Optional if flavor == 'neighborhood'.
Optional if flavor == 'neighborhood'.
min_niche_size
Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'.
Optional if flavor == 'neighborhood'.
Expand All @@ -88,7 +85,7 @@ def calculate_niche(
If 'True', calculate niches based on absolute neighborhood profile.
Optional if flavor == 'neighborhood'.
adj_subsets
List of adjacency matrices to use e.g. [1,2,3] for 1,2,3 neighbors respectively.
List of adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included.
Required if flavor == 'cellcharter'.
aggregation
How to aggregate count matrices. Either 'mean' or 'variance'.
Expand All @@ -98,74 +95,63 @@ def calculate_niche(
Required if flavor == 'cellcharter'.
random_state
Random state to use for GMM.
Required if flavor == 'cellcharter'.
spatial_key
Location of spatial coordinates in `adata.obsm`.
Optional if flavor == 'cellcharter'.
spatial_connectivities_key
Key in `adata.obsp` where spatial connectivities are stored.
spatial_distances_key
Key in `adata.obsp` where spatial distances are stored.
%(copy)s
"""

# check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present if no table is specified
# check whether anndata or spatialdata is provided and if spatialdata, check whether table_key is provided
if isinstance(adata, SpatialData):
if table_key is not None:
adata = adata.tables[table_key].copy()
else:
if len(adata.tables) > 1:
count = 0
for table in adata.tables.keys():
if groups in table.obs:
count += 1
table_key = table
if count > 1:
raise ValueError(
f"Multiple tables in `spatialdata` with group `{groups}` detected. Please specify which table to use in `table_key`."
)
elif count == 0:
raise ValueError(
f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`."
)
else:
adata = adata.tables[table_key].copy()
else:
((key, adata),) = adata.tables.items()
if groups not in adata.obs:
raise ValueError(
f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`."
)
raise ValueError("Please specify which table to use with `table_key`.")
else:
adata = adata

if flavor == "neighborhood":
"""adapted from https://github.com/immunitastx/monkeybread/blob/main/src/monkeybread/calc/_neighborhood_profile.py"""

# calculate the neighborhood profile for each cell (relative and absolute proportion of e.g. each cell type in the neighborhood)
rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile(
adata, groups, subset_groups, spatial_connectivities_key
)
# create AnnData object from neighborhood profile to perform scanpy functions
if not abs_nhood:
adata_neighborhood = ad.AnnData(X=rel_nhood_profile)
else:
adata_neighborhood = ad.AnnData(X=abs_nhood_profile)

# reason for scaling see https://monkeybread.readthedocs.io/en/latest/notebooks/tutorial.html#niche-analysis
if scale:
sc.pp.scale(adata_neighborhood, zero_center=True)

# mask obs to exclude cells for which no niche shall be assigned
if mask is not None:
if subset_groups is not None:
mask = mask[mask.index.isin(adata_neighborhood.obs.index)]
mask = mask[mask.index.isin(adata_neighborhood.obs.index)]
adata_neighborhood = adata_neighborhood[mask]

# required for leiden clustering (note: no dim reduction performed in original implementation)
print("calculating neighbors...")
sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X")
print("finished calculating neighbors")

if resolutions is not None:
if not isinstance(resolutions, list):
resolutions = [resolutions]
else:
raise ValueError("Please provide resolutions for leiden clustering.")

# For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label
print("starting clustering...")
for res in resolutions:
sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}")
adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map(
adata_neighborhood.obs[f"neighborhood_niche_res={res}"]
).fillna("not_a_niche")
print(f"finished clustering at resolution {res}")

# filter niches with n_cells < min_niche_size
if min_niche_size is not None:
counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts()
to_filter = counts_by_niche[counts_by_niche < min_niche_size].index
Expand All @@ -174,9 +160,11 @@ def calculate_niche(
)

elif flavor == "utag":
"""adapted from https://github.com/ElementoLab/utag/blob/main/utag/segmentation.py"""

new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key)
adata_utag = ad.AnnData(X=new_feature_matrix)
sc.tl.pca(adata_utag)
sc.tl.pca(adata_utag) # note: unlike with flavor 'neighborhood' dim reduction is performed here
sc.pp.neighbors(adata_utag, n_neighbors=n_neighbors, use_rep="X_pca")

if resolutions is not None:
Expand All @@ -185,11 +173,15 @@ def calculate_niche(
else:
raise ValueError("Please provide resolutions for leiden clustering.")

# For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label
for res in resolutions:
sc.tl.leiden(adata_utag, resolution=res, key_added=f"utag_res={res}")
adata.obs[f"utag_res={res}"] = adata_utag.obs[f"utag_res={res}"].values

elif flavor == "cellcharter":
"""adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py
and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py"""

adjacency_matrix = adata.obsp[spatial_connectivities_key]
if not isinstance(adj_subsets, list):
if adj_subsets is not None:
Expand All @@ -198,45 +190,50 @@ def calculate_niche(
raise ValueError(
"flavor 'cellcharter' requires adj_subsets to not be None. Specify list of values or maximum value of neighbors to use."
)
else:
if 0 not in adj_subsets:
adj_subsets.insert(0, 0)
if any(x < 0 for x in adj_subsets):
raise ValueError("adj_subsets must contain non-negative integers.")

aggregated_matrices = []
adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0
adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors
for k in adj_subsets:
if k == 0:
# If k == 0, we're using the original cell features (no neighbors)
# get original count matrix (not aggregated)
aggregated_matrices.append(adata.X)
else:
# get count and adjacency matrix for k-hop (neighbor of neighbor of neighbor ...) and aggregate them
if k > 1:
adj_hop, adj_visited = _hop(adj_hop, adjacency_matrix, adj_visited)

adj_hop_norm = _normalize(adj_hop) # Normalize adjacency matrix for current hop

# Apply aggregation, default to "mean" unless specified otherwise
adj_hop_norm = _normalize(adj_hop)
aggregated_matrix = _aggregate(adata, adj_hop_norm, aggregation)

# Collect the aggregated matrices
aggregated_matrices.append(aggregated_matrix)

concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally
arr = concatenated_matrix.toarray() # Densify the sparse matrix
arr = concatenated_matrix.toarray() # Densify

# cluster concatenated matrix with GMM, each cluster label equals to a niche label
niches = _get_GMM_clusters(arr, n_components, random_state)

adata.obs[f"{flavor}_niche"] = pd.Categorical(niches)


def _calculate_neighborhood_profile(
adata: AnnData,
groups: str,
groups: str | None,
subset_groups: list[str] | None,
spatial_connectivities_key: str,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""returns an obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood"""

if groups is None:
raise ValueError("Please specify 'groups' based on which to calculate neighborhood profile.")
if subset_groups:
adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc()
obs_mask = ~adata.obs[groups].isin(subset_groups)
adata = adata[obs_mask]
adata = adata[obs_mask]

# Update adjacency matrix such that it only contains connections to filtered observations
adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask]
Expand Down Expand Up @@ -274,14 +271,7 @@ def _calculate_neighborhood_profile(

def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData:
"""Performs inner product of adjacency matrix and feature matrix,
such that each observation inherits features from its immediate neighbors as described in UTAG paper.
Parameters
----------
adata
Annotated data matrix.
normalize
If 'True', aggregate by the mean, else aggregate by the sum."""
such that each observation inherits features from its immediate neighbors as described in UTAG paper."""

adjacency_matrix = adata.obsp[spatial_connectivity_key]

Expand All @@ -292,6 +282,8 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) ->


def _setdiag(adjacency_matrix: sps.spmatrix, value: int) -> sps.spmatrix:
"""remove self-loops"""

if issparse(adjacency_matrix):
adjacency_matrix = adjacency_matrix.tolil()
adjacency_matrix.setdiag(value)
Expand All @@ -304,6 +296,8 @@ def _setdiag(adjacency_matrix: sps.spmatrix, value: int) -> sps.spmatrix:
def _hop(
adj_hop: sps.spmatrix, adj: sps.spmatrix, adj_visited: sps.spmatrix = None
) -> tuple[sps.spmatrix, sps.spmatrix]:
"""get nearest neighbor of neighbors"""

adj_hop = adj_hop @ adj

if adj_visited is not None:
Expand All @@ -314,6 +308,8 @@ def _hop(


def _normalize(adj: sps.spmatrix) -> sps.spmatrix:
"""normalize adjacency matrix such that nodes with high degree don't disproportionately affect aggregation"""

deg = np.array(np.sum(adj, axis=1)).squeeze()
with np.errstate(divide="ignore"):
deg_inv = 1 / deg
Expand All @@ -323,6 +319,8 @@ def _normalize(adj: sps.spmatrix) -> sps.spmatrix:


def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggregation: str = "mean") -> Any:
"""aggregate count and adjacency matrix either by mean or variance"""
# TODO: add support for other aggregation methods
if aggregation == "mean":
aggregated_matrix = normalized_adjacency_matrix @ adata.X
elif aggregation == "variance":
Expand All @@ -339,58 +337,17 @@ def _get_GMM_clusters(A: np.ndarray[np.float64, Any], n_components: int, random_
"""Returns niche labels generated by GMM clustering.
Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis."""

print("initializing GMM...")
gmm = GaussianMixture(n_components=n_components, random_state=random_state, init_params="random_from_data")
print("fitting GMM...")
gmm.fit(A)
print("predicting labels...")
labels = gmm.predict(A)
print("done")

return labels


def _df_to_adata(df: pd.DataFrame) -> AnnData:
df.index = df.index.map(str)
adata = AnnData(X=df)
adata.obs.index = df.index
return adata


def _aggregate_var(product: csr_matrix, connectivities: csr_matrix, adata: AnnData) -> csr_matrix:
mean_squared = connectivities.dot(adata.X.multiply(adata.X))
return mean_squared - (product.multiply(product))


def pairwise_niche_comparison(
adata: AnnData,
library_key: str,
) -> pd.DataFrame:
"""Do a simple pairwise DE test on the 99th percentile of each gene for each niche.
Can be used to plot heatmap showing similar (large p-value) or different (small p-value) niches.
For validating niche results, the niche pairs that are similar in expression are the ones of interest because
it could hint at niches not being well defined in those cases."""
niches = adata.obs[library_key].unique().tolist()
niche_dict = {}
# for each niche, calculate the 99th percentile of each gene
for niche in adata.obs[library_key].unique():
niche_adata = adata[adata.obs[library_key] == niche]
n_cols = niche_adata.X.shape[1]
arr = np.ones(n_cols)
for i in range(n_cols):
col_data = niche_adata.X.getcol(i).data
percentile_99 = np.percentile(col_data, 99)
arr[i] = percentile_99
niche_dict[niche] = arr
# create 99th percentile count x niche matrix
var_by_niche = pd.DataFrame(niche_dict)
result = pd.DataFrame(index=niches, columns=niches, data=None, dtype=float)
# construct all pairs (unordered and with pairs of the same niche)
combinations = list(itertools.combinations_with_replacement(niches, 2))
# create a p-value matrix for all niche pairs
for pair in combinations:
p_val = ranksums(var_by_niche[pair[0]], var_by_niche[pair[1]], alternative="two-sided")[1]
result.at[pair[0], pair[1]] = p_val
result.at[pair[1], pair[0]] = p_val
return result


def mean_fide_score(
adatas: AnnData | list[AnnData],
library_key: str,
Expand Down

0 comments on commit cb4e4d1

Please sign in to comment.