Skip to content

Commit

Permalink
Merging staging branch into prod branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Bento007 committed Aug 1, 2024
2 parents 4da68a2 + 633c1c1 commit 8604996
Show file tree
Hide file tree
Showing 49 changed files with 194 additions and 625 deletions.
4 changes: 4 additions & 0 deletions .happy/terraform/modules/wmg-batch/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ resource aws_batch_job_definition batch_job_def {
{
"name": "REMOTE_DEV_PREFIX",
"value": "${var.remote_dev_prefix}"
},
{
"name": "CELLXGENE_CENSUS_USERAGENT",
"value": "${var.census_user_agent}"
}
],
"vcpus": ${var.desired_vcpus},
Expand Down
8 changes: 7 additions & 1 deletion .happy/terraform/modules/wmg-batch/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,10 @@ variable desired_vcpus {
variable "api_url" {
type = string
description = "URL for the backend api."
}
}

variable "census_user_agent" {
type = string
description = "User agent for the census API"
default = "CZI-wmg"
}
2 changes: 1 addition & 1 deletion backend/cellguide/pipeline/source_collections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def get_source_collections_data():
# for convenient imports
snapshot = sn.load_snapshot(snapshot_schema_version=CELLGUIDE_CENSUS_CUBE_DATA_SCHEMA_VERSION)
all_cell_type_ids_in_corpus = get_all_cell_type_ids_in_corpus(snapshot)
return generate_source_collections_data(all_cell_type_ids_in_corpus)
return generate_source_collections_data(all_cell_type_ids_in_corpus, snapshot.cell_counts_df)
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
import cellxgene_census
from pandas import DataFrame

from backend.cellguide.pipeline.source_collections.types import SourceCollectionsData
from backend.common.census_cube.utils import (
descendants,
get_collections_from_discover_api,
get_datasets_from_discover_api,
)
from backend.common.citation import format_citation_dp
from backend.common.census_cube.utils import descendants


def generate_source_collections_data(all_cell_type_ids_in_corpus: list[str]) -> dict[str, list[SourceCollectionsData]]:
def generate_source_collections_data(
all_cell_type_ids_in_corpus: list[str], cell_counts_df: DataFrame
) -> dict[str, list[SourceCollectionsData]]:
"""
For each cell type id in the corpus, we want to generate a SourceCollectionsData object, which contains
metadata about the source data for each cell type
"""
all_datasets = get_datasets_from_discover_api()
all_collections = get_collections_from_discover_api()
dataset_id_to_cell_type_ids_map = {}
dataset_id_to_tissue_ids_map = {}
dataset_id_to_disease_ids_map = {}
dataset_id_to_organism_ids_map = {}

for map_dict, column_name in zip(
[
dataset_id_to_cell_type_ids_map,
dataset_id_to_tissue_ids_map,
dataset_id_to_disease_ids_map,
dataset_id_to_organism_ids_map,
],
[
"cell_type_ontology_term_id",
"tissue_ontology_term_id",
"disease_ontology_term_id",
"organism_ontology_term_id",
],
strict=False,
):
df_agg = cell_counts_df.groupby("dataset_id").agg({column_name: lambda x: ",".join(set(x.values))})
df_dict = {df_agg.index[i]: df_agg.values[i][0].split(",") for i in range(len(df_agg))}
map_dict.update(df_dict)

collections_dict = {collection["collection_id"]: collection for collection in all_collections}
datasets_dict = {dataset["dataset_id"]: dataset for dataset in all_datasets}
with cellxgene_census.open_soma(census_version="latest") as census:
datasets_metadata_df = census["census_info"]["datasets"].read().concat().to_pandas()

source_collections_data: dict[str, list[SourceCollectionsData]] = {}
for cell_id in all_cell_type_ids_in_corpus:
Expand All @@ -25,57 +46,41 @@ def generate_source_collections_data(all_cell_type_ids_in_corpus: list[str]) ->

# Generate a list of unique dataset ids that contain cell type ids in the lineage
dataset_ids: list[str] = []
for dataset_id in datasets_dict:
for cell_type in datasets_dict[dataset_id]["cell_type"]:
if cell_type["ontology_term_id"] in lineage:
for dataset_id in dataset_id_to_cell_type_ids_map:
for cell_type in dataset_id_to_cell_type_ids_map[dataset_id]:
if cell_type in lineage:
dataset_ids.append(dataset_id)
break
assert len(set(dataset_ids)) == len(dataset_ids)

