Skip to content

Commit

Permalink
Lk pd 2736 doublets (#136)
Browse files Browse the repository at this point in the history
Adding doublets
  • Loading branch information
ekiernan authored Oct 17, 2024
1 parent dfa6b3b commit 74c8b3f
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 25 deletions.
29 changes: 6 additions & 23 deletions 3rd-party-tools/star-merge-npz/scripts/combine_shard_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np


# Function that takes in STARsolo alignment metrics and filtered matrix to produce library-level metrics
# The filtered matrix is produced by STARsolo and contains UMIs for barcodes flagged as actual cells
def merge_matrices(summary_file, align_file, cell_reads, counting_mode, uniform_barcodes, uniform_mtx, expected_cells):
# Read the whitelist into a set.
expected_cells = int(expected_cells)
Expand Down Expand Up @@ -92,23 +94,9 @@ def merge_matrices(summary_file, align_file, cell_reads, counting_mode, uniform_
unique_rows = filtered[0].unique()
total_genes_unique_detected = len(unique_rows)

# Keeper cell metrics
# Expected number of cells is unknown, so this commented out for now
#expected_cells = 3000 # Placeholder, replace with actual value
percent_target = estimated_cells/expected_cells
percent_intronic_reads = reads_mapped_confidently_to_intronic_regions/n_reads

if counting_mode == "sc_rna":
gene_threshold = 1500
else:
gene_threshold = 1000

keeper_cells = cells_filtered[cells_filtered["nGenesUnique"] > gene_threshold]
keeper_mean_reads_per_cell = keeper_cells["countedU"].mean()
keeper_median_genes = keeper_cells["nGenesUnique"].median()
keeper_cells_count = len(keeper_cells)
percent_keeper = keeper_cells_count/estimated_cells
percent_usable = keeper_cells_count/expected_cells

percent_target = estimated_cells/expected_cells*100
percent_intronic_reads = reads_mapped_confidently_to_intronic_regions/n_reads*100

data = {
"number_of_reads": [n_reads],
Expand All @@ -135,12 +123,7 @@ def merge_matrices(summary_file, align_file, cell_reads, counting_mode, uniform_
"median_gene_per_cell": [median_gene_per_cell],
"total_genes_unique_detected": [total_genes_unique_detected],
"percent_target": [percent_target],
"percent_intronic_reads": [percent_intronic_reads],
"keeper_mean_reads_per_cell": [keeper_mean_reads_per_cell],
"keeper_median_genes": [keeper_median_genes],
"keeper_cells": [keeper_cells_count],
"percent_keeper": [percent_keeper],
"percent_usable": [percent_usable]
"percent_intronic_reads": [percent_intronic_reads]
}

df = pd.DataFrame(data)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This repository has the container that hosts all the scripts and tools that [WARP](https://github.com/broadinstitute/warp) uses.

The project structure is straightforward and contains essentially just two types of directories: a tool directory, containing all our in-house tools, and a 3rd-party-tools directory which hosts all the third-party containers we use in our pipelines.
Each directory contains it's own README that describes the tool or scripts, along with a usage guide.
Each directory contains it's own README that describes the tool or scripts, along with a usage guide.

## .github/workflows
This contains all YML files for automated container builds.
Expand Down
2 changes: 1 addition & 1 deletion tools/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ RUN set -eux && \
mkdir -p /warptools && \
apt-get update && apt-get upgrade -y && apt-get install -y libhdf5-dev vim apt-utils liblzma-dev libbz2-dev tini && \
pip install --upgrade pip && \
pip install loompy==3.0.6 anndata==0.7.8 numpy==1.23.0 pandas==1.3.5 scipy pysam==0.21 && \
pip install loompy==3.0.6 anndata==0.10.8 numpy==1.23.0 pandas==2.2.2 scikit-learn==1.5.1 scanpy==1.10.2 scipy pysam==0.21 && \
curl -sSL https://sdk.cloud.google.com | bash

COPY . /warptools
Expand Down
183 changes: 183 additions & 0 deletions tools/scripts/add_library_tso_doublets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import numpy as np
import anndata as ad
import pandas as pd
from sklearn.neighbors import KNeighborsTransformer
import argparse
import scanpy as sc


# Function to flag cell barcodes that are filtered as cells by STARsolo
def call_cells(cellbarcodes, gex_h5ad):
cells=pd.read_csv(cellbarcodes, sep="\t", header=None)
adata=ad.read_h5ad(gex_h5ad)
adata.obs["star_IsCell"] = adata.obs.index.isin(cells[0])
return adata

# Function to compute doublet scores using a modified version of DoubletFinder
# This python implementation was provided by the Allen Institute
def compute_doublet_scores(gex_h5ad_modified, proportion_artificial=0.2):
adata = gex_h5ad_modified
adata.var_names_make_unique()
adata = adata[adata.obs["star_IsCell"] == True, :]
print("adata with star_IsCell == True", adata)
k = np.int64(np.round(np.min([100, adata.shape[0] * 0.01])))
n_doublets = np.int64(np.round(adata.shape[0] / (1 - proportion_artificial) - adata.shape[0]))
real_cells_1 = np.random.choice(adata.obs_names, size=n_doublets, replace=True)
real_cells_2 = np.random.choice(adata.obs_names, size=n_doublets, replace=True)
doublet_X = adata[real_cells_1, :].X + adata[real_cells_2, :].X
doublet_obs_names = [f"X{i}" for i in range(n_doublets)]
doublet_adata = ad.AnnData(X=doublet_X, obs=pd.DataFrame(index=doublet_obs_names), var=pd.DataFrame(index=adata.var_names))
adata = adata.concatenate(doublet_adata, index_unique=None)

adata.obs["doublet_cell"] = adata.obs_names.isin(doublet_obs_names)
adata.obs["doublet_cell"] = adata.obs["doublet_cell"].astype("category")
adata.layers["UMIs"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

try:
sc.pp.highly_variable_genes(adata, n_top_genes=5000, flavor="seurat_v3", layer="UMIs")
adata.layers["UMIs"]

except:
sc.pp.highly_variable_genes(adata, min_mean=1, min_disp=0.5)
del adata.layers["UMIs"]

adata_sub = adata[:, adata.var["highly_variable"]].copy()
sc.pp.scale(adata_sub)
sc.pp.pca(adata_sub)
v = adata_sub.uns['pca']['variance']
n_pcs = np.max(np.where(((v - np.mean(v)) / np.std(v)) > 0)[0])
knn = KNeighborsTransformer(
n_neighbors=k,
algorithm="kd_tree",
n_jobs=-1,
)

knn = knn.fit(adata_sub.obsm["X_pca"][:, :n_pcs])
dist, idx = knn.kneighbors()
knn_mapper = KNeighborsTransformer(
n_neighbors=10,
algorithm="kd_tree",
n_jobs=-1,
)

knn_mapper = knn_mapper.fit(adata_sub[adata_sub.obs["doublet_cell"] == False, :].obsm["X_pca"][:, :n_pcs])
dist1, _ = knn_mapper.kneighbors(adata_sub[adata_sub.obs["doublet_cell"] == True, :].obsm["X_pca"][:, :n_pcs])
dist_th = np.mean(dist1) + (1.64 * np.std(dist1))
freq = (dist < dist_th) & (idx > adata[adata.obs["doublet_cell"] == False, :].shape[0])
score1 = freq.mean(axis=1)
score2 = freq[:, :np.int(np.ceil(k/2))].mean(axis=1)
adata.obs["doublet_score"] = np.maximum(score1, score2)
doublet_csv=adata.obs.loc[~adata.obs_names.isin(doublet_obs_names), ["doublet_score"]]

# Calculate the percentage of doublets with a doublet_score > 0.3
num_doublets = doublet_csv[doublet_csv["doublet_score"] > 0.3].shape[0]
total_cells = doublet_csv.shape[0]
percent_doublets = num_doublets / total_cells * 100

return doublet_csv, percent_doublets


# Function to calculate additional library metrics such as keeper cells
def process_gex_data(gex_h5ad_modified, gex_nhash_id, library_csv, input_id, doublets, doublet_scores, counting_mode, expected_cells):
print("Reading Optimus h5ad:")
gex_data = gex_h5ad_modified
# NHashID is optional input, so the logic below sets it if undefined
if gex_nhash_id is not None:
gex_data.uns['NHashID'] = gex_nhash_id
else:
gex_nhash_id = "NA"
gex_data.uns['NHashID'] = gex_nhash_id

#gex_data.write(f"{input_id}.h5ad")

print("Reading library metrics")
library = pd.read_csv(library_csv, header=None)

# Calculates total library TSO metrics
# TSO reads refer to reads derived from the Template Switch Oligo
# TSO reads per cell are calculated from the BAM cN BAM tag
print("Calculating TSO frac")
tso_reads = gex_data.obs.tso_reads.sum() / gex_data.obs.n_reads.sum()
print("TSO reads:")
print(tso_reads)

print("Calclating keeper metrics based on doublets and n_genes")
if counting_mode == "sc_rna":
gene_threshold = 1500
else:
gene_threshold = 1000

estimated_cells = len(gex_data[gex_data.obs["star_IsCell"]==True])
# Expected cells is the number expected from the experiment; usually 10,000 with 10x data
expected_cells = int(expected_cells) # Placeholder, replace with actual value

# Adding doublet scores to barcodes that have been called as cells
all_barcodes = pd.DataFrame(index=gex_data.obs_names)
# Merge doublet scores with all barcodes, filling missing values with NA
all_barcodes = all_barcodes.join(doublet_scores, how='left')
# Assign the doublet scores back to the adata object
gex_data.obs['doublet_score'] = all_barcodes['doublet_score']

# Adding keeper metrics
subset = gex_data[gex_data.obs['star_IsCell'] & (gex_data.obs['doublet_score'] < 0.3) & (gex_data.obs['n_genes'] > gene_threshold)]
keeper_cells = subset.shape[0]
keeper_mean_reads_per_cell = subset.obs["n_reads"].mean()
keeper_median_genes = subset.obs["n_genes"].median()
percent_keeper = keeper_cells/estimated_cells
percent_usable = keeper_cells/expected_cells

# Updating library metrics
dictionary = library.set_index(0)[1].to_dict()
dictionary['frac_tso'] = tso_reads
dictionary['percent_doublets'] = doublets
dictionary['keeper_cells'] = keeper_cells
dictionary['keeper_mean_reads_per_cell'] = keeper_mean_reads_per_cell
dictionary['keeper_median_genes'] = keeper_median_genes
dictionary['percent_keeper'] = percent_keeper*100
dictionary['percent_usable'] = percent_usable*100

new_dictionary = {"NHashID": [gex_nhash_id]} # This line is fine, it already has a list
# Update other scalar values to lists
dictionary = {key: [value] for key, value in dictionary.items()}
new_dictionary.update(dictionary)
new_dictionary = pd.DataFrame(new_dictionary)
new_dictionary.transpose().to_csv("library_metrics.csv", header=None)
return gex_data

def main():
description = """This script converts the some of the Optimus outputs in to
h5ad format.
This script can be used as a module or run as a command line script."""
parser = argparse.ArgumentParser(description="Process single-cell RNA-seq data and compute doublet scores.")
parser.add_argument("--proportion_artificial", type=float, default=0.2, help="Proportion of artificial doublets to be generated (default is 0.2).")
parser.add_argument("--gex_h5ad", type=str, required=True, help="Path to the GEX h5ad file.")
parser.add_argument("--cellbarcodes", type=str, required=True, help="Path to the cell barcodes file.")
parser.add_argument("--gex_nhash_id", type=str, required=False, help="NHashID for the GEX data.")
parser.add_argument("--library_csv", type=str, required=True, help="Path to the library metrics CSV file.")
parser.add_argument("--input_id", type=str, required=True, help="Input ID for output files.")
parser.add_argument("--counting_mode", type=str, required=True, help="Counting mode for STARsolo alignment.")
parser.add_argument("--expected_cells", type=int, required=True, help="Expected number of cells.")

args = parser.parse_args()
# Compute cell calls and doublet scores
print("Calculating cell calls")
cell_h5ad=call_cells(args.cellbarcodes, args.gex_h5ad)
print("Calculating doublets based on cell calls")
doublet_scores, percent_doublets = compute_doublet_scores(cell_h5ad, proportion_artificial=args.proportion_artificial)
print("Adding doublet scores, NHashID to h5ad and calculating library metrics")
revised_adata = process_gex_data(cell_h5ad, args.gex_nhash_id, args.library_csv, args.input_id, percent_doublets, doublet_scores, args.counting_mode, args.expected_cells)
# Output the results
output_path = args.gex_h5ad.replace(".h5ad", "_doublet_scores.csv")
print("Output path: ", output_path)
doublet_scores.to_csv(output_path)
print("Saving revised adata object")
revised_adata.write(f"{args.input_id}.h5ad")
print(f"Doublet scores saved to {output_path}")
print("Percent_doublets: ", percent_doublets)
print("Done!")

if __name__ == "__main__":
main()

0 comments on commit 74c8b3f

Please sign in to comment.