# Generate a mapping from collection id to SourceCollectionsData
collections_to_source_data = {}
for dataset_id in dataset_ids:
dataset = datasets_dict[dataset_id]
collection_id = dataset["collection_id"]

# If we don't have any source data on this collection yet, create a new SourceCollectionsData object
unique_dataset_ids = cell_counts_df["dataset_id"].unique()
for i in range(len(datasets_metadata_df)):
dataset_id = datasets_metadata_df.iloc[i]["dataset_id"]

# dataset_ids is coming from cell_counts_df and dataset_id is coming from census
# cell_counts_df also is derived from census, so this condition should never trigger
# but we add it here anyway to be safe in case WMG decides to filter out datasets
# from the census
if dataset_id not in dataset_ids:
continue

assert dataset_id in unique_dataset_ids, f"{dataset_id} not in cell_counts_df"

collection_id = datasets_metadata_df.iloc[i]["collection_id"]
if collection_id not in collections_to_source_data:
collection = collections_dict.get(collection_id)
source_data = SourceCollectionsData(
collection_name=collection["name"],
collection_url=collection["collection_url"],
publication_url=collection["doi"],
publication_title=(
format_citation_dp(collection["publisher_metadata"])
if collection["publisher_metadata"]
else "No publication"
),
tissue=[],
disease=[],
organism=[],
collection_name=datasets_metadata_df.iloc[i]["collection_name"],
collection_url=f"https://cellxgene.cziscience.com/collections/{collection_id}",
publication_url=datasets_metadata_df.iloc[i]["collection_doi"],
publication_title=datasets_metadata_df.iloc[i]["collection_doi_label"],
tissue=dataset_id_to_tissue_ids_map[dataset_id],
disease=dataset_id_to_disease_ids_map[dataset_id],
organism=dataset_id_to_organism_ids_map[dataset_id],
)
collections_to_source_data[collection_id] = source_data

# Add the tissue, disease, and organism metadata from the dataset to the SourceCollectionsData object. If we
# previously found a SourceCollectionsData object, we'll want to add the tissue/disease/organism from this
# dataset to the existing object. If not, we'll want to create a new list for each of these fields.
source_data = collections_to_source_data[collection_id]

for tissue in dataset["tissue"]:
if tissue["ontology_term_id"] not in [t["ontology_term_id"] for t in source_data.tissue]:
source_data.tissue.append(tissue)

for disease in dataset["disease"]:
if disease["ontology_term_id"] not in [d["ontology_term_id"] for d in source_data.disease]:
source_data.disease.append(disease)

for organism in dataset["organism"]:
if organism["ontology_term_id"] not in [o["ontology_term_id"] for o in source_data.organism]:
source_data.organism.append(organism)

# Add the SourceCollectionsData to the mapping
collections_to_source_data[collection_id] = source_data

source_collections_data[cell_id] = list(collections_to_source_data.values())

return source_collections_data
1 change: 0 additions & 1 deletion backend/common/census_cube/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"EFO:0010010": "CEL-seq2",
}

CENSUS_CUBE_PINNED_SCHEMA_VERSION = "5.1.0"

CENSUS_CUBE_DATA_SCHEMA_VERSION = "v5"

Expand Down
37 changes: 0 additions & 37 deletions backend/common/census_cube/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from functools import lru_cache
from typing import Dict, Optional

Expand All @@ -10,9 +9,7 @@
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from backend.common.census_cube.data.constants import CENSUS_CUBE_PINNED_SCHEMA_VERSION
from backend.common.census_cube.data.snapshot import CensusCubeSnapshot
from backend.common.constants import DEPLOYMENT_STAGE_TO_API_URL

# exported and used by all modules related to the census cube
ontology_parser = OntologyParser()
Expand Down Expand Up @@ -118,40 +115,6 @@ def setup_retry_session(retries=3, backoff_factor=2, status_forcelist=(500, 502,
return session


def get_datasets_from_discover_api():
# hardcode to staging backend if deployment is rdev or test
deployment_stage = os.environ.get("DEPLOYMENT_STAGE")
API_URL = DEPLOYMENT_STAGE_TO_API_URL.get(
deployment_stage, "https://api.cellxgene.staging.single-cell.czi.technology"
)

datasets = {}
if API_URL:
session = setup_retry_session()
dataset_metadata_url = f"{API_URL}/curation/v1/datasets?schema_version={CENSUS_CUBE_PINNED_SCHEMA_VERSION}"
response = session.get(dataset_metadata_url)
if response.status_code == 200:
datasets = response.json()
return datasets


def get_collections_from_discover_api():
# hardcode to staging backend if deployment is rdev or test
deployment_stage = os.environ.get("DEPLOYMENT_STAGE")
API_URL = DEPLOYMENT_STAGE_TO_API_URL.get(
deployment_stage, "https://api.cellxgene.staging.single-cell.czi.technology"
)

collections = {}
if API_URL:
session = setup_retry_session()
dataset_metadata_url = f"{API_URL}/curation/v1/collections"
response = session.get(dataset_metadata_url)
if response.status_code == 200:
collections = response.json()
return collections


def build_filter_relationships(cell_counts_df: pd.DataFrame):
# get a dataframe of the columns that are not numeric
df_filters = cell_counts_df.select_dtypes(exclude="number")
Expand Down
12 changes: 9 additions & 3 deletions backend/wmg/pipeline/cell_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import tiledb
from pandas import DataFrame
from tiledbsoma import ExperimentAxisQuery

from backend.common.census_cube.data.schemas.cube_schema import (
Expand All @@ -18,14 +19,15 @@
create_empty_cube_if_needed,
log_func_runtime,
remove_accents,
return_dataset_dict_w_publications,
)

logger = logging.getLogger(__name__)


@log_func_runtime
def create_cell_counts_cube(*, query: ExperimentAxisQuery, corpus_path: str, organismId: str):
def create_cell_counts_cube(
*, dataset_metadata: DataFrame, query: ExperimentAxisQuery, corpus_path: str, organismId: str
):
"""
Create cell count cube and write to disk
"""
Expand All @@ -46,7 +48,11 @@ def create_cell_counts_cube(*, query: ExperimentAxisQuery, corpus_path: str, org
).size()
).rename(columns={"size": "n_cells"})

dataset_dict = return_dataset_dict_w_publications()
dataset_dict = {
row["dataset_id"]: row["collection_doi_label"]
for _, row in dataset_metadata.iterrows()
if row["collection_doi_label"]
}
df["publication_citation"] = [
remove_accents(dataset_dict.get(dataset_id, "No Publication")) for dataset_id in df["dataset_id"]
]
Expand Down
21 changes: 21 additions & 0 deletions backend/wmg/pipeline/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,24 @@
"UBERON:0000014", # zone of skin
"UBERON:0000916", # abdomen
]


ORGANISM_INFO = [
{"label": "homo_sapiens", "id": "NCBITaxon:9606"},
{"label": "mus_musculus", "id": "NCBITaxon:10090"},
]


class CensusParameters:
census_version = "latest"

@staticmethod
def value_filter(organism: str) -> str:
organism_mapping = {
"homo_sapiens": f"is_primary_data == True and nnz >= {GENE_EXPRESSION_COUNT_MIN_THRESHOLD} and cell_type_ontology_term_id != 'unknown'",
"mus_musculus": f"is_primary_data == True and nnz >= {GENE_EXPRESSION_COUNT_MIN_THRESHOLD} and cell_type_ontology_term_id != 'unknown'",
}
value_filter = organism_mapping[organism]
# Filter out system-level tissues. Census filters out organoids + cell cultures
value_filter += " and tissue_general_ontology_term_id != 'UBERON:0001017' and tissue_general_ontology_term_id != 'UBERON:0001007' and tissue_general_ontology_term_id != 'UBERON:0002405' and tissue_general_ontology_term_id != 'UBERON:0000990' and tissue_general_ontology_term_id != 'UBERON:0001004' and tissue_general_ontology_term_id != 'UBERON:0001434'"
return value_filter
39 changes: 31 additions & 8 deletions backend/wmg/pipeline/dataset_metadata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import json
import logging
import os

from backend.common.census_cube.data.snapshot import DATASET_METADATA_FILENAME
from backend.common.census_cube.utils import get_datasets_from_discover_api
from backend.wmg.pipeline.constants import DATASET_METADATA_CREATED_FLAG
import cellxgene_census
import tiledb

from backend.common.census_cube.data.snapshot import CELL_COUNTS_CUBE_NAME, DATASET_METADATA_FILENAME
from backend.wmg.pipeline.constants import (
DATASET_METADATA_CREATED_FLAG,
EXPRESSION_SUMMARY_AND_CELL_COUNTS_CUBE_CREATED_FLAG,
CensusParameters,
)
from backend.wmg.pipeline.errors import PipelineStepMissing
from backend.wmg.pipeline.utils import load_pipeline_state, log_func_runtime, write_pipeline_state

logger = logging.getLogger(__name__)
Expand All @@ -18,16 +26,31 @@ def create_dataset_metadata(corpus_path: str) -> None:
"""
logger.info("Generating dataset metadata file")
pipeline_state = load_pipeline_state(corpus_path)
datasets = get_datasets_from_discover_api()

if not pipeline_state.get(EXPRESSION_SUMMARY_AND_CELL_COUNTS_CUBE_CREATED_FLAG):
raise PipelineStepMissing("cell_counts")

with cellxgene_census.open_soma(census_version=CensusParameters.census_version) as census:
dataset_metadata = census["census_info"]["datasets"].read().concat().to_pandas()

# read in the cell counts df and only keep the dataset_ids that are in the cube
with tiledb.open(os.path.join(corpus_path, CELL_COUNTS_CUBE_NAME)) as cc_cube:
cell_counts_df = cc_cube.df[:]
unique_dataset_ids = cell_counts_df["dataset_id"].unique()

dataset_metadata = dataset_metadata[dataset_metadata["dataset_id"].isin(unique_dataset_ids)]

datasets = dataset_metadata.to_dict(orient="records")

dataset_dict = {}
for dataset in datasets:
dataset_id = dataset["dataset_id"]
dataset_dict[dataset_id] = dict(
id=dataset_id,
label=dataset["title"],
dataset_dict[dataset["dataset_id"]] = dict(
id=dataset["dataset_id"],
label=dataset["dataset_title"],
collection_id=dataset["collection_id"],
collection_label=dataset["collection_name"],
)

logger.info("Writing dataset metadata file")
with open(f"{corpus_path}/{DATASET_METADATA_FILENAME}", "w") as f:
json.dump(dataset_dict, f)
Expand Down
12 changes: 9 additions & 3 deletions backend/wmg/pipeline/expression_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
load_pipeline_state,
log_func_runtime,
remove_accents,
return_dataset_dict_w_publications,
)

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +36,9 @@


class ExpressionSummaryCubeBuilder:
def __init__(self, *, query: ExperimentAxisQuery, corpus_path: str, organismId: str):
def __init__(
self, *, dataset_metadata: pd.DataFrame, query: ExperimentAxisQuery, corpus_path: str, organismId: str
):
self.obs_df = query.obs().concat().to_pandas()
self.obs_df = self.obs_df.rename(columns=DIMENSION_NAME_MAP_CENSUS_TO_WMG)
self.obs_df["organism_ontology_term_id"] = organismId
Expand All @@ -51,6 +52,7 @@ def __init__(self, *, query: ExperimentAxisQuery, corpus_path: str, organismId:
self.corpus_path = corpus_path

self.pipeline_state = load_pipeline_state(corpus_path)
self.dataset_metadata = dataset_metadata

@log_func_runtime
def create_expression_summary_cube(self):
Expand Down Expand Up @@ -289,7 +291,11 @@ def _build_in_mem_cube(
idx = 0

if "publication_citation" in other_cube_attrs:
dataset_dict = return_dataset_dict_w_publications()
dataset_dict = {
row["dataset_id"]: row["collection_doi_label"]
for _, row in self.dataset_metadata.iterrows()
if row["collection_doi_label"]
}

for grp in cube_index.to_records():
(
Expand Down
Loading

0 comments on commit 8604996

Please sign in to comment.