From 8e60ca64b24914f92ee998018f9cf6b6c1b0a823 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 6 Feb 2024 09:09:54 -0800 Subject: [PATCH 01/21] Update to codecov-action v4 (#2469) --- .github/workflows/test_linux.yml | 2 +- .github/workflows/test_linux_cuda.yml | 2 +- .github/workflows/test_linux_pre.yml | 2 +- .github/workflows/test_linux_private.yml | 2 +- .github/workflows/test_macos.yml | 2 +- .github/workflows/test_macos_m1.yml | 2 +- .github/workflows/test_windows.yml | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index a9efb39b20..5338f5f532 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -57,4 +57,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 954059ef42..c15d02bd86 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -61,4 +61,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_pre.yml b/.github/workflows/test_linux_pre.yml index 69d944a72c..607e893db0 100644 --- a/.github/workflows/test_linux_pre.yml +++ b/.github/workflows/test_linux_pre.yml @@ -67,4 +67,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_private.yml b/.github/workflows/test_linux_private.yml index 35d598d5ad..8d6d68a077 100644 --- a/.github/workflows/test_linux_private.yml +++ b/.github/workflows/test_linux_private.yml @@ -80,4 +80,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml index e7c48c4836..a57c2b3783 100644 --- a/.github/workflows/test_macos.yml +++ b/.github/workflows/test_macos.yml @@ -56,4 +56,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_macos_m1.yml b/.github/workflows/test_macos_m1.yml index 5b80ef267f..12b6f7678f 100644 --- a/.github/workflows/test_macos_m1.yml +++ b/.github/workflows/test_macos_m1.yml @@ -56,4 +56,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_windows.yml b/.github/workflows/test_windows.yml index 144411d8ef..a8b92a7871 100644 --- a/.github/workflows/test_windows.yml +++ b/.github/workflows/test_windows.yml @@ -56,4 +56,4 @@ jobs: run: | coverage report - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 From d102254cb27c3065ca8d0d900756db69d30792cb Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:46:36 -0800 Subject: [PATCH 02/21] Temporarily remove cellxgene-census dependency, remove chex upperbound (#2474) See #2472 for more details. Opening a new PR as there were too many changes in the previous one. --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6906a3fdf3..785e97cb09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ "anndata>=0.7.5", - "chex<=0.1.8", # see https://github.com/scverse/scvi-tools/pull/2187 + "chex", "docrep>=0.3.2", "flax", "jax>=0.4.4", @@ -104,7 +104,7 @@ regseq = ["biopython>=1.81", "genomepy"] # scvi.data.add_dna_sequence loompy = ["loompy>=3.0.6"] # read loom scanpy = ["scanpy>=1.6"] # scvi.criticism and read 10x optional = [ - "scvi-tools[autotune,aws,census,hub,loompy,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]" ] # all optional user functionality tutorials = [ From b4ff6e2455fab9dc1a600ad2a838482de45d6a1e Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:03:28 -0800 Subject: [PATCH 03/21] Remove Pandas pin (#2475) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 785e97cb09..1180ab51b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "jaxlib>=0.4.3", "optax", "numpy>=1.21.0", - "pandas>=1.0,!=2.1.2", + "pandas>=1.0", "scipy", "scikit-learn>=0.21.2", "rich>=12.0.0", From 98e5596e69681e812e59d3c088738c69f7bc27c6 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:21:26 -0800 Subject: [PATCH 04/21] Add support for Ray 2.8-2.9 (#2478) --- docs/release_notes/index.md | 1 + pyproject.toml | 2 +- scvi/autotune/_manager.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index ed4fe0bfa7..3d3e9f0853 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -59,6 +59,7 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ {meth}`scvi.hub.HubModel.pull_from_s3` and {meth}`scvi.hub.HubModel.push_to_s3` {pr}`2378`. - Add clearer error message for {func}`scvi.data.poisson_gene_selection` when input data does not contain raw counts {pr}`2422`. +- Add support for Ray 2.8 - 2.9 in {class}`scvi.autotune.ModelTuner` {pr}`2478`. #### Fixed diff --git a/pyproject.toml b/pyproject.toml index 1180ab51b3..c8b516a320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ docsbuild = ["scvi-tools[docs,optional]"] # docs build dependencies autotune = [ "hyperopt>=0.2", - "ray[tune]>=2.5.0,<2.8.0", + "ray[tune]>=2.5.0,<2.10.0", "ipython", "scib-metrics>=0.4.1", "tensorboard", diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py index b970aebce8..ffecfb61de 100644 --- a/scvi/autotune/_manager.py +++ b/scvi/autotune/_manager.py @@ -8,6 +8,7 @@ import warnings from collections import OrderedDict from datetime import datetime +from pathlib import Path from typing import Any, Callable import lightning.pytorch as pl @@ -446,7 +447,8 @@ def _trainable( callbacks = [callback_cls(metric, on="validation_end")] logs_dir = os.path.join(logging_dir, experiment_name) - trial_name = air.session.get_trial_name() + "_lightning" + Path(logs_dir).mkdir(parents=True, exist_ok=True) + trial_name = ray.train.get_context().get_trial_name() + "_lightning" logger = pl.loggers.TensorBoardLogger(logs_dir, name=trial_name) if monitor_device_stats: From bd303be4164ced0fb9763f5b412a7bace7a1c136 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:35:34 -0800 Subject: [PATCH 05/21] Make xarray and sparse optional deps, lazy load criticism (#2480) --- docs/release_notes/index.md | 1 + pyproject.toml | 5 ++--- scvi/criticism/__init__.py | 7 ++++++- scvi/model/base/_rnamixin.py | 13 ++++++++++--- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 3d3e9f0853..59924377f4 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -103,6 +103,7 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ computations to use `"micro"` reduction rather than `"macro"` {pr}`2339`. - Internal refactoring of {meth}`scvi.module.VAE.sample` and {meth}`scvi.model.base.RNASeqMixin.posterior_predictive_sample` {pr}`2377`. +- Change `xarray` and `sparse` from mandatory to optional dependencies {pr}`2480`. #### Removed diff --git a/pyproject.toml b/pyproject.toml index c8b516a320..bdc0f8038d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,8 +54,6 @@ dependencies = [ "numpyro>=0.12.1", "ml-collections>=0.1.1", "mudata>=0.1.2", - "sparse>=0.14.0", - "xarray>=2023.2.0", ] @@ -98,13 +96,14 @@ autotune = [ ] # scvi.autotune aws = ["boto3"] # scvi.hub.HubModel.pull_from_s3 census = ["cellxgene-census"] # scvi.data.cellxgene +criticism = ["sparse>=0.14.0", "xarray>=2023.2.0"] # scvi.criticism hub = ["huggingface_hub"] # scvi.hub dependencies pymde = ["pymde"] # scvi.model.utils.mde dependencies regseq = ["biopython>=1.81", "genomepy"] # scvi.data.add_dna_sequence loompy = ["loompy>=3.0.6"] # read loom scanpy = ["scanpy>=1.6"] # scvi.criticism and read 10x optional = [ - "scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,criticism,hub,loompy,pymde,regseq,scanpy]" ] # all optional user functionality tutorials = [ diff --git a/scvi/criticism/__init__.py b/scvi/criticism/__init__.py index ed08672238..ba20ee4a46 100644 --- a/scvi/criticism/__init__.py +++ b/scvi/criticism/__init__.py @@ -1,3 +1,8 @@ -from ._ppc import PosteriorPredictiveCheck +from scvi.utils import error_on_missing_dependencies + +error_on_missing_dependencies("sparse", "xarray") + + +from ._ppc import PosteriorPredictiveCheck # noqa __all__ = ["PosteriorPredictiveCheck"] diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 6f47abde1b..00a4ab53ed 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd -import sparse import torch import torch.distributions as db from anndata import AnnData @@ -19,10 +18,15 @@ from scvi.distributions._utils import DistributionConcatenator, subset_distribution from scvi.model._utils import _get_batch_code_from_category, scrna_raw_counts_properties from scvi.module.base._decorators import _move_data_to_device -from scvi.utils import de_dsp, unsupported_if_adata_minified +from scvi.utils import de_dsp, dependencies, unsupported_if_adata_minified from ._utils import _de_core +try: + from sparse import GCXS +except ImportError: + GCXS = type(None) + logger = logging.getLogger(__name__) @@ -409,6 +413,7 @@ def differential_expression( return result + @dependencies("sparse") def posterior_predictive_sample( self, adata: AnnData | None = None, @@ -416,7 +421,7 @@ def posterior_predictive_sample( n_samples: int = 1, gene_list: list[str] | None = None, batch_size: int | None = None, - ) -> sparse.GCXS: + ) -> GCXS: r"""Generate predictive samples from the posterior predictive distribution. The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where @@ -449,6 +454,8 @@ def posterior_predictive_sample( Sparse multidimensional array of shape ``(n_obs, n_vars)`` if ``n_samples == 1``, else ``(n_obs, n_vars, n_samples)``. """ + import sparse + adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size From 7f669091002595f51674ae9c2ffe49348ef56e8d Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:49:54 -0800 Subject: [PATCH 06/21] Remove chex dependency (#2482) --- docs/release_notes/index.md | 1 + pyproject.toml | 1 - scvi/autotune/_manager.py | 2 +- scvi/external/tangram/_module.py | 2 -- scvi/module/base/_base_module.py | 40 +++++++++++++++++++------------- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 59924377f4..020c465dd8 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -110,6 +110,7 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ - Remove deprecated `use_gpu argument in favor of PyTorch Lightning arguments `accelerator`and`devices` {pr}`2114`. - Remove deprecated `scvi._compat.Literal` class {pr}`2115`. +- Remove chex dependency {pr}`2482`. ## Version 1.0 diff --git a/pyproject.toml b/pyproject.toml index bdc0f8038d..b050a2ad79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ classifiers = [ ] dependencies = [ "anndata>=0.7.5", - "chex", "docrep>=0.3.2", "flax", "jax>=0.4.4", diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py index ffecfb61de..898d523cb2 100644 --- a/scvi/autotune/_manager.py +++ b/scvi/autotune/_manager.py @@ -7,6 +7,7 @@ import os import warnings from collections import OrderedDict +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any, Callable @@ -14,7 +15,6 @@ import lightning.pytorch as pl import ray import rich -from chex import dataclass from ray import air, tune from ray.tune.integration.pytorch_lightning import TuneReportCallback diff --git a/scvi/external/tangram/_module.py b/scvi/external/tangram/_module.py index 2efdb7fb61..bfce474061 100644 --- a/scvi/external/tangram/_module.py +++ b/scvi/external/tangram/_module.py @@ -1,6 +1,5 @@ from typing import NamedTuple, Optional -import chex import jax import jax.numpy as jnp @@ -111,7 +110,6 @@ def loss( sc = sc * filter g_pred = mapper.transpose() @ sc - chex.assert_equal_shape([sp, g_pred]) # Expression term if self.lambda_g1 > 0: diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index 2c483a0825..387b5ae7f8 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -5,7 +5,6 @@ from dataclasses import field from typing import Any, Callable -import chex import flax import jax import jax.numpy as jnp @@ -28,7 +27,7 @@ from ._pyro import AutoMoveDataPredictive -@chex.dataclass +@flax.struct.dataclass class LossOutput: """Loss signature for models. @@ -84,12 +83,12 @@ class LossOutput: true_labels: Tensor | None = None extra_metrics: dict[str, Tensor] | None = field(default_factory=dict) n_obs_minibatch: int | None = None - reconstruction_loss_sum: Tensor = field(default=None, init=False) - kl_local_sum: Tensor = field(default=None, init=False) - kl_global_sum: Tensor = field(default=None, init=False) + reconstruction_loss_sum: Tensor = field(default=None) + kl_local_sum: Tensor = field(default=None) + kl_global_sum: Tensor = field(default=None) def __post_init__(self): - self.loss = self.dict_sum(self.loss) + object.__setattr__(self, "loss", self.dict_sum(self.loss)) if self.n_obs_minibatch is None and self.reconstruction_loss is None: raise ValueError( @@ -98,21 +97,30 @@ def __post_init__(self): default = 0 * self.loss if self.reconstruction_loss is None: - self.reconstruction_loss = default + object.__setattr__(self, "reconstruction_loss", default) if self.kl_local is None: - self.kl_local = default + object.__setattr__(self, "kl_local", default) if self.kl_global is None: - self.kl_global = default - self.reconstruction_loss = self._as_dict("reconstruction_loss") - self.kl_local = self._as_dict("kl_local") - self.kl_global = self._as_dict("kl_global") - self.reconstruction_loss_sum = self.dict_sum(self.reconstruction_loss).sum() - self.kl_local_sum = self.dict_sum(self.kl_local).sum() - self.kl_global_sum = self.dict_sum(self.kl_global) + object.__setattr__(self, "kl_global", default) + + object.__setattr__( + self, "reconstruction_loss", self._as_dict("reconstruction_loss") + ) + object.__setattr__(self, "kl_local", self._as_dict("kl_local")) + object.__setattr__(self, "kl_global", self._as_dict("kl_global")) + object.__setattr__( + self, + "reconstruction_loss_sum", + self.dict_sum(self.reconstruction_loss).sum(), + ) + object.__setattr__(self, "kl_local_sum", self.dict_sum(self.kl_local).sum()) + object.__setattr__(self, "kl_global_sum", self.dict_sum(self.kl_global)) if self.reconstruction_loss is not None and self.n_obs_minibatch is None: rec_loss = self.reconstruction_loss - self.n_obs_minibatch = list(rec_loss.values())[0].shape[0] + object.__setattr__( + self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0] + ) if self.classification_loss is not None and ( self.logits is None or self.true_labels is None From c32b1713a84063e43e1b1b0a6b57fe2b91e143cc Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 15:15:57 -0800 Subject: [PATCH 07/21] Add API for custom dataloaders in `SCVI` (#2467) This commit enables users to bypass the default workflow of setting up an `AnnData` and initializing a model with it in order to train, and instead being able to train a model directly with a custom dataloader. - Make `adata` argument optional in `SCVI`, which sets a `_module_init_on_train` flag (through `BaseModelClass`) if no `adata` is passed in - This causes the module init to be delayed until train time since we won't have access to certain arguments (*e.g.* `n_input` and `n_batch`) that must be inferred from the data - Add an optional `data_module` argument to `UnsupervisedTrainingMixin.train` that must be a `LightningDataModule` and must be passed in if the model was not initialized with `adata` Set as an experimental feature for now. Will expand to models other than `SCVI` when the implementation and API stabilize. --- docs/api/developer.md | 1 + docs/release_notes/index.md | 3 + scvi/model/__init__.py | 1 + scvi/model/_scvi.py | 136 +++++++++++++++++------------ scvi/model/_utils.py | 12 +-- scvi/model/base/_base_model.py | 1 + scvi/model/base/_training_mixin.py | 107 ++++++++++++++++------- tests/hub/test_hub_metadata.py | 2 +- tests/model/test_scvi.py | 39 +++++++++ 9 files changed, 206 insertions(+), 96 deletions(-) diff --git a/docs/api/developer.md b/docs/api/developer.md index 448597d8e1..2c0b437ed8 100644 --- a/docs/api/developer.md +++ b/docs/api/developer.md @@ -287,6 +287,7 @@ Utility functions used by scvi-tools. utils.track utils.setup_anndata_dsp utils.attrdict + model.get_max_epochs_heuristic ``` [ray tune]: https://docs.ray.io/en/latest/tune/index.html diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 020c465dd8..cd1d6ca52e 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -59,6 +59,9 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ {meth}`scvi.hub.HubModel.pull_from_s3` and {meth}`scvi.hub.HubModel.push_to_s3` {pr}`2378`. - Add clearer error message for {func}`scvi.data.poisson_gene_selection` when input data does not contain raw counts {pr}`2422`. +- Add API for using custom dataloaders with {class}`scvi.model.SCVI` by making `adata` argument + optional on initialization and adding optional argument `data_module` to + {meth}`scvi.model.base.UnsupervisedTrainingMixin.train` {pr}`2467`. - Add support for Ray 2.8 - 2.9 in {class}`scvi.autotune.ModelTuner` {pr}`2478`. #### Fixed diff --git a/scvi/model/__init__.py b/scvi/model/__init__.py index aa2e13ca2b..b16b905db3 100644 --- a/scvi/model/__init__.py +++ b/scvi/model/__init__.py @@ -10,6 +10,7 @@ from ._scanvi import SCANVI from ._scvi import SCVI from ._totalvi import TOTALVI +from ._utils import get_max_epochs_heuristic __all__ = [ "SCVI", diff --git a/scvi/model/_scvi.py b/scvi/model/_scvi.py index c2459b1329..d6bce7b488 100644 --- a/scvi/model/_scvi.py +++ b/scvi/model/_scvi.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import logging -from typing import Literal, Optional +import warnings +from typing import Literal import numpy as np from anndata import AnnData -from scvi import REGISTRY_KEYS +from scvi import REGISTRY_KEYS, settings from scvi._types import MinifiedDataType from scvi.data import AnnDataManager from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE @@ -46,7 +49,10 @@ class SCVI( Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.model.SCVI.setup_anndata`. + AnnData object that has been registered via :meth:`~scvi.model.SCVI.setup_anndata`. If + ``None``, then the underlying module will not be initialized until training, and a + :class:`~lightning.pytorch.core.LightningDataModule` must be passed in during training + (``EXPERIMENTAL``). n_hidden Number of nodes per hidden layer. n_latent @@ -73,8 +79,8 @@ class SCVI( * ``'normal'`` - Normal distribution * ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax) - **model_kwargs - Keyword args for :class:`~scvi.module.VAE` + **kwargs + Additional keyword arguments for :class:`~scvi.module.VAE`. Examples -------- @@ -99,7 +105,7 @@ class SCVI( def __init__( self, - adata: AnnData, + adata: AnnData | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, @@ -107,58 +113,72 @@ def __init__( dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", latent_distribution: Literal["normal", "ln"] = "normal", - **model_kwargs, + **kwargs, ): super().__init__(adata) - n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) - n_batch = self.summary_stats.n_batch - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + self._module_kwargs = { + "n_hidden": n_hidden, + "n_latent": n_latent, + "n_layers": n_layers, + "dropout_rate": dropout_rate, + "dispersion": dispersion, + "gene_likelihood": gene_likelihood, + "latent_distribution": latent_distribution, + **kwargs, + } + self._model_summary_string = ( + "SCVI model with the following parameters: \n" + f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, " + f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, " + f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) - library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch + + if self._module_init_on_train: + self.module = None + warnings.warn( + "Model was initialized without `adata`. The module will be initialized when " + "calling `train`. This behavior is experimental and may change in the future.", + UserWarning, + stacklevel=settings.warnings_stacklevel, ) + else: + n_cats_per_cov = ( + self.adata_manager.get_state_registry( + REGISTRY_KEYS.CAT_COVS_KEY + ).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + n_batch = self.summary_stats.n_batch + use_size_factor_key = ( + REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + ) + library_log_means, library_log_vars = None, None + if not use_size_factor_key and self.minified_data_type is None: + library_log_means, library_log_vars = _init_library_size( + self.adata_manager, n_batch + ) + self.module = self._module_cls( + n_input=self.summary_stats.n_vars, + n_batch=n_batch, + n_labels=self.summary_stats.n_labels, + n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), + n_cats_per_cov=n_cats_per_cov, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + dispersion=dispersion, + gene_likelihood=gene_likelihood, + latent_distribution=latent_distribution, + use_size_factor_key=use_size_factor_key, + library_log_means=library_log_means, + library_log_vars=library_log_vars, + **kwargs, + ) + self.module.minified_data_type = self.minified_data_type - self.module = self._module_cls( - n_input=self.summary_stats.n_vars, - n_batch=n_batch, - n_labels=self.summary_stats.n_labels, - n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), - n_cats_per_cov=n_cats_per_cov, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - dropout_rate=dropout_rate, - dispersion=dispersion, - gene_likelihood=gene_likelihood, - latent_distribution=latent_distribution, - use_size_factor_key=use_size_factor_key, - library_log_means=library_log_means, - library_log_vars=library_log_vars, - **model_kwargs, - ) - self.module.minified_data_type = self.minified_data_type - self._model_summary_string = ( - "SCVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: " - "{}, dispersion: {}, gene_likelihood: {}, latent_distribution: {}" - ).format( - n_hidden, - n_latent, - n_layers, - dropout_rate, - dispersion, - gene_likelihood, - latent_distribution, - ) self.init_params_ = self._get_init_params(locals()) @classmethod @@ -166,12 +186,12 @@ def __init__( def setup_anndata( cls, adata: AnnData, - layer: Optional[str] = None, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, - size_factor_key: Optional[str] = None, - categorical_covariate_keys: Optional[list[str]] = None, - continuous_covariate_keys: Optional[list[str]] = None, + layer: str | None = None, + batch_key: str | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, **kwargs, ): """%(summary)s. diff --git a/scvi/model/_utils.py b/scvi/model/_utils.py index 42bb0151ae..e573301773 100644 --- a/scvi/model/_utils.py +++ b/scvi/model/_utils.py @@ -34,13 +34,14 @@ def use_distributed_sampler(strategy: Union[str, Strategy]) -> bool: def get_max_epochs_heuristic( - n_obs: int, epochs_cap: int = 400, decay_at_n_obs: int = 20000 + n_obs: int, epochs_cap: int = 400, decay_at_n_obs: int = 20_000 ) -> int: """Compute a heuristic for the default number of maximum epochs. - If `n_obs <= decay_at_n_obs`, the number of maximum epochs is set to - `epochs_cap`. Otherwise, the number of maximum epochs decays according to - `(decay_at_n_obs / n_obs) * epochs_cap`, with a minimum of 1. + If ``n_obs <= decay_at_n_obs``, the number of maximum epochs is set to + ``epochs_cap``. Otherwise, the number of maximum epochs decays according to + ``(decay_at_n_obs / n_obs) * epochs_cap``, with a minimum of 1. Raises a + warning if the number of maximum epochs is set to 1. Parameters ---------- @@ -53,8 +54,7 @@ def get_max_epochs_heuristic( Returns ------- - `int` - A heuristic for the default number of maximum epochs. + A heuristic for the number of maximum training epochs. """ max_epochs = min(round((decay_at_n_obs / n_obs) * epochs_cap), epochs_cap) max_epochs = max(max_epochs, 1) diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index c29fda27d7..b1dccb0ebf 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -98,6 +98,7 @@ def __init__(self, adata: AnnOrMuData | None = None): self.registry_ = self._adata_manager.registry self.summary_stats = self._adata_manager.summary_stats + self._module_init_on_train = adata is None self.is_trained_ = False self._model_summary_string = "" self.train_indices_ = None diff --git a/scvi/model/base/_training_mixin.py b/scvi/model/base/_training_mixin.py index 1c6a01b1dc..ca050ca823 100644 --- a/scvi/model/base/_training_mixin.py +++ b/scvi/model/base/_training_mixin.py @@ -1,5 +1,7 @@ from __future__ import annotations +from lightning import LightningDataModule + from scvi._types import Tunable from scvi.dataloaders import DataSplitter from scvi.model._utils import get_max_epochs_heuristic, use_distributed_sampler @@ -28,6 +30,7 @@ def train( early_stopping: bool = False, datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, + data_module: LightningDataModule | None = None, **trainer_kwargs, ): """Train the model. @@ -35,53 +38,95 @@ def train( Parameters ---------- max_epochs - Number of passes through the dataset. If `None`, defaults to - `np.min([round((20000 / n_cells) * 400), 400])` + The maximum number of epochs to train the model. The actual number of epochs may be + less if early stopping is enabled. If ``None``, defaults to a heuristic based on + :func:`~scvi.model.get_max_epochs_heuristic`. Must be passed in if ``data_module`` is + passed in, and it does not have an ``n_obs`` attribute. %(param_accelerator)s %(param_devices)s train_size - Size of training set in the range [0.0, 1.0]. + Size of training set in the range ``[0.0, 1.0]``. Passed into + :class:`~scvi.dataloaders.DataSplitter`. Not used if ``data_module`` is passed in. validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. + Size of the test set. If ``None``, defaults to ``1 - train_size``. If + ``train_size + validation_size < 1``, the remaining cells belong to a test set. Passed + into :class:`~scvi.dataloaders.DataSplitter`. Not used if ``data_module`` is passed in. shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the - sequential order of the data according to `validation_size` and `train_size` percentages. + Whether to shuffle indices before splitting. If ``False``, the val, train, and test set + are split in the sequential order of the data according to ``validation_size`` and + ``train_size`` percentages. Passed into :class:`~scvi.dataloaders.DataSplitter`. Not + used if ``data_module`` is passed in. load_sparse_tensor ``EXPERIMENTAL`` If ``True``, loads data with sparse CSR or CSC layout as a :class:`~torch.Tensor` with the same layout. Can lead to speedups in data transfers to - GPUs, depending on the sparsity of the data. + GPUs, depending on the sparsity of the data. Passed into + :class:`~scvi.dataloaders.DataSplitter`. Not used if ``data_module`` is passed in. batch_size - Minibatch size to use during training. + Minibatch size to use during training. Passed into + :class:`~scvi.dataloaders.DataSplitter`. Not used if ``data_module`` is passed in. early_stopping - Perform early stopping. Additional arguments can be passed in `**kwargs`. + Perform early stopping. Additional arguments can be passed in through ``**kwargs``. See :class:`~scvi.train.Trainer` for further options. datasplitter_kwargs - Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. + Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. Values + in this argument can be overwritten by arguments directly passed into this method, when + appropriate. Not used if ``data_module`` is passed in. plan_kwargs - Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. + Additional keyword arguments passed into :class:`~scvi.train.TrainingPlan`. Values in + this argument can be overwritten by arguments directly passed into this method, when + appropriate. + data_module + ``EXPERIMENTAL`` A :class:`~lightning.pytorch.core.LightningDataModule` instance to use + for training in place of the default :class:`~scvi.dataloaders.DataSplitter`. Can only + be passed in if the model was not initialized with :class:`~anndata.AnnData`. + **kwargs + Additional keyword arguments passed into :class:`~scvi.train.Trainer`. """ + if data_module is not None and not self._module_init_on_train: + raise ValueError( + "Cannot pass in `data_module` if the model was initialized with `adata`." + ) + elif data_module is None and self._module_init_on_train: + raise ValueError( + "If the model was not initialized with `adata`, a `data_module` must be passed in." + ) + if max_epochs is None: - max_epochs = get_max_epochs_heuristic(self.adata.n_obs) + if data_module is None: + max_epochs = get_max_epochs_heuristic(self.adata.n_obs) + elif hasattr(data_module, "n_obs"): + max_epochs = get_max_epochs_heuristic(data_module.n_obs) + else: + raise ValueError( + "If `data_module` does not have `n_obs` attribute, `max_epochs` must be passed " + "in." + ) - plan_kwargs = plan_kwargs or {} - datasplitter_kwargs = datasplitter_kwargs or {} + if data_module is None: + datasplitter_kwargs = datasplitter_kwargs or {} + data_module = self._data_splitter_cls( + self.adata_manager, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + shuffle_set_split=shuffle_set_split, + distributed_sampler=use_distributed_sampler( + trainer_kwargs.get("strategy", None) + ), + load_sparse_tensor=load_sparse_tensor, + **datasplitter_kwargs, + ) + elif self.module is None: + self.module = self._module_cls( + data_module.n_vars, + n_batch=data_module.n_batch, + n_labels=getattr(data_module, "n_labels", 1), + n_continuous_cov=getattr(data_module, "n_continuous_cov", 0), + n_cats_per_cov=getattr(data_module, "n_cats_per_cov", None), + **self._module_kwargs, + ) - data_splitter = self._data_splitter_cls( - self.adata_manager, - train_size=train_size, - validation_size=validation_size, - batch_size=batch_size, - shuffle_set_split=shuffle_set_split, - distributed_sampler=use_distributed_sampler( - trainer_kwargs.get("strategy", None) - ), - load_sparse_tensor=load_sparse_tensor, - **datasplitter_kwargs, - ) + plan_kwargs = plan_kwargs or {} training_plan = self._training_plan_cls(self.module, **plan_kwargs) es = "early_stopping" @@ -91,7 +136,7 @@ def train( runner = self._train_runner_cls( self, training_plan=training_plan, - data_splitter=data_splitter, + data_splitter=data_module, max_epochs=max_epochs, accelerator=accelerator, devices=devices, diff --git a/tests/hub/test_hub_metadata.py b/tests/hub/test_hub_metadata.py index b51b92e177..7da0c173cd 100644 --- a/tests/hub/test_hub_metadata.py +++ b/tests/hub/test_hub_metadata.py @@ -87,7 +87,7 @@ def test_hub_modelcardhelper(request, save_path): assert hmch.license_info == "cc-by-4.0" assert hmch.model_cls_name == "SCVI" assert hmch.model_init_params == { - "kwargs": {"model_kwargs": {}}, + "kwargs": {"kwargs": {}}, "non_kwargs": { "n_hidden": 128, "n_latent": 10, diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index f7e2f018c9..7b8b583b4d 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -909,3 +909,42 @@ def test_set_seed(n_latent: int = 5, seed: int = 1): model1.module.z_encoder.encoder.fc_layers[0][0].weight, model2.module.z_encoder.encoder.fc_layers[0][0].weight, ) + + +def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): + from scvi.dataloaders import DataSplitter + + adata = synthetic_iid(n_batches=n_batches) + SCVI.setup_anndata(adata, batch_key="batch") + manager = SCVI._get_most_recent_anndata_manager(adata) + + data_module = DataSplitter(manager) + data_module.n_vars = adata.n_vars + data_module.n_batch = n_batches + + model = SCVI(n_latent=5) + assert model._module_init_on_train + assert model.module is None + + # cannot infer default max_epochs without n_obs set in data_module + with pytest.raises(ValueError): + model.train(data_module=data_module) + + # must pass in data_module if not initialized with adata + with pytest.raises(ValueError): + model.train() + + model.train(max_epochs=1, data_module=data_module) + + # must set n_obs for defaulting max_epochs + data_module.n_obs = 100_000_000 # large number for fewer default epochs + model.train(data_module=data_module) + + model = SCVI(adata, n_latent=5) + assert not model._module_init_on_train + assert model.module is not None + assert hasattr(model, "adata") + + # initialized with adata, cannot pass in data_module + with pytest.raises(ValueError): + model.train(data_module=data_module) From 42439808a12bfaee4569ada93e0aa9bc8e431337 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:43:07 -0800 Subject: [PATCH 08/21] Add token to codecov action (#2487) As of v4 of codecov-action, it requires a token for uploading. --- .github/workflows/test_linux.yml | 2 ++ .github/workflows/test_linux_cuda.yml | 2 ++ .github/workflows/test_linux_pre.yml | 2 ++ .github/workflows/test_linux_private.yml | 2 ++ .github/workflows/test_macos.yml | 2 ++ .github/workflows/test_macos_m1.yml | 2 ++ .github/workflows/test_windows.yml | 2 ++ 7 files changed, 14 insertions(+) diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index 5338f5f532..e13266f0ab 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -58,3 +58,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index c15d02bd86..153254bdf1 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -62,3 +62,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/test_linux_pre.yml b/.github/workflows/test_linux_pre.yml index 607e893db0..21b3f0e056 100644 --- a/.github/workflows/test_linux_pre.yml +++ b/.github/workflows/test_linux_pre.yml @@ -68,3 +68,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/test_linux_private.yml b/.github/workflows/test_linux_private.yml index 8d6d68a077..76ce45e49f 100644 --- a/.github/workflows/test_linux_private.yml +++ b/.github/workflows/test_linux_private.yml @@ -81,3 +81,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml index a57c2b3783..9d7bd7950b 100644 --- a/.github/workflows/test_macos.yml +++ b/.github/workflows/test_macos.yml @@ -57,3 +57,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/test_macos_m1.yml b/.github/workflows/test_macos_m1.yml index 12b6f7678f..8dad921269 100644 --- a/.github/workflows/test_macos_m1.yml +++ b/.github/workflows/test_macos_m1.yml @@ -57,3 +57,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/test_windows.yml b/.github/workflows/test_windows.yml index a8b92a7871..04b197055e 100644 --- a/.github/workflows/test_windows.yml +++ b/.github/workflows/test_windows.yml @@ -57,3 +57,5 @@ jobs: coverage report - name: Upload coverage uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} From b1e2aab74a60661ced5aad00dc99a7136cc4e759 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:19:21 -0800 Subject: [PATCH 09/21] Use anndata `CSCDataset` and `CSRDataset` for type checks (#2485) Addresses the following warning message: ``` :119: FutureWarning: SparseDataset is deprecated and will be removed in late 2024. It has been replaced by the public classes CSRDataset and CSCDataset. For instance checks, use `isinstance(X, (anndata.experimental.CSRDataset, anndata.experimental.CSCDataset))` instead. ``` --- docs/release_notes/index.md | 100 ++++++++++++++++------------------ scvi/data/_anntorchdataset.py | 6 +- scvi/data/_utils.py | 6 +- 3 files changed, 53 insertions(+), 59 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index cd1d6ca52e..fbe645fd98 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -13,42 +13,40 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ - Add {class}`scvi.external.ContrastiveVI` for contrastiveVI {pr}`2242`. - Add {class}`scvi.dataloaders.BatchDistributedSampler` for distributed training {pr}`2102`. -- Add `additional_val_metrics` argument to {class}`scvi.train.Trainer`, allowing to - specify additional metrics to compute and log during the validation loop using +- Add `additional_val_metrics` argument to {class}`scvi.train.Trainer`, allowing to specify + additional metrics to compute and log during the validation loop using {class}`scvi.train._callbacks.MetricsCallback` {pr}`2136`. -- Expose `accelerator` and `device` arguments in {meth}`scvi.hub.HubModel.load_model` - `pr`{2166}. -- Add `load_sparse_tensor` argument in {class}`scvi.data.AnnTorchDataset` for directly - loading SciPy CSR and CSC data structures to their PyTorch counterparts, leading to - faster data loading depending on the sparsity of the data {pr}`2158`. -- Add per-group LFC information to the - {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` - method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` - stores the summary dataframe, and the `lfc_per_model_per_group` key stores the - per-group LFC. +- Expose `accelerator` and `device` arguments in {meth}`scvi.hub.HubModel.load_model` `pr`{2166}. +- Add `load_sparse_tensor` argument in {class}`scvi.data.AnnTorchDataset` for directly loading + SciPy CSR and CSC data structures to their PyTorch counterparts, leading to faster data loading + depending on the sparsity of the data {pr}`2158`. +- Add per-group LFC information to + {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`. `metrics["diff_exp"]` + is now a dictionary where `summary` stores the summary dataframe, and `lfc_per_model_per_group` + stores the per-group LFC {pr}`2173`. - Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save` and {class}`scvi.external.GIMVI.save` {pr}`2200`. - Add `model_kwargs` and `train_kwargs` arguments to {meth}`scvi.autotune.ModelTuner.fit` {pr}`2203`. - Add `datasplitter_kwargs` to model `train` methods {pr}`2204`. -- Add `use_posterior_mean` argument to {meth}`scvi.model.SCANVI.predict` for - stochastic prediction of celltype labels {pr}`2224`. +- Add `use_posterior_mean` argument to {meth}`scvi.model.SCANVI.predict` for stochastic prediction + of celltype labels {pr}`2224`. - Add support for Python 3.10+ type annotations in {class}`scvi.autotune.ModelTuner` {pr}`2239`. -- Add option to log device statistics in {meth}`scvi.autotune.ModelTuner.fit` - with argument `monitor_device_stats` {pr}`2260`. -- Add option to pass in a random seed to {meth}`scvi.autotune.ModelTuner.fit` - with argument `seed` {pr}`2260`. -- Automatically log the learning rate when `reduce_lr_on_plateau=True` in - training plans {pr}`2280`. +- Add option to log device statistics in {meth}`scvi.autotune.ModelTuner.fit` with argument + `monitor_device_stats` {pr}`2260`. +- Add option to pass in a random seed to {meth}`scvi.autotune.ModelTuner.fit` with argument `seed` + {pr}`2260`. +- Automatically log the learning rate when `reduce_lr_on_plateau=True` in training plans + {pr}`2280`. - Add {class}`scvi.external.POISSONVI` to model scATAC-seq fragment counts with a Poisson distribution {pr}`2249` -- {class}`scvi.train.SemiSupervisedTrainingPlan` now logs the classifier - calibration error {pr}`2299`. -- Passing `enable_checkpointing=True` into `train` methods is now - compatible with our model saves. Additional options can be specified - by initializing with {class}`scvi.train.SaveCheckpoint` {pr}`2317`. -- {attr}`scvi.settings.dl_num_workers` is now correctly applied as the default - `num_workers` in {class}`scvi.dataloaders.AnnDataLoader` {pr}`2322`. +- {class}`scvi.train.SemiSupervisedTrainingPlan` now logs the classifier calibration error + {pr}`2299`. +- Passing `enable_checkpointing=True` into `train` methods is now compatible with our model saves. + Additional options can be specified by initializing with {class}`scvi.train.SaveCheckpoint` + {pr}`2317`. +- {attr}`scvi.settings.dl_num_workers` is now correctly applied as the default `num_workers` in + {class}`scvi.dataloaders.AnnDataLoader` {pr}`2322`. - Passing in `indices` to {class}`scvi.criticism.PosteriorPredictiveCheck` allows for running metrics on a subset of the data {pr}`2361`. - Add `seed` argument to {func}`scvi.model.utils.mde` for reproducibility {pr}`2373`. @@ -62,17 +60,17 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ - Add API for using custom dataloaders with {class}`scvi.model.SCVI` by making `adata` argument optional on initialization and adding optional argument `data_module` to {meth}`scvi.model.base.UnsupervisedTrainingMixin.train` {pr}`2467`. -- Add support for Ray 2.8 - 2.9 in {class}`scvi.autotune.ModelTuner` {pr}`2478`. +- Add support for Ray 2.8-2.9 in {class}`scvi.autotune.ModelTuner` {pr}`2478`. #### Fixed -- Fix bug where `n_hidden` was not being passed into {class}`scvi.nn.Encoder` - in {class}`scvi.model.AmortizedLDA` {pr}`2229` -- Fix bug in {class}`scvi.module.SCANVAE` where classifier probabilities - were interpreted as logits. This is backwards compatible as loading older - models will use the old code path {pr}`2301`. -- Fix bug in {class}`scvi.external.GIMVI` where `batch_size` was not - properly used in inference methods {pr}`2366`. +- Fix bug where `n_hidden` was not being passed into {class}`scvi.nn.Encoder` in + {class}`scvi.model.AmortizedLDA` {pr}`2229` +- Fix bug in {class}`scvi.module.SCANVAE` where classifier probabilities were interpreted as + logits. This is backwards compatible as loading older models will use the old code path + {pr}`2301`. +- Fix bug in {class}`scvi.external.GIMVI` where `batch_size` was not properly used in inference + methods {pr}`2366`. - Fix error message formatting in {meth}`scvi.data.fields.LayerField.transfer_field` {pr}`2368`. - Fix ambiguous error raised in {meth}`scvi.distributions.NegativeBinomial.log_prob` and {meth}`scvi.distributions.ZeroInflatedNegativeBinomial.log_prob` when `scale` not passed in @@ -84,34 +82,30 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ #### Changed -- Replace `sparse` with `sparse_format` argument in {meth}`scvi.data.synthetic_iid` - for increased flexibility over dataset format {pr}`2163`. -- Add per-group LFC information to the - {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` - method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` - stores the summary dataframe, and the `lfc_per_model_per_group` key stores the - per-group LFC. -- Revalidate `devices` when automatically switching from MPS to CPU - accelerator in {func}`scvi.model._utils.parse_device_args` {pr}`2247`. -- Refactor {class}`scvi.data.AnnTorchDataset`, now loads continuous data as - {class}`numpy.float32` and categorical data as {class}`numpy.int64` by - default {pr}`2250`. +- Replace `sparse` with `sparse_format` argument in {meth}`scvi.data.synthetic_iid` for increased + flexibility over dataset format {pr}`2163`. +- Revalidate `devices` when automatically switching from MPS to CPU accelerator in + {func}`scvi.model._utils.parse_device_args` {pr}`2247`. +- Refactor {class}`scvi.data.AnnTorchDataset`, now loads continuous data as {class}`numpy.float32` + and categorical data as {class}`numpy.int64` by default {pr}`2250`. - Support fractional GPU usage in {class}`scvi.autotune.ModelTuner` `pr`{2252}. -- Tensorboard is now the default logger in {class}`scvi.autotune.ModelTuner` - `pr`{2260}. -- Match `momentum` and `epsilon` in {class}`scvi.module.JaxVAE` to the - default values in PyTorch {pr}`2309`. +- Tensorboard is now the default logger in {class}`scvi.autotune.ModelTuner` `pr`{2260}. +- Match `momentum` and `epsilon` in {class}`scvi.module.JaxVAE` to the default values in PyTorch + {pr}`2309`. - Change {class}`scvi.train.SemiSupervisedTrainingPlan` and {class}`scvi.train.ClassifierTrainingPlan` accuracy and F1 score computations to use `"micro"` reduction rather than `"macro"` {pr}`2339`. - Internal refactoring of {meth}`scvi.module.VAE.sample` and {meth}`scvi.model.base.RNASeqMixin.posterior_predictive_sample` {pr}`2377`. - Change `xarray` and `sparse` from mandatory to optional dependencies {pr}`2480`. +- Use {class}`anndata.experimental.CSCDataset` and {class}`anndata.experimental.CSRDataset` + instead of the deprecated {class}`anndata._core.sparse_dataset.SparseDataset` for type checks + {pr}`2485`. #### Removed -- Remove deprecated `use_gpu argument in favor of PyTorch Lightning arguments -`accelerator`and`devices` {pr}`2114`. +- Remove deprecated `use_gpu` argument in favor of PyTorch Lightning arguments `accelerator` and + `devices` {pr}`2114`. - Remove deprecated `scvi._compat.Literal` class {pr}`2115`. - Remove chex dependency {pr}`2482`. diff --git a/scvi/data/_anntorchdataset.py b/scvi/data/_anntorchdataset.py index 5abd7f514d..124ca2ef09 100644 --- a/scvi/data/_anntorchdataset.py +++ b/scvi/data/_anntorchdataset.py @@ -12,9 +12,9 @@ from anndata._core.sparse_dataset import SparseDataset except ImportError: # anndata >= 0.10.0 - from anndata._core.sparse_dataset import ( - BaseCompressedSparseDataset as SparseDataset, - ) + from anndata.experimental import CSCDataset, CSRDataset + + SparseDataset = (CSRDataset, CSCDataset) from scipy.sparse import issparse from torch.utils.data import Dataset diff --git a/scvi/data/_utils.py b/scvi/data/_utils.py index 71bf342a4a..d52dd6000f 100644 --- a/scvi/data/_utils.py +++ b/scvi/data/_utils.py @@ -16,9 +16,9 @@ from anndata._core.sparse_dataset import SparseDataset except ImportError: # anndata >= 0.10.0 - from anndata._core.sparse_dataset import ( - BaseCompressedSparseDataset as SparseDataset, - ) + from anndata.experimental import CSCDataset, CSRDataset + + SparseDataset = (CSRDataset, CSCDataset) # TODO use the experimental api once we lower bound to anndata 0.8 try: From a773087f6f3586763a3ef36e601b47b7c29ea955 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 9 Feb 2024 11:24:58 -0800 Subject: [PATCH 10/21] Remove default criticism import (#2491) --- scvi/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/__init__.py b/scvi/__init__.py index 2471b3df17..ab29410fe0 100644 --- a/scvi/__init__.py +++ b/scvi/__init__.py @@ -8,7 +8,7 @@ from ._settings import settings # this import needs to come after prior imports to prevent circular import -from . import data, model, external, utils, criticism +from . import data, model, external, utils from importlib.metadata import version From 632e0a11e9ab7df2cf884d077639c9135e14f666 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 9 Feb 2024 18:27:59 -0800 Subject: [PATCH 11/21] Make `use_observed_lib_size` arg adjustable in `LDVAE` (#2494) Closes #2493 --- docs/release_notes/index.md | 1 + scvi/module/_vae.py | 15 ++++++++++++--- tests/model/test_linear_scvi.py | 10 ++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index fbe645fd98..063426243f 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -101,6 +101,7 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ - Use {class}`anndata.experimental.CSCDataset` and {class}`anndata.experimental.CSRDataset` instead of the deprecated {class}`anndata._core.sparse_dataset.SparseDataset` for type checks {pr}`2485`. +- Make `use_observed_lib_size` argument adjustable in {class}`scvi.module.LDVAE` `pr`{2494}. #### Removed diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 0b42c45292..6b2eb8c7c2 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -669,6 +669,14 @@ class LDVAE(VAE): Bool whether to use batch norm in decoder bias Bool whether to have bias term in linear decoder + latent_distribution + One of + + * ``'normal'`` - Isotropic normal + * ``'ln'`` - Logistic normal with normal params N(0, 1) + use_observed_lib_size + Use observed library size for RNA as scaling factor in mean of conditional distribution. + **kwargs """ def __init__( @@ -686,7 +694,8 @@ def __init__( use_batch_norm: bool = True, bias: bool = False, latent_distribution: str = "normal", - **vae_kwargs, + use_observed_lib_size: bool = False, + **kwargs, ): super().__init__( n_input=n_input, @@ -700,8 +709,8 @@ def __init__( log_variational=log_variational, gene_likelihood=gene_likelihood, latent_distribution=latent_distribution, - use_observed_lib_size=False, - **vae_kwargs, + use_observed_lib_size=use_observed_lib_size, + **kwargs, ) self.use_batch_norm = use_batch_norm self.z_encoder = Encoder( diff --git a/tests/model/test_linear_scvi.py b/tests/model/test_linear_scvi.py index 48df69c158..141acfb007 100644 --- a/tests/model/test_linear_scvi.py +++ b/tests/model/test_linear_scvi.py @@ -123,3 +123,13 @@ def test_linear_scvi(): model.get_loadings() model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2") + + +def test_linear_scvi_use_observed_lib_size(): + adata = synthetic_iid() + LinearSCVI.setup_anndata(adata) + model = LinearSCVI(adata, n_latent=10, use_observed_lib_size=True) + model.train(max_epochs=1) + model.get_loadings() + model.differential_expression(groupby="labels", group1="label_1") + model.differential_expression(groupby="labels", group1="label_1", group2="label_2") From 3322b04062cf9f81a4555785fdcdcc78b91688ec Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Sat, 10 Feb 2024 21:26:04 -0800 Subject: [PATCH 12/21] Increase line length to 99 (#2497) --- pyproject.toml | 2 +- scvi/_settings.py | 4 +- scvi/autotune/_manager.py | 40 ++----- scvi/criticism/_ppc.py | 59 +++------- scvi/data/_anntorchdataset.py | 11 +- scvi/data/_built_in_data/_brain_large.py | 22 ++-- scvi/data/_built_in_data/_cite_seq.py | 14 +-- scvi/data/_built_in_data/_csv.py | 8 +- scvi/data/_built_in_data/_loom.py | 4 +- scvi/data/_built_in_data/_pbmc.py | 4 +- scvi/data/_built_in_data/_smfish.py | 8 +- scvi/data/_compat.py | 31 ++--- scvi/data/_download.py | 8 +- scvi/data/_manager.py | 31 ++--- scvi/data/_preprocessing.py | 32 ++---- scvi/data/_read.py | 6 +- scvi/data/_utils.py | 23 +--- scvi/data/fields/_arraylike_field.py | 31 ++--- scvi/data/fields/_base_field.py | 8 +- scvi/data/fields/_dataframe_field.py | 16 +-- scvi/data/fields/_layer_field.py | 8 +- scvi/data/fields/_mudata.py | 12 +- scvi/data/fields/_protein.py | 12 +- scvi/data/fields/_scanvi.py | 4 +- scvi/data/fields/_uns_field.py | 4 +- scvi/dataloaders/_concat_dataloader.py | 4 +- scvi/dataloaders/_data_splitting.py | 20 +--- scvi/dataloaders/_semi_dataloader.py | 8 +- scvi/distributions/_negative_binomial.py | 24 +--- scvi/distributions/_utils.py | 12 +- scvi/external/cellassign/_model.py | 28 ++--- scvi/external/cellassign/_module.py | 15 +-- .../_contrastive_data_splitting.py | 8 +- .../contrastivevi/_contrastive_dataloader.py | 4 +- scvi/external/contrastivevi/_model.py | 40 ++----- scvi/external/contrastivevi/_module.py | 27 ++--- scvi/external/gimvi/_model.py | 40 ++----- scvi/external/gimvi/_module.py | 22 ++-- scvi/external/gimvi/_task.py | 3 +- scvi/external/gimvi/_utils.py | 32 ++---- scvi/external/poissonvi/_model.py | 32 ++---- scvi/external/scar/_model.py | 32 ++---- scvi/external/scar/_module.py | 4 +- scvi/external/scbasset/_model.py | 22 ++-- scvi/external/scbasset/_module.py | 4 +- scvi/external/solo/_model.py | 24 +--- scvi/external/stereoscope/_model.py | 16 +-- scvi/external/stereoscope/_module.py | 12 +- scvi/external/tangram/_model.py | 36 ++---- scvi/hub/_metadata.py | 16 +-- scvi/hub/_model.py | 4 +- scvi/hub/_url.py | 4 +- scvi/model/_amortizedlda.py | 11 +- scvi/model/_autozi.py | 12 +- scvi/model/_condscvi.py | 8 +- scvi/model/_destvi.py | 16 +-- scvi/model/_jaxscvi.py | 4 +- scvi/model/_linear_scvi.py | 12 +- scvi/model/_metrics.py | 4 +- scvi/model/_multivi.py | 68 +++-------- scvi/model/_peakvi.py | 28 ++--- scvi/model/_scanvi.py | 60 +++------- scvi/model/_scvi.py | 36 ++---- scvi/model/_totalvi.py | 92 ++++----------- scvi/model/_utils.py | 20 +--- scvi/model/base/_archesmixin.py | 8 +- scvi/model/base/_base_model.py | 52 +++------ scvi/model/base/_differential.py | 31 ++--- scvi/model/base/_jaxmixin.py | 4 +- scvi/model/base/_pyromixin.py | 34 ++---- scvi/model/base/_rnamixin.py | 44 ++----- scvi/model/base/_training_mixin.py | 4 +- scvi/model/base/_utils.py | 16 +-- scvi/model/base/_vaemixin.py | 12 +- scvi/model/utils/_mde.py | 4 +- scvi/module/_amortizedlda.py | 14 +-- scvi/module/_autozivae.py | 84 ++++---------- scvi/module/_jaxvae.py | 4 +- scvi/module/_mrdeconv.py | 24 +--- scvi/module/_multivae.py | 108 ++++++------------ scvi/module/_peakvae.py | 12 +- scvi/module/_scanvae.py | 28 ++--- scvi/module/_totalvae.py | 48 ++------ scvi/module/_utils.py | 4 +- scvi/module/_vae.py | 48 ++------ scvi/module/_vaec.py | 12 +- scvi/module/base/_base_module.py | 31 ++--- scvi/module/base/_decorators.py | 9 +- scvi/nn/_base_components.py | 28 ++--- scvi/train/_callbacks.py | 3 +- scvi/train/_logger.py | 4 +- scvi/train/_metrics.py | 4 +- scvi/train/_trainer.py | 4 +- scvi/train/_trainingplans.py | 52 +++------ scvi/utils/_dependencies.py | 4 +- tests/criticism/test_criticism.py | 4 +- tests/data/test_anndata.py | 26 +---- tests/data/test_anntorchdataset.py | 4 +- tests/data/test_dataset10X.py | 4 +- tests/data/test_mudata.py | 31 ++--- tests/data/utils.py | 20 +--- tests/dataloaders/sparse_utils.py | 8 +- tests/dataloaders/test_dataloaders.py | 4 +- tests/dataloaders/test_samplers.py | 4 +- tests/distributions/test_negative_binomial.py | 4 +- .../test_contrastive_dataloaders.py | 16 +-- .../contrastivevi/test_contrastivevae.py | 46 +++----- .../contrastivevi/test_contrastivevi.py | 29 ++--- tests/external/gimvi/test_gimvi.py | 20 +--- tests/external/scbasset/test_scbasset.py | 12 +- tests/external/solo/test_solo.py | 4 +- .../external/stereoscope/test_stereoscope.py | 4 +- tests/external/tangram/test_tangram.py | 4 +- tests/hub/test_hub_metadata.py | 15 +-- tests/hub/test_hub_model.py | 16 +-- tests/model/base/test_base_model.py | 8 +- tests/model/test_amortizedlda.py | 8 +- tests/model/test_autozi.py | 8 +- tests/model/test_differential.py | 16 +-- tests/model/test_jaxscvi.py | 8 +- tests/model/test_linear_scvi.py | 4 +- tests/model/test_models_with_minified_data.py | 32 ++---- tests/model/test_peakvi.py | 4 +- tests/model/test_pyro.py | 56 ++++----- tests/model/test_scanvi.py | 36 ++---- tests/model/test_scvi.py | 36 ++---- tests/model/test_totalvi.py | 32 ++---- tests/train/test_trainingplans.py | 16 +-- 128 files changed, 676 insertions(+), 1863 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b050a2ad79..b7c6074c98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,7 @@ xfail_strict = true [tool.ruff] src = ["."] -line-length = 89 +line-length = 99 indent-width = 4 target-version = "py39" diff --git a/scvi/_settings.py b/scvi/_settings.py index 3539c31cb1..2644737f55 100644 --- a/scvi/_settings.py +++ b/scvi/_settings.py @@ -164,9 +164,7 @@ def verbosity(self, level: Union[str, int]): console = Console(force_terminal=True) if console.is_jupyter is True: console.is_jupyter = False - ch = RichHandler( - level=level, show_path=False, console=console, show_time=False - ) + ch = RichHandler(level=level, show_path=False, console=console, show_time=False) formatter = logging.Formatter("%(message)s") ch.setFormatter(formatter) scvi_logger.addHandler(ch) diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py index 898d523cb2..f9c5a29359 100644 --- a/scvi/autotune/_manager.py +++ b/scvi/autotune/_manager.py @@ -112,9 +112,7 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict: tunables = {} for param, metadata in inspect.signature(func).parameters.items(): cond1 = isinstance(metadata.annotation, TunableMeta) - cond2 = "Tunable" in str( - metadata.annotation - ) # needed for new type annotation + cond2 = "Tunable" in str(metadata.annotation) # needed for new type annotation if not cond1 and not cond2: continue @@ -133,9 +131,7 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict: annotation = str(annotation) else: annotation = metadata.annotation - annotation = annotation[ - annotation.find("[") + 1 : annotation.rfind("]") - ] + annotation = annotation[annotation.find("[") + 1 : annotation.rfind("]")] tunables[param] = { "tunable_type": tunable_type, @@ -145,17 +141,13 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict: } return tunables - def _get_tunables( - attr: Any, parent: Any = None, tunable_type: str | None = None - ) -> dict: + def _get_tunables(attr: Any, parent: Any = None, tunable_type: str | None = None) -> dict: tunables = {} if inspect.isfunction(attr): return _parse_func_params(attr, parent, tunable_type) for child in getattr(attr, "_tunables", []): tunables.update( - _get_tunables( - child, parent=attr, tunable_type=_cls_to_tunable_type(attr) - ) + _get_tunables(child, parent=attr, tunable_type=_cls_to_tunable_type(attr)) ) return tunables @@ -210,17 +202,13 @@ def _validate_search_space(self, search_space: dict, use_defaults: bool) -> dict _search_space[param] = sample_fn(*fn_args, **fn_kwargs) # exclude defaults if requested - logger.info( - f"Merging search space with defaults for {self._model_cls.__name__}." - ) + logger.info(f"Merging search space with defaults for {self._model_cls.__name__}.") # priority given to user-provided search space _search_space.update(search_space) return _search_space - def _validate_metrics( - self, metric: str, additional_metrics: list[str] - ) -> OrderedDict: + def _validate_metrics(self, metric: str, additional_metrics: list[str]) -> OrderedDict: """Validates a set of metrics against the metric registry.""" registry_metrics = self._registry["metrics"] _metrics = OrderedDict() @@ -354,15 +342,11 @@ def _validate_scheduler_and_search_algorithm( ) _scheduler = self._validate_scheduler(scheduler, metrics, scheduler_kwargs) - _searcher = self._validate_search_algorithm( - searcher, metrics, searcher_kwargs, seed - ) + _searcher = self._validate_search_algorithm(searcher, metrics, searcher_kwargs, seed) return _scheduler, _searcher @staticmethod - def _validate_reporter( - reporter: bool, search_space: dict, metrics: OrderedDict - ) -> Any: + def _validate_reporter(reporter: bool, search_space: dict, metrics: OrderedDict) -> Any: """Validates a reporter depending on the execution environment.""" _metric_keys = list(metrics.keys()) _param_keys = list(search_space.keys()) @@ -441,9 +425,7 @@ def _trainable( model = model_cls(adata, **model_kwargs) # This is to get around lightning import changes - callback_cls = type( - "_TuneReportCallback", (TuneReportCallback, pl.Callback), {} - ) + callback_cls = type("_TuneReportCallback", (TuneReportCallback, pl.Callback), {}) callbacks = [callback_cls(metric, on="validation_end")] logs_dir = os.path.join(logging_dir, experiment_name) @@ -522,9 +504,7 @@ def _get_tuner( logging_dir: str | None = None, monitor_device_stats: bool = False, ) -> tuple[Any, dict]: - metric = ( - metric or self._get_primary_metric_and_mode(self._registry["metrics"])[0] - ) + metric = metric or self._get_primary_metric_and_mode(self._registry["metrics"])[0] additional_metrics = additional_metrics or [] search_space = search_space or {} model_kwargs = model_kwargs or {} diff --git a/scvi/criticism/_ppc.py b/scvi/criticism/_ppc.py index 4ab6bcdbab..7851a10200 100644 --- a/scvi/criticism/_ppc.py +++ b/scvi/criticism/_ppc.py @@ -38,9 +38,7 @@ def _make_dataset_dense(dataset: Dataset) -> Dataset: """Make a dataset dense, converting sparse arrays to dense arrays.""" - dataset = dataset.map( - lambda x: x.data.todense() if isinstance(x.data, SparseArray) else x - ) + dataset = dataset.map(lambda x: x.data.todense() if isinstance(x.data, SparseArray) else x) return dataset @@ -82,16 +80,12 @@ def __init__( adata = adata[indices] self.adata = adata self.count_layer_key = count_layer_key - raw_counts = ( - adata.layers[count_layer_key] if count_layer_key is not None else adata.X - ) + raw_counts = adata.layers[count_layer_key] if count_layer_key is not None else adata.X # Compressed axis is rows, like csr if isinstance(raw_counts, np.ndarray): self.raw_counts = GCXS.from_numpy(raw_counts, compressed_axes=(0,)) elif issparse(raw_counts): - self.raw_counts = GCXS.from_scipy_sparse(raw_counts).change_compressed_axes( - (0,) - ) + self.raw_counts = GCXS.from_scipy_sparse(raw_counts).change_compressed_axes((0,)) else: raise ValueError("raw_counts must be a numpy array or scipy sparse matrix") self.samples_dataset = None @@ -197,17 +191,11 @@ def coefficient_of_variation(self, dim: Dims = "cells") -> None: def zero_fraction(self) -> None: """Fraction of zeros in raw counts for a specific gene""" pp_samples = self.samples_dataset - mean = ( - (pp_samples != 0) - .mean(dim="cells", skipna=False) - .mean(dim="samples", skipna=False) - ) + mean = (pp_samples != 0).mean(dim="cells", skipna=False).mean(dim="samples", skipna=False) mean = _make_dataset_dense(mean) self.metrics[METRIC_ZERO_FRACTION] = mean.to_dataframe() - def calibration_error( - self, confidence_intervals: list[float] | float = None - ) -> None: + def calibration_error(self, confidence_intervals: list[float] | float = None) -> None: """Calibration error for each observed count. For a series of credible intervals of the samples, the fraction of observed counts that fall @@ -249,12 +237,8 @@ def calibration_error( start = interval[0] end = interval[1] true_width = ps[end] - ps[start] - greater_than = ( - quants[DATA_VAR_RAW] >= quants.model1.isel(quantile=start) - ).data - less_than = ( - quants[DATA_VAR_RAW] <= quants.model1.isel(quantile=end) - ).data + greater_than = (quants[DATA_VAR_RAW] >= quants.model1.isel(quantile=start)).data + less_than = (quants[DATA_VAR_RAW] <= quants.model1.isel(quantile=end)).data # Logical and ci = greater_than * less_than pci_features = ci.mean() @@ -345,9 +329,7 @@ def differential_expression( # run DE with the imputed normalized data with warnings.catch_warnings(): - warnings.simplefilter( - action="ignore", category=pd.errors.PerformanceWarning - ) + warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning) key_added = f"{UNS_NAME_RGG_PPC}_{model}_sample_{k}" de_keys[model].append(key_added) sc.tl.rank_genes_groups( @@ -376,9 +358,7 @@ def differential_expression( self.metrics[METRIC_DIFF_EXP] = {} self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"] = {} for g in groups: - raw_group_data = sc.get.rank_genes_groups_df( - adata_de, group=g, key=UNS_NAME_RGG_RAW - ) + raw_group_data = sc.get.rank_genes_groups_df(adata_de, group=g, key=UNS_NAME_RGG_RAW) raw_group_data.set_index("names", inplace=True) for model in de_keys.keys(): gene_overlap_f1s = [] @@ -399,15 +379,11 @@ def differential_expression( all_genes = raw_group_data.index # order doesn't matter here top_genes_raw = raw_group_data[:n_top_genes_fallback].index top_genes_sample = sample_group_data[:n_top_genes_fallback].index - true_genes = np.array( - [0 if g not in top_genes_raw else 1 for g in all_genes] - ) + true_genes = np.array([0 if g not in top_genes_raw else 1 for g in all_genes]) pred_genes = np.array( [0 if g not in top_genes_sample else 1 for g in all_genes] ) - gene_overlap_f1s.append( - _get_precision_recall_f1(true_genes, pred_genes)[2] - ) + gene_overlap_f1s.append(_get_precision_recall_f1(true_genes, pred_genes)[2]) # compute lfc correlations sample_group_data = sample_group_data.loc[raw_group_data.index] rgd, sgd = ( @@ -442,16 +418,11 @@ def differential_expression( pd.DataFrame(rgds).mean(axis=0), pd.DataFrame(sgds).mean(axis=0), ) - if ( - model - not in self.metrics[METRIC_DIFF_EXP][ - "lfc_per_model_per_group" - ].keys() - ): + if model not in self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"].keys(): self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model] = {} - self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model][ - g - ] = pd.DataFrame([rgd, sgd], index=["raw", "approx"]).T + self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model][g] = pd.DataFrame( + [rgd, sgd], index=["raw", "approx"] + ).T i += 1 self.metrics[METRIC_DIFF_EXP]["summary"] = df diff --git a/scvi/data/_anntorchdataset.py b/scvi/data/_anntorchdataset.py index 124ca2ef09..350cdaf135 100644 --- a/scvi/data/_anntorchdataset.py +++ b/scvi/data/_anntorchdataset.py @@ -61,9 +61,7 @@ def __init__( super().__init__() if adata_manager.adata is None: - raise ValueError( - "Please run ``register_fields`` on ``adata_manager`` first." - ) + raise ValueError("Please run ``register_fields`` on ``adata_manager`` first.") self.adata_manager = adata_manager self.keys_and_dtypes = getitem_tensors self.load_sparse_tensor = load_sparse_tensor @@ -85,9 +83,7 @@ def keys_and_dtypes(self, getitem_tensors: list | dict[str, type] | None): Raises an error if any of the keys are not in the data registry. """ if isinstance(getitem_tensors, list): - keys_to_dtypes = { - key: registry_key_to_default_dtype(key) for key in getitem_tensors - } + keys_to_dtypes = {key: registry_key_to_default_dtype(key) for key in getitem_tensors} elif isinstance(getitem_tensors, dict): keys_to_dtypes = getitem_tensors elif getitem_tensors is None: @@ -117,8 +113,7 @@ def data(self): """ if not hasattr(self, "_data"): self._data = { - key: self.adata_manager.get_from_registry(key) - for key in self.keys_and_dtypes + key: self.adata_manager.get_from_registry(key) for key in self.keys_and_dtypes } return self._data diff --git a/scvi/data/_built_in_data/_brain_large.py b/scvi/data/_built_in_data/_brain_large.py index 3d710a4fc9..fe75eb536c 100644 --- a/scvi/data/_built_in_data/_brain_large.py +++ b/scvi/data/_built_in_data/_brain_large.py @@ -45,9 +45,7 @@ def _load_brainlarge_file( with h5py.File(path_to_file, "r") as f: data = f["mm10"] nb_genes, nb_cells = f["mm10"]["shape"] - n_cells_to_keep = ( - max_cells_to_keep if max_cells_to_keep is not None else nb_cells - ) + n_cells_to_keep = max_cells_to_keep if max_cells_to_keep is not None else nb_cells index_partitioner = data["indptr"][...] # estimate gene variance using a subset of cells. index_partitioner_gene_var = index_partitioner[: (sample_size_gene_var + 1)] @@ -61,9 +59,9 @@ def _load_brainlarge_file( shape=(nb_genes, len(index_partitioner_gene_var) - 1), ) mean = gene_var_sample_matrix.mean(axis=1) - var = gene_var_sample_matrix.multiply(gene_var_sample_matrix).mean( - axis=1 - ) - np.multiply(mean, mean) + var = gene_var_sample_matrix.multiply(gene_var_sample_matrix).mean(axis=1) - np.multiply( + mean, mean + ) subset_genes = np.squeeze(np.asarray(var)).argsort()[-n_genes_to_keep:][::-1] del gene_var_sample_matrix, mean, var @@ -76,16 +74,12 @@ def _load_brainlarge_file( ] first_index_batch = index_partitioner_batch[0] last_index_batch = index_partitioner_batch[-1] - index_partitioner_batch = ( - index_partitioner_batch - first_index_batch - ).astype(np.int32) - n_cells_batch = len(index_partitioner_batch) - 1 - data_batch = data["data"][first_index_batch:last_index_batch].astype( - np.float32 - ) - indices_batch = data["indices"][first_index_batch:last_index_batch].astype( + index_partitioner_batch = (index_partitioner_batch - first_index_batch).astype( np.int32 ) + n_cells_batch = len(index_partitioner_batch) - 1 + data_batch = data["data"][first_index_batch:last_index_batch].astype(np.float32) + indices_batch = data["indices"][first_index_batch:last_index_batch].astype(np.int32) matrix_batch = sp_sparse.csr_matrix( (data_batch, indices_batch, index_partitioner_batch), shape=(n_cells_batch, nb_genes), diff --git a/scvi/data/_built_in_data/_cite_seq.py b/scvi/data/_built_in_data/_cite_seq.py index 11fd594666..e591ff750a 100644 --- a/scvi/data/_built_in_data/_cite_seq.py +++ b/scvi/data/_built_in_data/_cite_seq.py @@ -157,16 +157,10 @@ def _load_pbmc_seurat_v4_cite_seq( if apply_filters: adata.obs["total_counts"] = np.ravel(adata.X.sum(axis=1).A) adata.var["mt"] = adata.var_names.str.startswith("MT-") - adata.obs["total_counts_mt"] = np.ravel( - adata.X[:, adata.var["mt"].values].sum(axis=1).A - ) - adata.obs["pct_counts_mt"] = ( - adata.obs["total_counts_mt"] / adata.obs["total_counts"] * 100 - ) - - adata.obs["Protein log library size"] = np.log( - adata.obsm["protein_counts"].sum(1) - ) + adata.obs["total_counts_mt"] = np.ravel(adata.X[:, adata.var["mt"].values].sum(axis=1).A) + adata.obs["pct_counts_mt"] = adata.obs["total_counts_mt"] / adata.obs["total_counts"] * 100 + + adata.obs["Protein log library size"] = np.log(adata.obsm["protein_counts"].sum(1)) adata.obs["Number proteins detected"] = (adata.obsm["protein_counts"] > 0).sum(1) adata.obs["RNA log library size"] = np.log(adata.X.sum(1).A) diff --git a/scvi/data/_built_in_data/_csv.py b/scvi/data/_built_in_data/_csv.py index e93cffddc1..7576cb7b17 100644 --- a/scvi/data/_built_in_data/_csv.py +++ b/scvi/data/_built_in_data/_csv.py @@ -14,9 +14,7 @@ def _load_breast_cancer_dataset(save_path: str = "data/"): url = "https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer2_BC_count_matrix-1.tsv" save_fn = "Layer2_BC_count_matrix-1.tsv" _download(url, save_path, save_fn) - adata = _load_csv( - os.path.join(save_path, save_fn), delimiter="\t", gene_by_cell=False - ) + adata = _load_csv(os.path.join(save_path, save_fn), delimiter="\t", gene_by_cell=False) adata.obs["batch"] = np.zeros(adata.shape[0]).astype(int) adata.obs["labels"] = np.zeros(adata.shape[0]).astype(int) @@ -28,9 +26,7 @@ def _load_mouse_ob_dataset(save_path: str = "data/"): url = "https://www.spatialresearch.org/wp-content/uploads/2016/07/Rep11_MOB_count_matrix-1.tsv" save_fn = "Rep11_MOB_count_matrix-1.tsv" _download(url, save_path, save_fn) - adata = _load_csv( - os.path.join(save_path, save_fn), delimiter="\t", gene_by_cell=False - ) + adata = _load_csv(os.path.join(save_path, save_fn), delimiter="\t", gene_by_cell=False) adata.obs["batch"] = np.zeros(adata.shape[0]).astype(int) adata.obs["labels"] = np.zeros(adata.shape[0]).astype(int) diff --git a/scvi/data/_built_in_data/_loom.py b/scvi/data/_built_in_data/_loom.py index 8250082171..725013f589 100644 --- a/scvi/data/_built_in_data/_loom.py +++ b/scvi/data/_built_in_data/_loom.py @@ -121,9 +121,7 @@ def _load_loom(path_to_file: str, gene_names_attribute_name: str = "Gene") -> An dataset = loompy.connect(path_to_file) select = dataset[:, :].sum(axis=0) > 0 # Take out cells that don't express any gene if not all(select): - warnings.warn( - "Removing empty cells", UserWarning, stacklevel=settings.warnings_stacklevel - ) + warnings.warn("Removing empty cells", UserWarning, stacklevel=settings.warnings_stacklevel) var_dict, obs_dict, uns_dict, obsm_dict = {}, {}, {}, {} for row_key in dataset.ra: diff --git a/scvi/data/_built_in_data/_pbmc.py b/scvi/data/_built_in_data/_pbmc.py index ea1ac715c8..8ef297d6b8 100644 --- a/scvi/data/_built_in_data/_pbmc.py +++ b/scvi/data/_built_in_data/_pbmc.py @@ -79,9 +79,7 @@ def _load_pbmc_dataset( subset_cells = [] barcodes_metadata = pbmc_metadata["barcodes"].index.values.ravel().astype(str) for barcode in barcodes_metadata: - if ( - barcode in dict_barcodes - ): # barcodes with end -11 filtered on 10X website (49 cells) + if barcode in dict_barcodes: # barcodes with end -11 filtered on 10X website (49 cells) subset_cells += [dict_barcodes[barcode]] adata = adata[np.asarray(subset_cells), :].copy() idx_metadata = np.asarray( diff --git a/scvi/data/_built_in_data/_smfish.py b/scvi/data/_built_in_data/_smfish.py index 0e4b02d6a8..a60e5fd36f 100644 --- a/scvi/data/_built_in_data/_smfish.py +++ b/scvi/data/_built_in_data/_smfish.py @@ -58,9 +58,7 @@ def _load_smfish( return adata -def _load_smfish_data( - path_to_file: str, use_high_level_cluster: bool -) -> anndata.AnnData: +def _load_smfish_data(path_to_file: str, use_high_level_cluster: bool) -> anndata.AnnData: import loompy logger.info("Loading smFISH dataset") @@ -86,9 +84,7 @@ def _load_smfish_data( "Pyramidals", ] row_indices = [ - i - for i in range(data.shape[0]) - if ds.ca["ClusterName"][i] in cell_types_to_keep + i for i in range(data.shape[0]) if ds.ca["ClusterName"][i] in cell_types_to_keep ] str_labels = str_labels[row_indices] data = data[row_indices, :] diff --git a/scvi/data/_compat.py b/scvi/data/_compat.py index 83ec020935..3560d9d4bc 100644 --- a/scvi/data/_compat.py +++ b/scvi/data/_compat.py @@ -26,9 +26,7 @@ } -def _infer_setup_args( - model_cls, setup_dict: dict, unlabeled_category: Optional[str] -) -> dict: +def _infer_setup_args(model_cls, setup_dict: dict, unlabeled_category: Optional[str]) -> dict: setup_args = {} data_registry = setup_dict[_constants._DATA_REGISTRY_KEY] categorical_mappings = setup_dict["categorical_mappings"] @@ -127,20 +125,17 @@ def registry_from_setup_dict( elif attr_name == _constants._ADATA_ATTRS.OBS: categorical_mapping = categorical_mappings[attr_key] # Default labels field for TOTALVI - if ( - model_cls.__name__ == "TOTALVI" - and new_registry_key == REGISTRY_KEYS.LABELS_KEY - ): - field_state_registry[ - CategoricalObsField.CATEGORICAL_MAPPING_KEY - ] = np.zeros(1, dtype=np.int64) + if model_cls.__name__ == "TOTALVI" and new_registry_key == REGISTRY_KEYS.LABELS_KEY: + field_state_registry[CategoricalObsField.CATEGORICAL_MAPPING_KEY] = np.zeros( + 1, dtype=np.int64 + ) else: field_state_registry[ CategoricalObsField.CATEGORICAL_MAPPING_KEY ] = categorical_mapping["mapping"] - field_state_registry[ - CategoricalObsField.ORIGINAL_ATTR_KEY - ] = categorical_mapping["original_key"] + field_state_registry[CategoricalObsField.ORIGINAL_ATTR_KEY] = categorical_mapping[ + "original_key" + ] if new_registry_key == REGISTRY_KEYS.BATCH_KEY: field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_batch"] elif new_registry_key == REGISTRY_KEYS.LABELS_KEY: @@ -168,12 +163,10 @@ def registry_from_setup_dict( "protein_names" ].copy() if "totalvi_batch_mask" in setup_dict: - field_state_registry[ - ProteinObsmField.PROTEIN_BATCH_MASK - ] = setup_dict["totalvi_batch_mask"].copy() - field_summary_stats[f"n_{new_registry_key}"] = len( - setup_dict["protein_names"] - ) + field_state_registry[ProteinObsmField.PROTEIN_BATCH_MASK] = setup_dict[ + "totalvi_batch_mask" + ].copy() + field_summary_stats[f"n_{new_registry_key}"] = len(setup_dict["protein_names"]) registry.update(_infer_setup_args(model_cls, setup_dict, unlabeled_category)) diff --git a/scvi/data/_download.py b/scvi/data/_download.py index b98468f0b5..c8a0aad532 100644 --- a/scvi/data/_download.py +++ b/scvi/data/_download.py @@ -16,9 +16,7 @@ def _download(url: Optional[str], save_path: str, filename: str): logger.info(f"File {os.path.join(save_path, filename)} already downloaded") return elif url is None: - logger.info( - f"No backup URL provided for missing file {os.path.join(save_path, filename)}" - ) + logger.info(f"No backup URL provided for missing file {os.path.join(save_path, filename)}") return req = urllib.request.Request(url, headers={"User-Agent": "Magic Browser"}) try: @@ -55,7 +53,5 @@ def read_iter(file, block_size=1000): filesize = np.rint(filesize / block_size) with open(os.path.join(save_path, filename), "wb") as f: iterator = read_iter(r, block_size=block_size) - for data in track( - iterator, style="tqdm", total=filesize, description="Downloading..." - ): + for data in track(iterator, style="tqdm", total=filesize, description="Downloading..."): f.write(data) diff --git a/scvi/data/_manager.py b/scvi/data/_manager.py index 7e3e9b309a..460777af7f 100644 --- a/scvi/data/_manager.py +++ b/scvi/data/_manager.py @@ -102,19 +102,14 @@ def __init__( def _assert_anndata_registered(self): """Asserts that an AnnData object has been registered with this instance.""" if self.adata is None: - raise AssertionError( - "AnnData object not registered. Please call register_fields." - ) + raise AssertionError("AnnData object not registered. Please call register_fields.") def _validate_anndata_object(self, adata: AnnOrMuData): """For a given AnnData object, runs general scvi-tools compatibility checks.""" if self.validation_checks.check_if_view: _check_if_view(adata, copy_if_view=False) - if ( - isinstance(adata, MuData) - and self.validation_checks.check_fully_paired_mudata - ): + if isinstance(adata, MuData) and self.validation_checks.check_fully_paired_mudata: _check_mudata_fully_paired(adata) def _get_setup_method_args(self) -> dict: @@ -166,9 +161,7 @@ def register_fields( Additional keywords which modify transfer behavior. Only applicable if ``source_registry`` is set. """ if self.adata is not None: - raise AssertionError( - "Existing AnnData object registered with this Manager instance." - ) + raise AssertionError("Existing AnnData object registered with this Manager instance.") if source_registry is None and transfer_kwargs: raise TypeError( @@ -214,21 +207,17 @@ def _add_field( # Transfer case: Source registry is used for validation and/or setup. if source_registry is not None: field_registry[_constants._STATE_REGISTRY_KEY] = field.transfer_field( - source_registry[_constants._FIELD_REGISTRIES_KEY][ - field.registry_key - ][_constants._STATE_REGISTRY_KEY], + source_registry[_constants._FIELD_REGISTRIES_KEY][field.registry_key][ + _constants._STATE_REGISTRY_KEY + ], adata, **transfer_kwargs, ) else: - field_registry[_constants._STATE_REGISTRY_KEY] = field.register_field( - adata - ) + field_registry[_constants._STATE_REGISTRY_KEY] = field.register_field(adata) # Compute and set summary stats for the given field. state_registry = field_registry[_constants._STATE_REGISTRY_KEY] - field_registry[_constants._SUMMARY_STATS_KEY] = field.get_summary_stats( - state_registry - ) + field_registry[_constants._SUMMARY_STATS_KEY] = field.get_summary_stats(state_registry) def register_new_fields(self, fields: list[AnnDataField]): """Register new fields to a manager instance. @@ -364,9 +353,7 @@ def create_torch_dataset( @staticmethod def _get_data_registry_from_registry(registry: dict) -> attrdict: data_registry = {} - for registry_key, field_registry in registry[ - _constants._FIELD_REGISTRIES_KEY - ].items(): + for registry_key, field_registry in registry[_constants._FIELD_REGISTRIES_KEY].items(): field_data_registry = field_registry[_constants._DATA_REGISTRY_KEY] if field_data_registry: data_registry[registry_key] = field_data_registry diff --git a/scvi/data/_preprocessing.py b/scvi/data/_preprocessing.py index f16c624060..ca4892a9c3 100644 --- a/scvi/data/_preprocessing.py +++ b/scvi/data/_preprocessing.py @@ -132,9 +132,7 @@ def poisson_gene_selection( expected_fraction_zeros = torch.zeros(scaled_means.shape, device=device) for i in range(n_batches): - total_counts_batch = total_counts[ - i * minibatch_size : (i + 1) * minibatch_size - ] + total_counts_batch = total_counts[i * minibatch_size : (i + 1) * minibatch_size] # Use einsum for outer product. expected_fraction_zeros += torch.exp( -torch.einsum("i,j->ij", [scaled_means, total_counts_batch]) @@ -218,16 +216,12 @@ def poisson_gene_selection( adata.var["highly_variable"] = df["highly_variable"].values adata.var["observed_fraction_zeros"] = df["observed_fraction_zeros"].values adata.var["expected_fraction_zeros"] = df["expected_fraction_zeros"].values - adata.var["prob_zero_enriched_nbatches"] = df[ - "prob_zero_enriched_nbatches" - ].values + adata.var["prob_zero_enriched_nbatches"] = df["prob_zero_enriched_nbatches"].values adata.var["prob_zero_enrichment"] = df["prob_zero_enrichment"].values adata.var["prob_zero_enrichment_rank"] = df["prob_zero_enrichment_rank"].values if batch_key is not None: - adata.var["prob_zero_enriched_nbatches"] = df[ - "prob_zero_enriched_nbatches" - ].values + adata.var["prob_zero_enriched_nbatches"] = df["prob_zero_enriched_nbatches"].values if subset: adata._inplace_subset_var(df["highly_variable"].values) else: @@ -236,9 +230,7 @@ def poisson_gene_selection( return df -def organize_cite_seq_10x( - adata: anndata.AnnData, copy: bool = False -) -> Optional[anndata.AnnData]: +def organize_cite_seq_10x(adata: anndata.AnnData, copy: bool = False) -> Optional[anndata.AnnData]: """Organize anndata object loaded from 10x for scvi models. Parameters @@ -270,9 +262,7 @@ def organize_cite_seq_10x( adata = adata.copy() pro_array = adata[:, adata.var["feature_types"] == "Antibody Capture"].X.copy().A - pro_names = np.array( - adata.var_names[adata.var["feature_types"] == "Antibody Capture"] - ) + pro_names = np.array(adata.var_names[adata.var["feature_types"] == "Antibody Capture"]) genes = (adata.var["feature_types"] != "Antibody Capture").values adata._inplace_subset_var(genes) @@ -323,9 +313,7 @@ def organize_multiome_anndatas( obs_names = list(multi_anndata.obs.index.values) def _concat_anndata(multi_anndata, other): - shared_features = np.intersect1d( - other.var.index.values, multi_anndata.var.index.values - ) + shared_features = np.intersect1d(other.var.index.values, multi_anndata.var.index.values) if not len(shared_features) > 0: raise ValueError("No shared features between Multiome and other AnnData.") @@ -346,9 +334,7 @@ def _concat_anndata(multi_anndata, other): # set .obs stuff res_anndata.obs[modality_key] = modality_ann - res_anndata.obs.index = ( - pd.Series(obs_names) + "_" + res_anndata.obs[modality_key].values - ) + res_anndata.obs.index = pd.Series(obs_names) + "_" + res_anndata.obs[modality_key].values # keep the feature order as the original order in the multiomic anndata res_anndata = res_anndata[:, multi_anndata.var.index.values] @@ -481,8 +467,6 @@ def reads_to_fragments( adata.layers[read_layer].copy() if read_layer else adata.X.copy() ) if issparse(adata.layers[fragment_layer]): - adata.layers[fragment_layer].data = np.ceil( - adata.layers[fragment_layer].data / 2 - ) + adata.layers[fragment_layer].data = np.ceil(adata.layers[fragment_layer].data / 2) else: adata.layers[fragment_layer] = np.ceil(adata.layers[fragment_layer] / 2) diff --git a/scvi/data/_read.py b/scvi/data/_read.py index 1fcd1de7fb..4947fd975d 100644 --- a/scvi/data/_read.py +++ b/scvi/data/_read.py @@ -24,11 +24,7 @@ def read_10x_atac(base_path: Union[str, Path]) -> AnnData: ) coords.rename({0: "chr", 1: "start", 2: "end"}, axis="columns", inplace=True) coords.set_index( - coords.chr.astype(str) - + ":" - + coords.start.astype(str) - + "-" - + coords.end.astype(str), + coords.chr.astype(str) + ":" + coords.start.astype(str) + "-" + coords.end.astype(str), inplace=True, ) coords.index = coords.index.astype(str) diff --git a/scvi/data/_utils.py b/scvi/data/_utils.py index d52dd6000f..48445e900f 100644 --- a/scvi/data/_utils.py +++ b/scvi/data/_utils.py @@ -90,8 +90,7 @@ def scipy_to_torch_sparse(x: ScipySparse) -> torch.Tensor: ) else: raise TypeError( - "`x` must be of type `scipy.sparse.csr_matrix` or " - "`scipy.sparse.csc_matrix`." + "`x` must be of type `scipy.sparse.csr_matrix` or " "`scipy.sparse.csc_matrix`." ) @@ -114,9 +113,7 @@ def get_anndata_attribute( else: if isinstance(adata_attr, pd.DataFrame): if attr_key not in adata_attr.columns: - raise ValueError( - f"{attr_key} is not a valid column in adata.{attr_name}." - ) + raise ValueError(f"{attr_key} is not a valid column in adata.{attr_name}.") field = adata_attr.loc[:, attr_key] else: if attr_key not in adata_attr.keys(): @@ -162,9 +159,7 @@ def _set_data_in_registry( setattr(adata, attr_name, attribute) -def _verify_and_correct_data_format( - adata: AnnData, attr_name: str, attr_key: Optional[str] -): +def _verify_and_correct_data_format(adata: AnnData, attr_name: str, attr_key: Optional[str]): """Will make sure that the user's AnnData field is C_CONTIGUOUS and csr if it is dense numpy or sparse respectively. Parameters @@ -178,9 +173,7 @@ def _verify_and_correct_data_format( """ data = get_anndata_attribute(adata, attr_name, attr_key) data_loc_str = ( - f"adata.{attr_name}[{attr_key}]" - if attr_key is not None - else f"adata.{attr_name}" + f"adata.{attr_name}[{attr_key}]" if attr_key is not None else f"adata.{attr_name}" ) if sp_sparse.isspmatrix(data) and (data.getformat() != "csr"): warnings.warn( @@ -192,9 +185,7 @@ def _verify_and_correct_data_format( logger.debug(f"{data_loc_str} is not C_CONTIGUOUS. Overwriting to C_CONTIGUOUS.") data = np.asarray(data, order="C") _set_data_in_registry(adata, data, attr_name, attr_key) - elif isinstance(data, pd.DataFrame) and ( - data.to_numpy().flags["C_CONTIGUOUS"] is False - ): + elif isinstance(data, pd.DataFrame) and (data.to_numpy().flags["C_CONTIGUOUS"] is False): logger.debug(f"{data_loc_str} is not C_CONTIGUOUS. Overwriting to C_CONTIGUOUS.") index = data.index vals = data.to_numpy() @@ -306,9 +297,7 @@ def _check_if_view(adata: AnnOrMuData, copy_if_view: bool = False): def _check_mudata_fully_paired(mdata: MuData): if isinstance(mdata, AnnData): - raise AssertionError( - "Cannot call ``_check_mudata_fully_paired`` with AnnData object." - ) + raise AssertionError("Cannot call ``_check_mudata_fully_paired`` with AnnData object.") for mod_key in mdata.mod: if not mdata.obsm[mod_key].all(): raise ValueError( diff --git a/scvi/data/fields/_arraylike_field.py b/scvi/data/fields/_arraylike_field.py index f4eb310a53..729f2218dc 100644 --- a/scvi/data/fields/_arraylike_field.py +++ b/scvi/data/fields/_arraylike_field.py @@ -149,9 +149,7 @@ def register_field(self, adata: AnnData) -> dict: return {self.COLUMN_NAMES_KEY: column_names} - def transfer_field( - self, state_registry: dict, adata_target: AnnData, **kwargs - ) -> dict: + def transfer_field(self, state_registry: dict, adata_target: AnnData, **kwargs) -> dict: """Transfer the field.""" super().transfer_field(state_registry, adata_target, **kwargs) self.validate_field(adata_target) @@ -163,9 +161,7 @@ def transfer_field( f"the source adata.{self.attr_name}['{self.attr_key}'] column count of {len(source_cols)}." ) - if isinstance(target_data, pd.DataFrame) and source_cols != list( - target_data.columns - ): + if isinstance(target_data, pd.DataFrame) and source_cols != list(target_data.columns): raise ValueError( f"Target adata.{self.attr_name}['{self.attr_key}'] column names do not match " f"the source adata.{self.attr_name}['{self.attr_key}'] column names." @@ -301,11 +297,7 @@ def register_field(self, adata: AnnData) -> dict: """Register the field.""" super().register_field(adata) self._combine_fields(adata) - return { - self.COLUMNS_KEY: getattr(adata, self.attr_name)[ - self.attr_key - ].columns.to_numpy() - } + return {self.COLUMNS_KEY: getattr(adata, self.attr_name)[self.attr_key].columns.to_numpy()} def transfer_field( self, @@ -398,10 +390,7 @@ def _make_array_categorical( self, adata: AnnData, category_dict: Optional[dict[str, list[str]]] = None ) -> dict: """Make the .obsm categorical.""" - if ( - self.attr_keys - != getattr(adata, self.attr_name)[self.attr_key].columns.tolist() - ): + if self.attr_keys != getattr(adata, self.attr_name)[self.attr_key].columns.tolist(): raise ValueError( f"Original .{self.source_attr_name} keys do not match the columns in the ", f"generated .{self.attr_name} field.", @@ -415,9 +404,7 @@ def _make_array_categorical( if category_dict is not None else None ) - mapping = _make_column_categorical( - df, key, key, categorical_dtype=categorical_dtype - ) + mapping = _make_column_categorical(df, key, key, categorical_dtype=categorical_dtype) categories[key] = mapping store_cats = categories if category_dict is None else category_dict @@ -481,9 +468,7 @@ def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table no_wrap=True, overflow="fold", ) - t.add_column( - "Categories", justify="center", style="green", no_wrap=True, overflow="fold" - ) + t.add_column("Categories", justify="center", style="green", no_wrap=True, overflow="fold") t.add_column( "scvi-tools Encoding", justify="center", @@ -494,9 +479,7 @@ def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table for key, mappings in state_registry[self.MAPPINGS_KEY].items(): for i, mapping in enumerate(mappings): if i == 0: - t.add_row( - f"adata.{self.source_attr_name}['{key}']", str(mapping), str(i) - ) + t.add_row(f"adata.{self.source_attr_name}['{key}']", str(mapping), str(i)) else: t.add_row("", str(mapping), str(i)) t.add_row("", "") diff --git a/scvi/data/fields/_base_field.py b/scvi/data/fields/_base_field.py index d6acc1f7a6..27e877bdc6 100644 --- a/scvi/data/fields/_base_field.py +++ b/scvi/data/fields/_base_field.py @@ -65,9 +65,7 @@ def register_field(self, adata: AnnOrMuData) -> dict: return {} @abstractmethod - def transfer_field( - self, state_registry: dict, adata_target: AnnOrMuData, **kwargs - ) -> dict: + def transfer_field(self, state_registry: dict, adata_target: AnnOrMuData, **kwargs) -> dict: """Takes an existing scvi-tools setup dictionary and transfers the same setup to the target AnnData. Used when one is running a pretrained model on a new AnnData object, which @@ -128,9 +126,7 @@ def get_field_data(self, adata: AnnOrMuData) -> Union[np.ndarray, pd.DataFrame]: """Returns the requested data as determined by the field for a given AnnData/MuData object.""" if self.is_empty: raise AssertionError(f"The {self.registry_key} field is empty.") - return get_anndata_attribute( - adata, self.attr_name, self.attr_key, mod_key=self.mod_key - ) + return get_anndata_attribute(adata, self.attr_name, self.attr_key, mod_key=self.mod_key) def get_data_registry(self) -> dict: """Returns a nested dictionary which describes the mapping to the data field. diff --git a/scvi/data/fields/_dataframe_field.py b/scvi/data/fields/_dataframe_field.py index c35770c4ea..2a56f3bc46 100644 --- a/scvi/data/fields/_dataframe_field.py +++ b/scvi/data/fields/_dataframe_field.py @@ -91,9 +91,7 @@ def register_field(self, adata: AnnData) -> dict: """Register field.""" return super().register_field(adata) - def transfer_field( - self, state_registry: dict, adata_target: AnnData, **kwargs - ) -> dict: + def transfer_field(self, state_registry: dict, adata_target: AnnData, **kwargs) -> dict: """Transfer field from registry to target AnnData.""" super().transfer_field(state_registry, adata_target, **kwargs) return self.register_field(adata_target) @@ -162,9 +160,7 @@ def _setup_default_attr(self, adata: AnnData) -> None: """Setup default attr.""" self._original_attr_key = self.attr_key length = ( - adata.shape[0] - if self._attr_name == _constants._ADATA_ATTRS.OBS - else adata.shape[1] + adata.shape[0] if self._attr_name == _constants._ADATA_ATTRS.OBS else adata.shape[1] ) getattr(adata, self.attr_name)[self.attr_key] = np.zeros(length, dtype=np.int64) @@ -176,9 +172,7 @@ def validate_field(self, adata: AnnData) -> None: """Validate field.""" super().validate_field(adata) if self._original_attr_key not in getattr(adata, self.attr_name): - raise KeyError( - f"{self._original_attr_key} not found in adata.{self.attr_name}." - ) + raise KeyError(f"{self._original_attr_key} not found in adata.{self.attr_name}.") def register_field(self, adata: AnnData) -> dict: """Register field.""" @@ -253,9 +247,7 @@ def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table no_wrap=True, overflow="fold", ) - t.add_column( - "Categories", justify="center", style="green", no_wrap=True, overflow="fold" - ) + t.add_column("Categories", justify="center", style="green", no_wrap=True, overflow="fold") t.add_column( "scvi-tools Encoding", justify="center", diff --git a/scvi/data/fields/_layer_field.py b/scvi/data/fields/_layer_field.py index fd06806f25..e726580284 100644 --- a/scvi/data/fields/_layer_field.py +++ b/scvi/data/fields/_layer_field.py @@ -51,9 +51,7 @@ def __init__( super().__init__() self._registry_key = registry_key self._attr_name = ( - _constants._ADATA_ATTRS.X - if layer is None - else _constants._ADATA_ATTRS.LAYERS + _constants._ADATA_ATTRS.X if layer is None else _constants._ADATA_ATTRS.LAYERS ) self._attr_key = layer self.is_count_data = is_count_data @@ -121,9 +119,7 @@ def register_field(self, adata: AnnData) -> dict: self.COLUMN_NAMES_KEY: np.asarray(adata.var_names), } - def transfer_field( - self, state_registry: dict, adata_target: AnnData, **kwargs - ) -> dict: + def transfer_field(self, state_registry: dict, adata_target: AnnData, **kwargs) -> dict: """Transfer the field.""" super().transfer_field(state_registry, adata_target, **kwargs) n_vars = state_registry[self.N_VARS_KEY] diff --git a/scvi/data/fields/_mudata.py b/scvi/data/fields/_mudata.py index 90b66fe2e9..609874a068 100644 --- a/scvi/data/fields/_mudata.py +++ b/scvi/data/fields/_mudata.py @@ -23,14 +23,10 @@ class BaseMuDataWrapperClass(BaseAnnDataField): If ``True``, raises ``ValueError`` when ``mod_key`` is ``None``. """ - def __init__( - self, mod_key: Optional[str] = None, mod_required: bool = False - ) -> None: + def __init__(self, mod_key: Optional[str] = None, mod_required: bool = False) -> None: super().__init__() if mod_required and mod_key is None: - raise ValueError( - f"Modality required for {self.__class__.__name__} but not provided." - ) + raise ValueError(f"Modality required for {self.__class__.__name__} but not provided.") self._mod_key = mod_key self._preregister = lambda _self, _mdata: None @@ -96,9 +92,7 @@ def register_field(self, mdata: MuData) -> dict: bdata = self.get_modality(mdata) return self.adata_field.register_field(bdata) - def transfer_field( - self, state_registry: dict, mdata_target: MuData, **kwargs - ) -> dict: + def transfer_field(self, state_registry: dict, mdata_target: MuData, **kwargs) -> dict: """Transfer the field.""" self.preregister(mdata_target) bdata_target = self.get_modality(mdata_target) diff --git a/scvi/data/fields/_protein.py b/scvi/data/fields/_protein.py index e61f4574a5..5000f93494 100644 --- a/scvi/data/fields/_protein.py +++ b/scvi/data/fields/_protein.py @@ -83,13 +83,9 @@ def register_field(self, adata: AnnData) -> dict: return state_registry - def transfer_field( - self, state_registry: dict, adata_target: AnnData, **kwargs - ) -> dict: + def transfer_field(self, state_registry: dict, adata_target: AnnData, **kwargs) -> dict: """Transfer the field.""" - transfer_state_registry = super().transfer_field( - state_registry, adata_target, **kwargs - ) + transfer_state_registry = super().transfer_field(state_registry, adata_target, **kwargs) batch_mask = self._get_batch_mask_protein_data(adata_target) if batch_mask is not None: transfer_state_registry[self.PROTEIN_BATCH_MASK] = batch_mask @@ -167,6 +163,4 @@ def copy_over_batch_attr(self, mdata: MuData): bdata_attr[self.batch_field.attr_key] = batch_data -MuDataProteinLayerField = MuDataWrapper( - ProteinLayerField, preregister_fn=copy_over_batch_attr -) +MuDataProteinLayerField = MuDataWrapper(ProteinLayerField, preregister_fn=copy_over_batch_attr) diff --git a/scvi/data/fields/_scanvi.py b/scvi/data/fields/_scanvi.py index ec2e9bcab0..65ef6357fd 100644 --- a/scvi/data/fields/_scanvi.py +++ b/scvi/data/fields/_scanvi.py @@ -38,9 +38,7 @@ def __init__( super().__init__(registry_key, obs_key) self._unlabeled_category = unlabeled_category - def _remap_unlabeled_to_final_category( - self, adata: AnnData, mapping: np.ndarray - ) -> dict: + def _remap_unlabeled_to_final_category(self, adata: AnnData, mapping: np.ndarray) -> dict: labels = self._get_original_column(adata) if self._unlabeled_category in labels: diff --git a/scvi/data/fields/_uns_field.py b/scvi/data/fields/_uns_field.py index 7cf82a2b8b..799657caa9 100644 --- a/scvi/data/fields/_uns_field.py +++ b/scvi/data/fields/_uns_field.py @@ -26,9 +26,7 @@ class BaseUnsField(BaseAnnDataField): _attr_name = _constants._ADATA_ATTRS.UNS - def __init__( - self, registry_key: str, uns_key: Optional[str], required: bool = True - ) -> None: + def __init__(self, registry_key: str, uns_key: Optional[str], required: bool = True) -> None: super().__init__() if required and uns_key is None: raise ValueError( diff --git a/scvi/dataloaders/_concat_dataloader.py b/scvi/dataloaders/_concat_dataloader.py index ef7c8066c6..dcb41b0a80 100644 --- a/scvi/dataloaders/_concat_dataloader.py +++ b/scvi/dataloaders/_concat_dataloader.py @@ -74,7 +74,5 @@ def __iter__(self): the data in the other dataloaders. The order of data in returned iter_list is the same as indices_list. """ - iter_list = [ - cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders - ] + iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders] return zip(*iter_list) diff --git a/scvi/dataloaders/_data_splitting.py b/scvi/dataloaders/_data_splitting.py index 52d4ca4330..9a1d2f56f7 100644 --- a/scvi/dataloaders/_data_splitting.py +++ b/scvi/dataloaders/_data_splitting.py @@ -246,9 +246,7 @@ def __init__( self.data_loader_kwargs = kwargs self.n_samples_per_label = n_samples_per_label - labels_state_registry = adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) labels = get_anndata_attribute( adata_manager.adata, adata_manager.data_registry.labels.attr_name, @@ -304,9 +302,7 @@ def setup(self, stage: Optional[str] = None): unlabeled_idx_train = unlabeled_permutation[ n_unlabeled_val : (n_unlabeled_val + n_unlabeled_train) ] - unlabeled_idx_test = unlabeled_permutation[ - (n_unlabeled_val + n_unlabeled_train) : - ] + unlabeled_idx_test = unlabeled_permutation[(n_unlabeled_val + n_unlabeled_train) :] else: unlabeled_idx_train = [] unlabeled_idx_val = [] @@ -440,16 +436,10 @@ def setup(self, stage: Optional[str] = None): if self.shuffle is False: self.train_idx = np.sort(self.train_idx) - self.val_idx = ( - np.sort(self.val_idx) if len(self.val_idx) > 0 else self.val_idx - ) - self.test_idx = ( - np.sort(self.test_idx) if len(self.test_idx) > 0 else self.test_idx - ) + self.val_idx = np.sort(self.val_idx) if len(self.val_idx) > 0 else self.val_idx + self.test_idx = np.sort(self.test_idx) if len(self.test_idx) > 0 else self.test_idx - self.train_tensor_dict = self._get_tensor_dict( - self.train_idx, device=self.device - ) + self.train_tensor_dict = self._get_tensor_dict(self.train_idx, device=self.device) self.test_tensor_dict = self._get_tensor_dict(self.test_idx, device=self.device) self.val_tensor_dict = self._get_tensor_dict(self.val_idx, device=self.device) diff --git a/scvi/dataloaders/_semi_dataloader.py b/scvi/dataloaders/_semi_dataloader.py index da4fed4f8b..5a7c79c3c2 100644 --- a/scvi/dataloaders/_semi_dataloader.py +++ b/scvi/dataloaders/_semi_dataloader.py @@ -56,9 +56,7 @@ def __init__( self.n_samples_per_label = n_samples_per_label - labels_state_registry = adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) labels = get_anndata_attribute( adata_manager.adata, adata_manager.data_registry.labels.attr_name, @@ -109,9 +107,7 @@ def subsample_labels(self): if len(loc) < self.n_samples_per_label: sample_idx.append(loc) else: - label_subset = np.random.choice( - loc, self.n_samples_per_label, replace=False - ) + label_subset = np.random.choice(loc, self.n_samples_per_label, replace=False) sample_idx.append(label_subset) sample_idx = np.concatenate(sample_idx) return sample_idx diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index bfa3837b83..6a6e0eb492 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -51,9 +51,7 @@ def log_zinb_positive( """ # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels) if theta.ndimension() == 1: - theta = theta.view( - 1, theta.size(0) - ) # In this case, we reshape theta for broadcasting + theta = theta.view(1, theta.size(0)) # In this case, we reshape theta for broadcasting # Uses log(sigmoid(x)) = -softplus(-x) softplus_pi = F.softplus(-pi) @@ -157,9 +155,7 @@ def log_mixture_nb( else: theta = theta_1 if theta.ndimension() == 1: - theta = theta.view( - 1, theta.size(0) - ) # In this case, we reshape theta for broadcasting + theta = theta.view(1, theta.size(0)) # In this case, we reshape theta for broadcasting log_theta_mu_1_eps = torch.log(theta + mu_1 + eps) log_theta_mu_2_eps = torch.log(theta + mu_2 + eps) @@ -343,9 +339,7 @@ def __init__( "Please use one of the two possible parameterizations. Refer to the documentation for more information." ) - using_param_1 = total_count is not None and ( - logits is not None or probs is not None - ) + using_param_1 = total_count is not None and (logits is not None or probs is not None) if using_param_1: logits = logits if logits is not None else probs_to_logits(probs) total_count = total_count.type_as(logits) @@ -379,9 +373,7 @@ def sample( # Clamping as distributions objects can have buggy behaviors when # their parameters are too high l_train = torch.clamp(p_means, max=1e8) - counts = PoissonTorch( - l_train - ).sample() # Shape : (n_samples, n_cells_batch, n_vars) + counts = PoissonTorch(l_train).sample() # Shape : (n_samples, n_cells_batch, n_vars) return counts def log_prob(self, value: torch.Tensor) -> torch.Tensor: @@ -473,9 +465,7 @@ def __init__( scale=scale, validate_args=validate_args, ) - self.zi_logits, self.mu, self.theta = broadcast_all( - zi_logits, self.mu, self.theta - ) + self.zi_logits, self.mu, self.theta = broadcast_all(zi_logits, self.mu, self.theta) @property def mean(self) -> torch.Tensor: @@ -603,9 +593,7 @@ def sample( # Clamping as distributions objects can have buggy behaviors when # their parameters are too high l_train = torch.clamp(p_means, max=1e8) - counts = PoissonTorch( - l_train - ).sample() # Shape : (n_samples, n_cells_batch, n_features) + counts = PoissonTorch(l_train).sample() # Shape : (n_samples, n_cells_batch, n_features) return counts def log_prob(self, value: torch.Tensor) -> torch.Tensor: diff --git a/scvi/distributions/_utils.py b/scvi/distributions/_utils.py index 6f12f8ca8c..b169569cbe 100644 --- a/scvi/distributions/_utils.py +++ b/scvi/distributions/_utils.py @@ -9,9 +9,7 @@ def subset_distribution( """Utility function to subset the parameter of a Pytorch distribution.""" return my_distribution.__class__( **{ - name: torch.index_select( - getattr(my_distribution, name), dim=dim, index=index - ) + name: torch.index_select(getattr(my_distribution, name), dim=dim, index=index) for name in my_distribution.arg_constraints.keys() } ) @@ -38,15 +36,11 @@ def store_distribution(self, dist: torch.distributions.Distribution): if self._params is None: self._params = {name: [] for name in dist.arg_constraints.keys()} self.distribution_cls = dist.__class__ - new_params = { - name: getattr(dist, name).cpu() for name in dist.arg_constraints.keys() - } + new_params = {name: getattr(dist, name).cpu() for name in dist.arg_constraints.keys()} for param_name, param in new_params.items(): self._params[param_name].append(param) def get_concatenated_distributions(self, axis=0): """Returns a concatenated `Distribution` object along the specified axis.""" - concat_params = { - key: torch.cat(value, dim=axis) for key, value in self._params.items() - } + concat_params = {key: torch.cat(value, dim=axis) for key, value in self._params.items()} return self.distribution_cls(**concat_params) diff --git a/scvi/external/cellassign/_model.py b/scvi/external/cellassign/_model.py index dcb3f6f6bf..38dbc03529 100644 --- a/scvi/external/cellassign/_model.py +++ b/scvi/external/cellassign/_model.py @@ -77,18 +77,14 @@ def __init__( try: cell_type_markers = cell_type_markers.loc[adata.var_names] except KeyError as err: - raise KeyError( - "Anndata and cell type markers do not contain the same genes." - ) from err + raise KeyError("Anndata and cell type markers do not contain the same genes.") from err super().__init__(adata) self.n_genes = self.summary_stats.n_vars self.cell_type_markers = cell_type_markers rho = torch.Tensor(cell_type_markers.to_numpy()) n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None ) @@ -112,7 +108,9 @@ def __init__( n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), **model_kwargs, ) - self._model_summary_string = f"CellAssign Model with params: \nn_genes: {self.n_genes}, n_labels: {rho.shape[1]}" + self._model_summary_string = ( + f"CellAssign Model with params: \nn_genes: {self.n_genes}, n_labels: {rho.shape[1]}" + ) self.init_params_ = self._get_init_params(locals()) @torch.inference_mode() @@ -126,9 +124,7 @@ def predict(self) -> pd.DataFrame: outputs = self.module.generative(**generative_inputs) gamma = outputs["gamma"] predictions += [gamma.cpu()] - return pd.DataFrame( - torch.cat(predictions).numpy(), columns=self.cell_type_markers.columns - ) + return pd.DataFrame(torch.cat(predictions).numpy(), columns=self.cell_type_markers.columns) @devices_dsp.dedent def train( @@ -264,16 +260,10 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/external/cellassign/_module.py b/scvi/external/cellassign/_module.py index e3b13cf5be..5f8c8c5f8e 100644 --- a/scvi/external/cellassign/_module.py +++ b/scvi/external/cellassign/_module.py @@ -182,8 +182,7 @@ def generative(self, x, size_factor, design_matrix=None): n_cells, self.n_genes, self.n_labels, B ) # (n, g, c, B) phi = ( # (n, g, c) - torch.sum(a * torch.exp(-b * torch.square(mu_ngcb - basis_means)), 3) - + LOWER_BOUND + torch.sum(a * torch.exp(-b * torch.square(mu_ngcb - basis_means)), 3) + LOWER_BOUND ) # compute gamma @@ -193,9 +192,7 @@ def generative(self, x, size_factor, design_matrix=None): theta_log = theta_log.expand(n_cells, self.n_labels) p_x_c = torch.sum(x_log_prob_raw, 1) + theta_log # (n, c) normalizer_over_c = torch.logsumexp(p_x_c, 1) - normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand( - n_cells, self.n_labels - ) + normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand(n_cells, self.n_labels) gamma = torch.exp(p_x_c - normalizer_over_c) # (n, c) return { @@ -226,13 +223,9 @@ def loss( # third term is log prob of prior terms in Q theta_log = F.log_softmax(self.theta_logit, dim=-1) theta_log_prior = Dirichlet(self.dirichlet_concentration) - theta_log_prob = -theta_log_prior.log_prob( - torch.exp(theta_log) + THETA_LOWER_BOUND - ) + theta_log_prob = -theta_log_prior.log_prob(torch.exp(theta_log) + THETA_LOWER_BOUND) prior_log_prob = theta_log_prob - delta_log_prior = Normal( - self.delta_log_mean, self.delta_log_log_scale.exp().sqrt() - ) + delta_log_prior = Normal(self.delta_log_mean, self.delta_log_log_scale.exp().sqrt()) delta_log_prob = torch.masked_select( delta_log_prior.log_prob(self.delta_log), (self.rho > 0) ) diff --git a/scvi/external/contrastivevi/_contrastive_data_splitting.py b/scvi/external/contrastivevi/_contrastive_data_splitting.py index 9597805b95..c86da8b8fb 100644 --- a/scvi/external/contrastivevi/_contrastive_data_splitting.py +++ b/scvi/external/contrastivevi/_contrastive_data_splitting.py @@ -99,14 +99,10 @@ def setup(self, stage: Optional[str] = None): self.background_train_idx = background_indices[ n_background_val : (n_background_val + n_background_train) ] - self.background_test_idx = background_indices[ - (n_background_val + n_background_train) : - ] + self.background_test_idx = background_indices[(n_background_val + n_background_train) :] self.target_val_idx = target_indices[:n_target_val] - self.target_train_idx = target_indices[ - n_target_val : (n_target_val + n_target_train) - ] + self.target_train_idx = target_indices[n_target_val : (n_target_val + n_target_train)] self.target_test_idx = target_indices[(n_target_val + n_target_train) :] self.val_idx = self.background_val_idx + self.target_val_idx diff --git a/scvi/external/contrastivevi/_contrastive_dataloader.py b/scvi/external/contrastivevi/_contrastive_dataloader.py index 41246c5add..0c88ee10a1 100644 --- a/scvi/external/contrastivevi/_contrastive_dataloader.py +++ b/scvi/external/contrastivevi/_contrastive_dataloader.py @@ -120,7 +120,5 @@ def __iter__(self): Will iter over the dataloader with the most data while cycling through the data in the other dataloader. """ - iter_list = [ - cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders - ] + iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders] return _ContrastiveIterator(background=iter_list[0], target=iter_list[1]) diff --git a/scvi/external/contrastivevi/_model.py b/scvi/external/contrastivevi/_model.py index 90a8418797..3d35e74542 100644 --- a/scvi/external/contrastivevi/_model.py +++ b/scvi/external/contrastivevi/_model.py @@ -95,9 +95,7 @@ def __init__( super().__init__(adata) n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None ) @@ -105,9 +103,7 @@ def __init__( library_log_means, library_log_vars = None, None if not use_observed_lib_size: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -203,9 +199,7 @@ def train( validation_size=validation_size, batch_size=batch_size, shuffle_set_split=shuffle_set_split, - distributed_sampler=use_distributed_sampler( - trainer_kwargs.get("strategy", None) - ), + distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), load_sparse_tensor=load_sparse_tensor, **datasplitter_kwargs, ) @@ -256,19 +250,11 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - NumericalObsField( - REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False - ), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -321,9 +307,7 @@ def get_latent_representation( for tensors in data_loader: x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - outputs = self.module._generic_inference( - x=x, batch_index=batch_index, n_samples=1 - ) + outputs = self.module._generic_inference(x=x, batch_index=batch_index, n_samples=1) if representation_kind == "background": latent_m = outputs["qz_m"] @@ -489,9 +473,7 @@ def get_normalized_expression( if return_numpy is None or return_numpy is False: genes = adata.var_names[gene_mask] samples = adata.obs_names[indices] - background_exprs = pd.DataFrame( - background_exprs, columns=genes, index=samples - ) + background_exprs = pd.DataFrame(background_exprs, columns=genes, index=samples) salient_exprs = pd.DataFrame(salient_exprs, columns=genes, index=samples) return {"background": background_exprs, "salient": salient_exprs} @@ -852,9 +834,7 @@ def get_latent_library_size( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) libraries = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/scvi/external/contrastivevi/_module.py b/scvi/external/contrastivevi/_module.py index 2a2e6a26ed..1d5e054e0c 100644 --- a/scvi/external/contrastivevi/_module.py +++ b/scvi/external/contrastivevi/_module.py @@ -85,12 +85,8 @@ def __init__( "If not using observed_lib_size, " "must provide library_log_means and library_log_vars." ) - self.register_buffer( - "library_log_means", torch.from_numpy(library_log_means).float() - ) - self.register_buffer( - "library_log_vars", torch.from_numpy(library_log_vars).float() - ) + self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float()) + self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float()) cat_list = [n_batch] # Background encoder. @@ -158,12 +154,8 @@ def _compute_local_library_params( log library sizes in the batch the cell corresponds to. """ n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear( - one_hot(batch_index, n_batch), self.library_log_means - ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) + local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means) + local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) return local_library_log_means, local_library_log_vars @staticmethod @@ -193,9 +185,7 @@ def _get_inference_input_from_concat_tensors( def _get_inference_input( self, concat_tensors: dict[str, dict[str, torch.Tensor]] ) -> dict[str, dict[str, torch.Tensor]]: - background = self._get_inference_input_from_concat_tensors( - concat_tensors, "background" - ) + background = self._get_inference_input_from_concat_tensors(concat_tensors, "background") target = self._get_inference_input_from_concat_tensors(concat_tensors, "target") # Ensure batch sizes are the same. min_batch_size = self._get_min_batch_size(concat_tensors) @@ -363,9 +353,7 @@ def generative( target_batch_size = target["z"].shape[batch_size_dim] generative_input = {} for key in ["z", "s", "library"]: - generative_input[key] = torch.cat( - [background[key], target[key]], dim=batch_size_dim - ) + generative_input[key] = torch.cat([background[key], target[key]], dim=batch_size_dim) generative_input["batch_index"] = torch.cat( [background["batch_index"], target["batch_index"]], dim=0 ) @@ -594,8 +582,7 @@ def loss( kl_local_no_warmup = kl_divergence_l weighted_kl_local = ( - kl_weight - * (self.wasserstein_penalty * wasserstein_loss + kl_local_for_warmup) + kl_weight * (self.wasserstein_penalty * wasserstein_loss + kl_local_for_warmup) + kl_local_no_warmup ) diff --git a/scvi/external/gimvi/_model.py b/scvi/external/gimvi/_model.py index 15c8ad6b9c..dc62bd7c8e 100644 --- a/scvi/external/gimvi/_model.py +++ b/scvi/external/gimvi/_model.py @@ -94,9 +94,7 @@ def __init__( self.adatas = [adata_seq, adata_spatial] self.adata_managers = { "seq": self._get_most_recent_anndata_manager(adata_seq, required=True), - "spatial": self._get_most_recent_anndata_manager( - adata_spatial, required=True - ), + "spatial": self._get_most_recent_anndata_manager(adata_spatial, required=True), } self.registries_ = [] for adm in self.adata_managers.values(): @@ -109,9 +107,7 @@ def __init__( if not set(spatial_var_names) <= set(seq_var_names): raise ValueError("spatial genes needs to be subset of seq genes") - spatial_gene_loc = [ - np.argwhere(seq_var_names == g)[0] for g in spatial_var_names - ] + spatial_gene_loc = [np.argwhere(seq_var_names == g)[0] for g in spatial_var_names] spatial_gene_loc = np.concatenate(spatial_gene_loc) gene_mappings = [slice(None), spatial_gene_loc] sum_stats = [adm.summary_stats for adm in self.adata_managers.values()] @@ -121,18 +117,14 @@ def __init__( adata_seq_n_batches = sum_stats[0]["n_batch"] adata_spatial_batch = adata_spatial.obs[ - self.adata_managers["spatial"] - .data_registry[REGISTRY_KEYS.BATCH_KEY] - .attr_key + self.adata_managers["spatial"].data_registry[REGISTRY_KEYS.BATCH_KEY].attr_key ] if np.min(adata_spatial_batch) == 0: # see #2446 # since we are combining datasets, we need to increment the batch_idx of one of the # datasets. we only need to do this once so we check if the min is 0 adata_spatial.obs[ - self.adata_managers["spatial"] - .data_registry[REGISTRY_KEYS.BATCH_KEY] - .attr_key + self.adata_managers["spatial"].data_registry[REGISTRY_KEYS.BATCH_KEY].attr_key ] += adata_seq_n_batches n_batches = sum(s["n_batch"] for s in sum_stats) @@ -435,9 +427,7 @@ def save( seq_save_path = os.path.join(dir_path, f"{file_name_prefix}adata_seq.h5ad") seq_adata.write(seq_save_path) - spatial_save_path = os.path.join( - dir_path, f"{file_name_prefix}adata_spatial.h5ad" - ) + spatial_save_path = os.path.join(dir_path, f"{file_name_prefix}adata_spatial.h5ad") spatial_adata.write(spatial_save_path) # save the model state dict and the trainer state dict only @@ -547,9 +537,7 @@ def load( registries = attr_dict.pop("registries_") for adata, registry in zip(adatas, registries): if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError( - "It appears you are loading a model from a different class." - ) + raise ValueError("It appears you are loading a model from a different class.") if _SETUP_ARGS_KEY not in registry: raise ValueError( @@ -557,9 +545,7 @@ def load( "Cannot load the original setup." ) - cls.setup_anndata( - adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] - ) + cls.setup_anndata(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) # get the parameters for the class init signature init_params = attr_dict.pop("init_params_") @@ -574,9 +560,7 @@ def load( kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} else: # grab all the parameters except for kwargs (is a dict) - non_kwargs = { - k: v for k, v in init_params.items() if not isinstance(v, dict) - } + non_kwargs = {k: v for k, v in init_params.items() if not isinstance(v, dict)} kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs) @@ -678,9 +662,7 @@ def setup_anndata( CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -690,9 +672,7 @@ class TrainDL(DataLoader): def __init__(self, data_loader_list, **kwargs): self.data_loader_list = data_loader_list - self.largest_train_dl_idx = np.argmax( - [len(dl.indices) for dl in data_loader_list] - ) + self.largest_train_dl_idx = np.argmax([len(dl.indices) for dl in data_loader_list]) self.largest_dl = self.data_loader_list[self.largest_train_dl_idx] super().__init__(self.largest_dl, **kwargs) diff --git a/scvi/external/gimvi/_module.py b/scvi/external/gimvi/_module.py index 2a9190ec26..37d8b6695d 100644 --- a/scvi/external/gimvi/_module.py +++ b/scvi/external/gimvi/_module.py @@ -354,16 +354,12 @@ def reconstruction_loss( reconstruction_loss = None if self.gene_likelihoods[mode] == "zinb": reconstruction_loss = ( - -ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout - ) + -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) .log_prob(x) .sum(dim=-1) ) elif self.gene_likelihoods[mode] == "nb": - reconstruction_loss = ( - -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) - ) + reconstruction_loss = -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) elif self.gene_likelihoods[mode] == "poisson": reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss @@ -417,9 +413,9 @@ def generative( px_r = self.px_r.view(1, self.px_r.size(0)) px_r = torch.exp(px_r) - px_scale = px_scale / torch.sum( - px_scale[:, self.indices_mappings[mode]], dim=1 - ).view(-1, 1) + px_scale = px_scale / torch.sum(px_scale[:, self.indices_mappings[mode]], dim=1).view( + -1, 1 + ) px_rate = px_scale * torch.exp(library) return { @@ -492,9 +488,7 @@ def loss( local_library_log_means = F.linear( one_hot(batch_index, self.n_batch), library_log_means ) - local_library_log_vars = F.linear( - one_hot(batch_index, self.n_batch), library_log_vars - ) + local_library_log_vars = F.linear(one_hot(batch_index, self.n_batch), library_log_vars) kl_divergence_l = kl( ql, Normal(local_library_log_means, local_library_log_vars.sqrt()), @@ -506,6 +500,4 @@ def loss( loss = torch.mean(reconstruction_loss + kl_weight * kl_local) * x.size(0) - return LossOutput( - loss=loss, reconstruction_loss=reconstruction_loss, kl_local=kl_local - ) + return LossOutput(loss=loss, reconstruction_loss=reconstruction_loss, kl_local=kl_local) diff --git a/scvi/external/gimvi/_task.py b/scvi/external/gimvi/_task.py index 943dfd0cfd..9868f18eb4 100644 --- a/scvi/external/gimvi/_task.py +++ b/scvi/external/gimvi/_task.py @@ -91,8 +91,7 @@ def training_step(self, batch, batch_idx): zs.append(outputs["z"]) batch_tensor = [ - torch.zeros((z.shape[0], 1), device=z.device) + i - for i, z in enumerate(zs) + torch.zeros((z.shape[0], 1), device=z.device) + i for i, z in enumerate(zs) ] loss = self.loss_adversarial_classifier( torch.cat(zs).detach(), torch.cat(batch_tensor), True diff --git a/scvi/external/gimvi/_utils.py b/scvi/external/gimvi/_utils.py index 1874f4b320..e1e0b3c0f6 100644 --- a/scvi/external/gimvi/_utils.py +++ b/scvi/external/gimvi/_utils.py @@ -18,9 +18,7 @@ def _load_legacy_saved_gimvi_files( model_path = os.path.join(dir_path, f"{file_name_prefix}model_params.pt") setup_dict_path = os.path.join(dir_path, f"{file_name_prefix}attr.pkl") seq_var_names_path = os.path.join(dir_path, f"{file_name_prefix}var_names_seq.csv") - spatial_var_names_path = os.path.join( - dir_path, f"{file_name_prefix}var_names_spatial.csv" - ) + spatial_var_names_path = os.path.join(dir_path, f"{file_name_prefix}var_names_spatial.csv") model_state_dict = torch.load(model_path, map_location="cpu") @@ -36,19 +34,13 @@ def _load_legacy_saved_gimvi_files( if os.path.exists(seq_data_path): adata_seq = read(seq_data_path) elif not os.path.exists(seq_data_path): - raise ValueError( - "Save path contains no saved anndata and no adata was passed." - ) + raise ValueError("Save path contains no saved anndata and no adata was passed.") if load_spatial_adata: - spatial_data_path = os.path.join( - dir_path, f"{file_name_prefix}adata_spatial.h5ad" - ) + spatial_data_path = os.path.join(dir_path, f"{file_name_prefix}adata_spatial.h5ad") if os.path.exists(spatial_data_path): adata_spatial = read(spatial_data_path) elif not os.path.exists(spatial_data_path): - raise ValueError( - "Save path contains no saved anndata and no adata was passed." - ) + raise ValueError("Save path contains no saved anndata and no adata was passed.") return ( model_state_dict, @@ -67,9 +59,7 @@ def _load_saved_gimvi_files( prefix: Optional[str] = None, map_location: Optional[Literal["cpu", "cuda"]] = None, backup_url: Optional[str] = None, -) -> tuple[ - dict, dict, np.ndarray, np.ndarray, dict, Optional[AnnData], Optional[AnnData] -]: +) -> tuple[dict, dict, np.ndarray, np.ndarray, dict, Optional[AnnData], Optional[AnnData]]: file_name_prefix = prefix or "" model_file_name = f"{file_name_prefix}model.pt" @@ -95,19 +85,13 @@ def _load_saved_gimvi_files( if os.path.exists(seq_data_path): adata_seq = read(seq_data_path) elif not os.path.exists(seq_data_path): - raise ValueError( - "Save path contains no saved anndata and no adata was passed." - ) + raise ValueError("Save path contains no saved anndata and no adata was passed.") if load_spatial_adata: - spatial_data_path = os.path.join( - dir_path, f"{file_name_prefix}adata_spatial.h5ad" - ) + spatial_data_path = os.path.join(dir_path, f"{file_name_prefix}adata_spatial.h5ad") if os.path.exists(spatial_data_path): adata_spatial = read(spatial_data_path) elif not os.path.exists(spatial_data_path): - raise ValueError( - "Save path contains no saved anndata and no adata was passed." - ) + raise ValueError("Save path contains no saved anndata and no adata was passed.") return ( attr_dict, diff --git a/scvi/external/poissonvi/_model.py b/scvi/external/poissonvi/_model.py index b9d2ff9f57..eef36a6a85 100644 --- a/scvi/external/poissonvi/_model.py +++ b/scvi/external/poissonvi/_model.py @@ -89,14 +89,10 @@ def __init__( ) n_batch = self.summary_stats.n_batch - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - ) + use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None if use_size_factor_key is not None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self._module_cls = VAE @@ -231,9 +227,7 @@ def get_accessibility_estimates( ) if not normalize_regions: # reset region_factors (bias) - self.module.decoder.px_scale_decoder[-2].bias = torch.nn.Parameter( - region_factors - ) + self.module.decoder.px_scale_decoder[-2].bias = torch.nn.Parameter(region_factors) return accs def get_normalized_expression( @@ -337,9 +331,7 @@ def differential_accessibility( weights=weights, **importance_weighting_kwargs, ) - representation_fn = ( - self.get_latent_representation if filter_outlier_cells else None - ) + representation_fn = self.get_latent_representation if filter_outlier_cells else None if two_sided: @@ -422,18 +414,10 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, check_fragment_counts=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - NumericalObsField( - REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False - ), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/external/scar/_model.py b/scvi/external/scar/_model.py index abc70fe9e9..950d183ddb 100644 --- a/scvi/external/scar/_model.py +++ b/scvi/external/scar/_model.py @@ -94,14 +94,10 @@ def __init__( super().__init__(adata) n_batch = self.summary_stats.n_batch - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - ) + use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None if not use_size_factor_key: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) # self.summary_stats provides information about anndata dimensions and other tensor info if not torch.is_tensor(ambient_profile): @@ -118,9 +114,7 @@ def __init__( raise TypeError( f"Expecting str / np.array / None / pd.DataFrame, but get a {type(ambient_profile)}" ) - ambient_profile = ( - torch.from_numpy(np.asarray(ambient_profile)).float().reshape(1, -1) - ) + ambient_profile = torch.from_numpy(np.asarray(ambient_profile)).float().reshape(1, -1) self.module = SCAR_VAE( ambient_profile=ambient_profile, @@ -173,13 +167,9 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), - NumericalObsField( - REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False - ), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -243,9 +233,7 @@ def get_ambient_profile( try: count_batch = raw_adata[batch_idx == b].X.astype(int).A except MemoryError as err: - raise MemoryError( - "use more batches by setting a higher n_batch" - ) from err + raise MemoryError("use more batches by setting a higher n_batch") from err log_prob_batch = Multinomial( probs=torch.tensor(ambient_prof), validate_args=False ).log_prob(torch.Tensor(count_batch)) @@ -254,9 +242,7 @@ def get_ambient_profile( raw_adata.obs["log_prob"] = log_prob raw_adata.obs["droplets"] = "other droplets" # cell-containing droplets - raw_adata.obs.loc[ - raw_adata.obs_names.isin(adata.obs_names), "droplets" - ] = "cells" + raw_adata.obs.loc[raw_adata.obs_names.isin(adata.obs_names), "droplets"] = "cells" # identify cell-free droplets raw_adata.obs["droplets"] = raw_adata.obs["droplets"].mask( raw_adata.obs["log_prob"] >= np.log(prob) * raw_adata.shape[1], @@ -318,9 +304,7 @@ def get_denoised_counts( px_scale = generative_outputs["px"].scale expected_counts = total_count_per_cell * px_scale.cpu() - b = torch.distributions.Binomial( - probs=expected_counts - expected_counts.floor() - ) + b = torch.distributions.Binomial(probs=expected_counts - expected_counts.floor()) expected_counts = expected_counts.floor() + b.sample() if n_samples > 1: diff --git a/scvi/external/scar/_module.py b/scvi/external/scar/_module.py index 7ebbfd8c65..0b86e32a6b 100644 --- a/scvi/external/scar/_module.py +++ b/scvi/external/scar/_module.py @@ -334,9 +334,7 @@ def loss( ): """Compute the loss function for the model.""" x = tensors[REGISTRY_KEYS.X_KEY] - kl_divergence_z = kl(inference_outputs["qz"], generative_outputs["pz"]).sum( - dim=1 - ) + kl_divergence_z = kl(inference_outputs["qz"], generative_outputs["pz"]).sum(dim=1) if not self.use_observed_lib_size: kl_divergence_l = kl( inference_outputs["ql"], diff --git a/scvi/external/scbasset/_model.py b/scvi/external/scbasset/_model.py index 1363a89071..83f19c789b 100644 --- a/scvi/external/scbasset/_model.py +++ b/scvi/external/scbasset/_model.py @@ -157,9 +157,7 @@ def train( """ custom_plan_kwargs = { "optimizer": "Custom", - "optimizer_creator": lambda p: torch.optim.Adam( - p, lr=lr, betas=(0.95, 0.9995) - ), + "optimizer_creator": lambda p: torch.optim.Adam(p, lr=lr, betas=(0.95, 0.9995)), } if plan_kwargs is not None: custom_plan_kwargs.update(plan_kwargs) @@ -188,9 +186,7 @@ def train( "early_stopping_min_delta": early_stopping_min_delta, } for k, v in es.items(): - trainer_kwargs[k] = ( - v if k not in trainer_kwargs.keys() else trainer_kwargs[k] - ) + trainer_kwargs[k] = v if k not in trainer_kwargs.keys() else trainer_kwargs[k] runner = TrainRunner( self, training_plan=training_plan, @@ -348,9 +344,7 @@ def get_tf_activity( # download if none is found # `motif_seqs` is a List of str sequences where each char is in "ACTGN". # `bg_seqs` is the same, but background sequences rather than motif injected - motif_seqs, bg_seqs = self._get_motif_library( - tf=tf, genome=genome, motif_dir=motif_dir - ) + motif_seqs, bg_seqs = self._get_motif_library(tf=tf, genome=genome, motif_dir=motif_dir) # SCBASSET.module.inference(...) takes `dna_code: torch.Tensor` as input # where `dna_code` is [batch_size, seq_length] and each value is [0,1,2,3] @@ -364,9 +358,9 @@ def get_tf_activity( # NOTE: SCBASSET uses a fixed size of 1344 bp. If motifs from a different source # than the above are used, we may need to truncate to match the model size. # We should be cautious about doing this, so we throw a warning to the user. - model_input_size = self.adata_manager.get_from_registry( - REGISTRY_KEYS.DNA_CODE_KEY - ).shape[1] + model_input_size = self.adata_manager.get_from_registry(REGISTRY_KEYS.DNA_CODE_KEY).shape[ + 1 + ] n_diff = motif_codes.shape[1] - model_input_size if n_diff > 0: n_cut = n_diff // 2 @@ -441,8 +435,6 @@ def setup_anndata( ObsmField(REGISTRY_KEYS.DNA_CODE_KEY, dna_code_key, is_count_data=True), CategoricalVarField(REGISTRY_KEYS.BATCH_KEY, batch_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/external/scbasset/_module.py b/scvi/external/scbasset/_module.py index 66b32c12ce..4923be501a 100644 --- a/scvi/external/scbasset/_module.py +++ b/scvi/external/scbasset/_module.py @@ -253,9 +253,7 @@ def __init__( tower_layers = [] curr_n_filters = n_filters_init for i in range(n_repeat_blocks_tower): - new_n_filters = ( - _round(curr_n_filters * filters_mult) if i > 0 else curr_n_filters - ) + new_n_filters = _round(curr_n_filters * filters_mult) if i > 0 else curr_n_filters tower_layers.append( _ConvBlock( in_channels=curr_n_filters, diff --git a/scvi/external/solo/_model.py b/scvi/external/solo/_model.py index 181a274111..51153c3abe 100644 --- a/scvi/external/solo/_model.py +++ b/scvi/external/solo/_model.py @@ -135,12 +135,8 @@ def from_scvi_model( """ _validate_scvi_model(scvi_model, restrict_to_batch=restrict_to_batch) orig_adata_manager = scvi_model.adata_manager - orig_batch_key_registry = orig_adata_manager.get_state_registry( - REGISTRY_KEYS.BATCH_KEY - ) - orig_labels_key_registry = orig_adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + orig_batch_key_registry = orig_adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY) + orig_labels_key_registry = orig_adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) orig_batch_key = orig_batch_key_registry.original_key orig_labels_key = orig_labels_key_registry.original_key @@ -221,9 +217,7 @@ def from_scvi_model( latent_adata.obs[LABELS_KEY] = "singlet" orig_obs_names = adata.obs_names latent_adata.obs_names = ( - orig_obs_names[batch_indices] - if batch_indices is not None - else orig_obs_names + orig_obs_names[batch_indices] if batch_indices is not None else orig_obs_names ) logger.info("Creating doublets, preparing SOLO model.") @@ -391,9 +385,7 @@ def train( return runner() @torch.inference_mode() - def predict( - self, soft: bool = True, include_simulated_doublets: bool = False - ) -> pd.DataFrame: + def predict(self, soft: bool = True, include_simulated_doublets: bool = False) -> pd.DataFrame: """Return doublet predictions. Parameters @@ -430,9 +422,7 @@ def auto_forward(module, x): preds = y_pred[mask] - cols = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ).categorical_mapping + cols = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).categorical_mapping preds_df = pd.DataFrame(preds, columns=cols, index=self.adata.obs_names[mask]) if not soft: @@ -461,9 +451,7 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=False), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/external/stereoscope/_model.py b/scvi/external/stereoscope/_model.py index 9ad89af04f..b6b8239d57 100644 --- a/scvi/external/stereoscope/_model.py +++ b/scvi/external/stereoscope/_model.py @@ -59,7 +59,9 @@ def __init__( n_labels=self.n_labels, **model_kwargs, ) - self._model_summary_string = f"RNADeconv Model with params: \nn_genes: {self.n_genes}, n_labels: {self.n_labels}" + self._model_summary_string = ( + f"RNADeconv Model with params: \nn_genes: {self.n_genes}, n_labels: {self.n_labels}" + ) self.init_params_ = self._get_init_params(locals()) @devices_dsp.dedent @@ -147,9 +149,7 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -207,9 +207,7 @@ def __init__( prior_weight=prior_weight, **model_kwargs, ) - self._model_summary_string = ( - f"RNADeconv Model with params: \nn_spots: {st_adata.n_obs}" - ) + self._model_summary_string = f"RNADeconv Model with params: \nn_spots: {st_adata.n_obs}" self.cell_type_mapping = cell_type_mapping self.init_params_ = self._get_init_params(locals()) @@ -367,8 +365,6 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/external/stereoscope/_module.py b/scvi/external/stereoscope/_module.py index 40f1eb4bac..bfc48fb089 100644 --- a/scvi/external/stereoscope/_module.py +++ b/scvi/external/stereoscope/_module.py @@ -76,9 +76,7 @@ def inference(self): @auto_move_data def generative(self, x, y): """Simply build the negative binomial parameters for every cell in the minibatch.""" - px_scale = torch.nn.functional.softplus(self.W)[ - :, y.long().ravel() - ].T # cells per gene + px_scale = torch.nn.functional.softplus(self.W)[:, y.long().ravel()].T # cells per gene library = torch.sum(x, dim=1, keepdim=True) px_rate = library * px_scale scaling_factor = self.ct_weight[y.long().ravel()] @@ -164,9 +162,7 @@ def __init__( def get_proportions(self, keep_noise=False) -> np.ndarray: """Returns the loadings.""" # get estimated unadjusted proportions - res = ( - torch.nn.functional.softplus(self.V).cpu().numpy().T - ) # n_spots, n_labels + 1 + res = torch.nn.functional.softplus(self.V).cpu().numpy().T # n_spots, n_labels + 1 # remove dummy cell type proportion values if not keep_noise: res = res[:, :-1] @@ -204,9 +200,7 @@ def generative(self, x, ind_x): ) # n_genes, n_labels + 1 # subsample observations v_ind = v[:, ind_x] # labels + 1, batch_size - px_rate = torch.transpose( - torch.matmul(r_hat, v_ind), 0, 1 - ) # batch_size, n_genes + px_rate = torch.transpose(torch.matmul(r_hat, v_ind), 0, 1) # batch_size, n_genes return {"px_o": self.px_o, "px_rate": px_rate, "eta": self.eta} diff --git a/scvi/external/tangram/_model.py b/scvi/external/tangram/_model.py index c4d26f9de6..920678dbbf 100644 --- a/scvi/external/tangram/_model.py +++ b/scvi/external/tangram/_model.py @@ -74,26 +74,16 @@ def __init__( **model_kwargs, ): super().__init__(sc_adata) - self.n_obs_sc = self.adata_manager.get_from_registry( - TANGRAM_REGISTRY_KEYS.SC_KEY - ).shape[0] - self.n_obs_sp = self.adata_manager.get_from_registry( - TANGRAM_REGISTRY_KEYS.SP_KEY - ).shape[0] + self.n_obs_sc = self.adata_manager.get_from_registry(TANGRAM_REGISTRY_KEYS.SC_KEY).shape[0] + self.n_obs_sp = self.adata_manager.get_from_registry(TANGRAM_REGISTRY_KEYS.SP_KEY).shape[0] if constrained and target_count is None: - raise ValueError( - "Please specify `target_count` when using constrained Tangram." - ) + raise ValueError("Please specify `target_count` when using constrained Tangram.") has_density_prior = not self.adata_manager.fields[-1].is_empty if has_density_prior: - prior = self.adata_manager.get_from_registry( - TANGRAM_REGISTRY_KEYS.DENSITY_KEY - ) + prior = self.adata_manager.get_from_registry(TANGRAM_REGISTRY_KEYS.DENSITY_KEY) if np.abs(prior.ravel().sum() - 1) > 1e-3: - raise ValueError( - "Density prior must sum to 1. Please normalize the prior." - ) + raise ValueError("Density prior must sum to 1. Please normalize the prior.") self.module = TangramMapper( n_obs_sc=self.n_obs_sc, @@ -113,9 +103,7 @@ def get_mapper_matrix(self) -> np.ndarray: ------- Mapping matrix of shape (n_obs_sp, n_obs_sc) """ - return jax.device_get( - jax.nn.softmax(self.module.params["mapper_unconstrained"], axis=1) - ) + return jax.device_get(jax.nn.softmax(self.module.params["mapper_unconstrained"], axis=1)) @devices_dsp.dedent def train( @@ -247,9 +235,7 @@ def setup_mudata( adata_manager = AnnDataManager( fields=mudata_fields, setup_method_args=setup_method_args, - validation_checks=AnnDataManagerValidationCheck( - check_fully_paired_mudata=False - ), + validation_checks=AnnDataManagerValidationCheck(check_fully_paired_mudata=False), ) adata_manager.register_fields(mdata, **kwargs) sc_state = adata_manager.get_state_registry(TANGRAM_REGISTRY_KEYS.SC_KEY) @@ -268,9 +254,7 @@ def setup_mudata( @classmethod def setup_anndata(cls): """Not implemented, use `setup_mudata`.""" - raise NotImplementedError( - "Use `setup_mudata` to setup a MuData object for training." - ) + raise NotImplementedError("Use `setup_mudata` to setup a MuData object for training.") def _get_tensor_dict( self, @@ -330,9 +314,7 @@ def project_cell_annotations( ) @staticmethod - def project_genes( - adata_sc: AnnData, adata_sp: AnnData, mapper: np.ndarray - ) -> AnnData: + def project_genes(adata_sc: AnnData, adata_sp: AnnData, mapper: np.ndarray) -> AnnData: """Project gene expression to spatial data. Parameters diff --git a/scvi/hub/_metadata.py b/scvi/hub/_metadata.py index d852f2dc9b..fbbc654997 100644 --- a/scvi/hub/_metadata.py +++ b/scvi/hub/_metadata.py @@ -212,12 +212,8 @@ def from_dir( model_cls_name = registry["model_name"] scvi_version = registry["scvi_version"] model_setup_anndata_args = registry["setup_args"] - model_summary_stats = dict( - AnnDataManager._get_summary_stats_from_registry(registry) - ) - model_data_registry = dict( - AnnDataManager._get_data_registry_from_registry(registry) - ) + model_summary_stats = dict(AnnDataManager._get_summary_stats_from_registry(registry)) + model_data_registry = dict(AnnDataManager._get_data_registry_from_registry(registry)) # get `is_minified` from the param if it is given, else from adata if it on disk, else set it to None is_minified = data_is_minified @@ -269,13 +265,9 @@ def _to_model_card(self) -> ModelCard: kwargs = self.model_init_params["kwargs"] else: non_kwargs = { - k: v - for k, v in self.model_init_params.items() - if not isinstance(v, dict) - } - kwargs = { - k: v for k, v in self.model_init_params.items() if isinstance(v, dict) + k: v for k, v in self.model_init_params.items() if not isinstance(v, dict) } + kwargs = {k: v for k, v in self.model_init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} # kwargs and non_kwargs keys should be disjoint but if not, we'll just use the original model_init_params if len(set(kwargs.keys()).intersection(set(non_kwargs.keys()))) == 0: diff --git a/scvi/hub/_model.py b/scvi/hub/_model.py index 0e1d6dbaf5..6b6382eb52 100644 --- a/scvi/hub/_model.py +++ b/scvi/hub/_model.py @@ -485,8 +485,6 @@ def read_large_training_adata(self) -> None: else: _download(training_data_url, dn, fn) logger.info("Reading large training data...") - self._large_training_adata = anndata.read_h5ad( - self._large_training_adata_path - ) + self._large_training_adata = anndata.read_h5ad(self._large_training_adata_path) else: logger.info("No training_data_url found in the model card. Skipping...") diff --git a/scvi/hub/_url.py b/scvi/hub/_url.py index 9d85907bf9..68fbdbdef2 100644 --- a/scvi/hub/_url.py +++ b/scvi/hub/_url.py @@ -3,9 +3,7 @@ import requests -def validate_url( - url: str, error_format: bool = False, error_response: bool = False -) -> bool: +def validate_url(url: str, error_format: bool = False, error_response: bool = False) -> bool: """Validates a URL. Source: https://stackoverflow.com/questions/7160737/how-to-validate-a-url-in-python-malformed-or-not diff --git a/scvi/model/_amortizedlda.py b/scvi/model/_amortizedlda.py index 4dfd579ee2..80a195935b 100644 --- a/scvi/model/_amortizedlda.py +++ b/scvi/model/_amortizedlda.py @@ -123,9 +123,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -188,9 +186,7 @@ def get_latent_representation( transformed_xs = [] for tensors in dl: x = tensors[REGISTRY_KEYS.X_KEY] - transformed_xs.append( - self.module.get_topic_distribution(x, n_samples=n_samples) - ) + transformed_xs.append(self.module.get_topic_distribution(x, n_samples=n_samples)) transformed_x = torch.cat(transformed_xs).numpy() return pd.DataFrame( @@ -267,6 +263,5 @@ def get_perplexity( total_counts = sum(tensors[REGISTRY_KEYS.X_KEY].sum().item() for tensors in dl) return np.exp( - self.get_elbo(adata=adata, indices=indices, batch_size=batch_size) - / total_counts + self.get_elbo(adata=adata, indices=indices, batch_size=batch_size) / total_counts ) diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index c1cb1a4ddd..f1f9627910 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -112,9 +112,7 @@ def __init__( self.use_observed_lib_size = use_observed_lib_size n_batch = self.summary_stats.n_batch - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -188,9 +186,7 @@ def get_marginal_ll( if indices is None: indices = np.arange(adata.n_obs) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) log_lkl = 0 to_sum = torch.zeros((n_mc_samples,)).to(self.device) @@ -290,8 +286,6 @@ def setup_anndata( CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index 4e8c7a42b0..22f712431d 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -134,9 +134,7 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) mp_vprior = np.zeros((self.summary_stats.n_labels, p)) - labels_state_registry = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) key = labels_state_registry.original_key mapping = labels_state_registry.categorical_mapping @@ -285,8 +283,6 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_destvi.py b/scvi/model/_destvi.py index 44c0b105ab..16dc080137 100644 --- a/scvi/model/_destvi.py +++ b/scvi/model/_destvi.py @@ -184,9 +184,7 @@ def get_proportions( column_names = np.append(column_names, "noise_term") if self.module.amortization in ["both", "proportion"]: - stdl = self._make_data_loader( - adata=self.adata, indices=indices, batch_size=batch_size - ) + stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) prop_ = [] for tensors in stdl: generative_inputs = self.module._get_generative_input(tensors, None) @@ -233,9 +231,7 @@ def get_gamma( index_names = self.adata.obs.index if self.module.amortization in ["both", "latent"]: - stdl = self._make_data_loader( - adata=self.adata, indices=indices, batch_size=batch_size - ) + stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) gamma_ = [] for tensors in stdl: generative_inputs = self.module._get_generative_input(tensors, None) @@ -257,9 +253,7 @@ def get_gamma( else: res = {} for i, ct in enumerate(self.cell_type_mapping): - res[ct] = pd.DataFrame( - data=data[:, :, i], columns=column_names, index=index_names - ) + res[ct] = pd.DataFrame(data=data[:, :, i], columns=column_names, index=index_names) return res def get_scale_for_ct( @@ -396,8 +390,6 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_jaxscvi.py b/scvi/model/_jaxscvi.py index 4b6701df0e..f703caf63d 100644 --- a/scvi/model/_jaxscvi.py +++ b/scvi/model/_jaxscvi.py @@ -99,9 +99,7 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_linear_scvi.py b/scvi/model/_linear_scvi.py index 48d04b2ebe..649c2d9a33 100644 --- a/scvi/model/_linear_scvi.py +++ b/scvi/model/_linear_scvi.py @@ -85,9 +85,7 @@ def __init__( super().__init__(adata) n_batch = self.summary_stats.n_batch - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -126,9 +124,7 @@ def get_loadings(self) -> pd.DataFrame: """ cols = [f"Z_{i}" for i in range(self.n_latent)] var_names = self.adata.var_names - loadings = pd.DataFrame( - self.module.get_loadings(), index=var_names, columns=cols - ) + loadings = pd.DataFrame(self.module.get_loadings(), index=var_names, columns=cols) return loadings @@ -157,8 +153,6 @@ def setup_anndata( CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_metrics.py b/scvi/model/_metrics.py index a76d721d6b..e808211821 100644 --- a/scvi/model/_metrics.py +++ b/scvi/model/_metrics.py @@ -28,9 +28,7 @@ def nearest_neighbor_overlap(x1, x2, k=100): kmatrix_2 = nne.kneighbors_graph(x2) - scipy.sparse.identity(n_samples) # 1 - spearman correlation from knn graphs - spearman_correlation = scipy.stats.spearmanr( - kmatrix_1.A.flatten(), kmatrix_2.A.flatten() - )[0] + spearman_correlation = scipy.stats.spearmanr(kmatrix_1.A.flatten(), kmatrix_2.A.flatten())[0] # 2 - fold enrichment set_1 = set(np.where(kmatrix_1.A.flatten() == 1)[0]) set_2 = set(np.where(kmatrix_2.A.flatten() == 1)[0]) diff --git a/scvi/model/_multivi.py b/scvi/model/_multivi.py index ec1bd46b39..2ad574029c 100644 --- a/scvi/model/_multivi.py +++ b/scvi/model/_multivi.py @@ -155,25 +155,19 @@ def __init__( deeply_inject_covariates: bool = False, encode_covariates: bool = False, fully_paired: bool = False, - protein_dispersion: Literal[ - "protein", "protein-batch", "protein-label" - ] = "protein", + protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", **model_kwargs, ): super().__init__(adata) prior_mean, prior_scale = None, None n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else [] ) - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - ) + use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry if "n_proteins" in self.summary_stats: n_proteins = self.summary_stats.n_proteins @@ -330,9 +324,7 @@ def train( if save_best: if "callbacks" not in kwargs.keys(): kwargs["callbacks"] = [] - kwargs["callbacks"].append( - SaveBestState(monitor="reconstruction_loss_validation") - ) + kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation")) data_splitter = self._data_splitter_cls( self.adata_manager, @@ -383,9 +375,7 @@ def get_library_size_factors( """ self._check_adata_modality_weights(adata) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) lib_exp = [] lib_acc = [] @@ -444,9 +434,7 @@ def get_latent_representation( self._check_adata_modality_weights(adata) keys = {"z": "z", "qz_m": "qz_m", "qz_v": "qz_v"} if self.fully_paired and modality != "joint": - raise RuntimeError( - "A fully paired model only has a joint latent representation." - ) + raise RuntimeError("A fully paired model only has a joint latent representation.") if not self.fully_paired and modality != "joint": if modality == "expression": keys = {"z": "z_expr", "qz_m": "qzm_expr", "qz_v": "qzv_expr"} @@ -460,9 +448,7 @@ def get_latent_representation( ) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) @@ -543,17 +529,13 @@ def get_accessibility_estimates( indices = np.arange(adata.n_obs) if n_samples_overall is not None: indices = np.random.choice(indices, n_samples_overall) - post = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) transform_batch = _get_batch_code_from_category(adata_manager, transform_batch) if region_list is None: region_mask = slice(None) else: - region_mask = [ - region in region_list for region in adata.var_names[self.n_genes :] - ] + region_mask = [region in region_list for region in adata.var_names[self.n_genes :]] if threshold is not None and (threshold < 0 or threshold > 1): raise ValueError("the provided threshold must be between 0 and 1") @@ -661,9 +643,7 @@ def get_normalized_expression( indices = np.arange(adata.n_obs) if n_samples_overall is not None: indices = np.random.choice(indices, n_samples_overall) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) transform_batch = _get_batch_code_from_category(adata_manager, transform_batch) @@ -679,9 +659,7 @@ def get_normalized_expression( for batch in transform_batch: if batch is not None: batch_indices = tensors[REGISTRY_KEYS.BATCH_KEY] - tensors[REGISTRY_KEYS.BATCH_KEY] = ( - torch.ones_like(batch_indices) * batch - ) + tensors[REGISTRY_KEYS.BATCH_KEY] = torch.ones_like(batch_indices) * batch _, generative_outputs = self.module.forward( tensors=tensors, inference_kwargs={"n_samples": n_samples}, @@ -991,9 +969,7 @@ def get_protein_foreground_probability( Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. """ adata = self._validate_anndata(adata) - post = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) if protein_list is None: protein_mask = slice(None) @@ -1017,9 +993,7 @@ def get_protein_foreground_probability( if not isinstance(transform_batch, IterableClass): transform_batch = [transform_batch] - transform_batch = _get_batch_code_from_category( - self.adata_manager, transform_batch - ) + transform_batch = _get_batch_code_from_category(self.adata_manager, transform_batch) for tensors in post: y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] py_mixing = torch.zeros_like(y[..., protein_mask]) @@ -1102,15 +1076,9 @@ def setup_anndata( batch_field, CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - NumericalObsField( - REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False - ), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] if protein_expression_obsm_key is not None: @@ -1125,9 +1093,7 @@ def setup_anndata( ) ) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_peakvi.py b/scvi/model/_peakvi.py index a106848d57..17503537f6 100644 --- a/scvi/model/_peakvi.py +++ b/scvi/model/_peakvi.py @@ -104,9 +104,7 @@ def __init__( super().__init__(adata) n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else [] ) @@ -234,9 +232,7 @@ def train( if save_best: if "callbacks" not in kwargs.keys(): kwargs["callbacks"] = [] - kwargs["callbacks"].append( - SaveBestState(monitor="reconstruction_loss_validation") - ) + kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation")) super().train( max_epochs=max_epochs, @@ -279,9 +275,7 @@ def get_library_size_factors( Library size factor for expression and accessibility """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) library_sizes = [] for tensors in scdl: @@ -362,9 +356,7 @@ def get_accessibility_estimates( indices = np.arange(adata.n_obs) if n_samples_overall is not None: indices = np.random.choice(indices, n_samples_overall) - post = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) transform_batch = _get_batch_code_from_category(adata_manager, transform_batch) if region_list is None: @@ -581,15 +573,9 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index 35838b4509..a1c1b0ee62 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -127,22 +127,16 @@ def __init__( # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None ) n_batch = self.summary_stats.n_batch - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - ) + use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None if not use_size_factor_key and self.minified_data_type is None: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -209,9 +203,7 @@ def from_scvi_model( scanvi_kwargs kwargs for scANVI model """ - scvi_model._check_if_trained( - message="Passed in scvi model hasn't been trained yet." - ) + scvi_model._check_if_trained(message="Passed in scvi model hasn't been trained yet.") scanvi_kwargs = dict(scanvi_kwargs) init_params = scvi_model.init_params_ @@ -237,9 +229,7 @@ def from_scvi_model( adata = scvi_model.adata else: if _is_minified(adata): - raise ValueError( - "Please provide a non-minified `adata` to initialize scanvi." - ) + raise ValueError("Please provide a non-minified `adata` to initialize scanvi.") # validate new anndata against old model scvi_model._validate_anndata(adata) @@ -265,9 +255,7 @@ def from_scvi_model( def _set_indices_and_labels(self): """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self.unlabeled_category_ = labels_state_registry.unlabeled_category @@ -431,9 +419,7 @@ def train( batch_size=batch_size, **datasplitter_kwargs, ) - training_plan = self._training_plan_cls( - self.module, self.n_labels, **plan_kwargs - ) + training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) if "callbacks" in trainer_kwargs.keys(): trainer_kwargs["callbacks"] + [sampler_callback] else: @@ -482,26 +468,16 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - LabelsWithUnlabeledObsField( - REGISTRY_KEYS.LABELS_KEY, labels_key, unlabeled_category - ), - NumericalObsField( - REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False - ), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + LabelsWithUnlabeledObsField(REGISTRY_KEYS.LABELS_KEY, labels_key, unlabeled_category), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -570,18 +546,12 @@ def minify_adata( raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") if self.module.use_observed_lib_size is False: - raise ValueError( - "Cannot minify the data if `use_observed_lib_size` is False" - ) + raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze( - np.asarray(counts.sum(axis=1)) - ) - self._update_adata_and_manager_post_minification( - minified_adata, minified_data_type - ) + minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) + self._update_adata_and_manager_post_minification(minified_adata, minified_data_type) self.module.minified_data_type = minified_data_type diff --git a/scvi/model/_scvi.py b/scvi/model/_scvi.py index d6bce7b488..c919659feb 100644 --- a/scvi/model/_scvi.py +++ b/scvi/model/_scvi.py @@ -144,16 +144,12 @@ def __init__( ) else: n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None ) n_batch = self.summary_stats.n_batch - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - ) + use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None if not use_size_factor_key and self.minified_data_type is None: library_log_means, library_log_vars = _init_library_size( @@ -211,23 +207,15 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - NumericalObsField( - REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False - ), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), - NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -298,18 +286,12 @@ def minify_adata( raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") if self.module.use_observed_lib_size is False: - raise ValueError( - "Cannot minify the data if `use_observed_lib_size` is False" - ) + raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze( - np.asarray(counts.sum(axis=1)) - ) - self._update_adata_and_manager_post_minification( - minified_adata, minified_data_type - ) + minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) + self._update_adata_and_manager_post_minification(minified_adata, minified_data_type) self.module.minified_data_type = minified_data_type diff --git a/scvi/model/_totalvi.py b/scvi/model/_totalvi.py index 67f65e1a60..9f7f893ad9 100644 --- a/scvi/model/_totalvi.py +++ b/scvi/model/_totalvi.py @@ -102,12 +102,8 @@ def __init__( self, adata: AnnData, n_latent: int = 20, - gene_dispersion: Literal[ - "gene", "gene-batch", "gene-label", "gene-cell" - ] = "gene", - protein_dispersion: Literal[ - "protein", "protein-batch", "protein-label" - ] = "protein", + gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", + protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", gene_likelihood: Literal["zinb", "nb"] = "nb", latent_distribution: Literal["normal", "ln"] = "normal", empirical_protein_background_prior: bool | None = None, @@ -155,14 +151,10 @@ def __init__( ) n_batch = self.summary_stats.n_batch - use_size_factor_key = ( - REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry - ) + use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None if not use_size_factor_key: - library_log_means, library_log_vars = _init_library_size( - self.adata_manager, n_batch - ) + library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( n_input_genes=self.summary_stats.n_vars, @@ -266,9 +258,7 @@ def train( if adversarial_classifier is None: adversarial_classifier = self._use_adversarial_classifier n_steps_kl_warmup = ( - n_steps_kl_warmup - if n_steps_kl_warmup is not None - else int(0.75 * self.adata.n_obs) + n_steps_kl_warmup if n_steps_kl_warmup is not None else int(0.75 * self.adata.n_obs) ) if reduce_lr_on_plateau: check_val_every_n_epoch = 1 @@ -340,9 +330,7 @@ def get_latent_library_size( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - post = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) libraries = [] for tensors in post: @@ -438,9 +426,7 @@ def get_normalized_expression( indices = np.arange(adata.n_obs) if n_samples_overall is not None: indices = np.random.choice(indices, n_samples_overall) - post = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) if gene_list is None: gene_mask = slice(None) @@ -499,9 +485,7 @@ def get_normalized_expression( # probability of background protein_mixing = 1 / (1 + torch.exp(-py_["mixing"].cpu())) if sample_protein_mixing is True: - protein_mixing = torch.distributions.Bernoulli( - protein_mixing - ).sample() + protein_mixing = torch.distributions.Bernoulli(protein_mixing).sample() protein_val = py_["rate_fore"].cpu() * (1 - protein_mixing) if include_protein_background is True: protein_val += py_["rate_back"].cpu() * protein_mixing @@ -602,9 +586,7 @@ def get_protein_foreground_probability( Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. """ adata = self._validate_anndata(adata) - post = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) if protein_list is None: protein_mask = slice(None) @@ -628,9 +610,7 @@ def get_protein_foreground_probability( if not isinstance(transform_batch, IterableClass): transform_batch = [transform_batch] - transform_batch = _get_batch_code_from_category( - self.adata_manager, transform_batch - ) + transform_batch = _get_batch_code_from_category(self.adata_manager, transform_batch) for tensors in post: y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] py_mixing = torch.zeros_like(y[..., protein_mask]) @@ -855,9 +835,7 @@ def posterior_predictive_sample( all_proteins = self.protein_state_registry.column_names protein_mask = [True if p in protein_list else False for p in all_proteins] - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) scdl_list = [] for tensors in scdl: @@ -902,9 +880,7 @@ def _get_denoised_samples( int of which batch to condition on for all cells """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) scdl_list = [] for tensors in scdl: @@ -948,9 +924,7 @@ def _get_denoised_samples( l_train = torch.distributions.Gamma(r, (1 - p) / p).sample() data = l_train.cpu().numpy() # make background 0 - data[:, :, x.shape[1] :] = ( - data[:, :, x.shape[1] :] * (1 - mixing_sample).cpu().numpy() - ) + data[:, :, x.shape[1] :] = data[:, :, x.shape[1] :] * (1 - mixing_sample).cpu().numpy() scdl_list += [data] scdl_list[-1] = np.transpose(scdl_list[-1], (1, 2, 0)) @@ -1020,17 +994,13 @@ def get_feature_correlation_matrix( rna_size_factor=rna_size_factor, transform_batch=b, ) - flattened = np.zeros( - (denoised_data.shape[0] * n_samples, denoised_data.shape[1]) - ) + flattened = np.zeros((denoised_data.shape[0] * n_samples, denoised_data.shape[1])) for i in range(n_samples): flattened[ denoised_data.shape[0] * (i) : denoised_data.shape[0] * (i + 1) ] = denoised_data[:, :, i] if log_transform is True: - flattened[:, : self.n_genes] = np.log( - flattened[:, : self.n_genes] + 1e-8 - ) + flattened[:, : self.n_genes] = np.log(flattened[:, : self.n_genes] + 1e-8) flattened[:, self.n_genes :] = np.log1p(flattened[:, self.n_genes :]) if correlation_type == "pearson": corr_matrix = np.corrcoef(flattened, rowvar=False) @@ -1102,19 +1072,15 @@ def _get_totalvi_protein_priors(self, adata, n_cells=100): with warnings.catch_warnings(): warnings.filterwarnings("error") - logger.info( - "Computing empirical prior initialization for protein background." - ) + logger.info("Computing empirical prior initialization for protein background.") adata = self._validate_anndata(adata) adata_manager = self.get_anndata_manager(adata) pro_exp = adata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY) - pro_exp = ( - pro_exp.to_numpy() if isinstance(pro_exp, pd.DataFrame) else pro_exp + pro_exp = pro_exp.to_numpy() if isinstance(pro_exp, pd.DataFrame) else pro_exp + batch_mask = adata_manager.get_state_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY).get( + fields.ProteinObsmField.PROTEIN_BATCH_MASK ) - batch_mask = adata_manager.get_state_registry( - REGISTRY_KEYS.PROTEIN_EXP_KEY - ).get(fields.ProteinObsmField.PROTEIN_BATCH_MASK) batch = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY).ravel() cats = adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY)[ fields.CategoricalObsField.CATEGORICAL_MAPPING_KEY @@ -1181,9 +1147,7 @@ def _get_totalvi_protein_priors(self, adata, n_cells=100): # repeat prior for each protein batch_avg_mus = np.array(batch_avg_mus, dtype=np.float32).reshape(1, -1) - batch_avg_scales = np.array(batch_avg_scales, dtype=np.float32).reshape( - 1, -1 - ) + batch_avg_scales = np.array(batch_avg_scales, dtype=np.float32).reshape(1, -1) batch_avg_mus = np.tile(batch_avg_mus, (pro_exp.shape[1], 1)) batch_avg_scales = np.tile(batch_avg_scales, (pro_exp.shape[1], 1)) @@ -1193,9 +1157,7 @@ def _get_totalvi_protein_priors(self, adata, n_cells=100): def get_protein_background_mean(self, adata, indices, batch_size): """Get protein background mean.""" adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) background_mean = [] for tensors in scdl: _, inference_outputs, _ = self.module.forward(tensors) @@ -1251,9 +1213,7 @@ def setup_anndata( fields.CategoricalJointObsField( REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys ), - fields.NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), fields.ProteinObsmField( REGISTRY_KEYS.PROTEIN_EXP_KEY, protein_expression_obsm_key, @@ -1263,9 +1223,7 @@ def setup_anndata( is_count_data=True, ), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -1355,8 +1313,6 @@ def setup_mudata( mod_required=True, ), ] - adata_manager = AnnDataManager( - fields=mudata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/model/_utils.py b/scvi/model/_utils.py index e573301773..e8d3a7575d 100644 --- a/scvi/model/_utils.py +++ b/scvi/model/_utils.py @@ -89,9 +89,7 @@ def parse_device_args( """ valid = [None, "torch", "jax"] if return_device not in valid: - raise InvalidParameterError( - param="return_device", value=return_device, valid=valid - ) + raise InvalidParameterError(param="return_device", value=return_device, valid=valid) _validate_single_device = validate_single_device and devices != "auto" cond1 = isinstance(devices, list) and len(devices) > 1 @@ -252,12 +250,8 @@ def cite_seq_raw_counts_properties( properties = { "raw_mean1": np.concatenate([gp["raw_mean1"], mean1_pro]), "raw_mean2": np.concatenate([gp["raw_mean2"], mean2_pro]), - "non_zeros_proportion1": np.concatenate( - [gp["non_zeros_proportion1"], nonz1_pro] - ), - "non_zeros_proportion2": np.concatenate( - [gp["non_zeros_proportion2"], nonz2_pro] - ), + "non_zeros_proportion1": np.concatenate([gp["non_zeros_proportion1"], nonz1_pro]), + "non_zeros_proportion2": np.concatenate([gp["non_zeros_proportion2"], nonz2_pro]), "raw_normalized_mean1": np.concatenate([gp["raw_normalized_mean1"], nan]), "raw_normalized_mean2": np.concatenate([gp["raw_normalized_mean2"], nan]), } @@ -307,9 +301,7 @@ def _get_batch_code_from_category( if not isinstance(category, IterableClass) or isinstance(category, str): category = [category] - batch_mappings = adata_manager.get_state_registry( - REGISTRY_KEYS.BATCH_KEY - ).categorical_mapping + batch_mappings = adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).categorical_mapping batch_code = [] for cat in category: if cat is None: @@ -352,9 +344,7 @@ def _init_library_size( for i_batch in np.unique(batch_indices): idx_batch = np.squeeze(batch_indices == i_batch) - batch_data = data[ - idx_batch.nonzero()[0] - ] # h5ad requires integer indexing arrays. + batch_data = data[idx_batch.nonzero()[0]] # h5ad requires integer indexing arrays. sum_counts = batch_data.sum(axis=1) masked_log_sum = np.ma.log(sum_counts) if np.ma.is_masked(masked_log_sum): diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index d5cefcf660..15d9169cc6 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -83,9 +83,7 @@ def load_query_data( validate_single_device=True, ) - attr_dict, var_names, load_state_dict = _get_loaded_data( - reference_model, device=device - ) + attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device) if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") @@ -94,9 +92,7 @@ def load_query_data( registry = attr_dict.pop("registry_") if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError( - "It appears you are loading a model from a different class." - ) + raise ValueError("It appears you are loading a model from a different class.") if _SETUP_ARGS_KEY not in registry: raise ValueError( diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index b1dccb0ebf..9c7a866996 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -36,7 +36,9 @@ logger = logging.getLogger(__name__) -_UNTRAINED_WARNING_MESSAGE = "Trying to query inferred values from an untrained model. Please train the model first." +_UNTRAINED_WARNING_MESSAGE = ( + "Trying to query inferred values from an untrained model. Please train the model first." +) _SETUP_INPUTS_EXCLUDED_PARAMS = {"adata", "mdata", "kwargs"} @@ -80,9 +82,7 @@ def __init__(self, adata: AnnOrMuData | None = None): # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality # for this model - data_is_minified = ( - adata is not None and _get_adata_minify_type(adata) is not None - ) + data_is_minified = adata is not None and _get_adata_minify_type(adata) is not None if data_is_minified and not issubclass(type(self), BaseMinifiedModeModelClass): raise NotImplementedError( f"The {type(self).__name__} model currently does not support minified data." @@ -90,9 +90,7 @@ def __init__(self, adata: AnnOrMuData | None = None): self.id = str(uuid4()) # Used for cls._manager_store keys. if adata is not None: self._adata = adata - self._adata_manager = self._get_most_recent_anndata_manager( - adata, required=True - ) + self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. self.registry_ = self._adata_manager.registry @@ -198,9 +196,7 @@ def _create_modalities_attr_dict( } extra_modalities = set(modalities) - set(filtered_modalities) if len(extra_modalities) > 0: - raise ValueError( - f"Extraneous modality mapping(s) detected: {extra_modalities}" - ) + raise ValueError(f"Extraneous modality mapping(s) detected: {extra_modalities}") return attrdict(filtered_modalities) @classmethod @@ -253,18 +249,14 @@ def deregister_manager(self, adata: AnnData | None = None): for adata_id in cls_managers_to_clear: # don't clear the current manager by default - is_current_adata = ( - adata is None and adata_id == self.adata_manager.adata_uuid - ) + is_current_adata = adata is None and adata_id == self.adata_manager.adata_uuid if is_current_adata or adata_id not in cls_manager_store: continue del cls_manager_store[adata_id] for adata_id in instance_managers_to_clear: # don't clear the current manager by default - is_current_adata = ( - adata is None and adata_id == self.adata_manager.adata_uuid - ) + is_current_adata = adata is None and adata_id == self.adata_manager.adata_uuid if is_current_adata or adata_id not in instance_manager_store: continue del instance_manager_store[adata_id] @@ -352,9 +344,7 @@ def get_anndata_manager( adata_manager = cls._per_instance_manager_store[self.id][adata_id] if adata_manager.adata is not adata: - logger.info( - "AnnData object appears to be a copy. Attempting to transfer setup." - ) + logger.info("AnnData object appears to be a copy. Attempting to transfer setup.") _assign_adata_uuid(adata, overwrite=True) adata_manager = self.adata_manager.transfer_fields(adata) self._register_manager_for_instance(adata_manager) @@ -458,27 +448,21 @@ def _validate_anndata( "Input AnnData not setup with scvi-tools. " + "attempting to transfer AnnData setup" ) - self._register_manager_for_instance( - self.adata_manager.transfer_fields(adata) - ) + self._register_manager_for_instance(self.adata_manager.transfer_fields(adata)) else: # Case where correct AnnDataManager is found, replay registration as necessary. adata_manager.validate() return adata - def _check_if_trained( - self, warn: bool = True, message: str = _UNTRAINED_WARNING_MESSAGE - ): + def _check_if_trained(self, warn: bool = True, message: str = _UNTRAINED_WARNING_MESSAGE): """Check if the model is trained. If not trained and `warn` is True, raise a warning, else raise a RuntimeError. """ if not self.is_trained_: if warn: - warnings.warn( - message, UserWarning, stacklevel=settings.warnings_stacklevel - ) + warnings.warn(message, UserWarning, stacklevel=settings.warnings_stacklevel) else: raise RuntimeError(message) @@ -526,9 +510,7 @@ def history(self): def _get_user_attributes(self): """Returns all the self attributes defined in a model class, e.g., `self.is_trained_`.""" attributes = inspect.getmembers(self, lambda a: not (inspect.isroutine(a))) - attributes = [ - a for a in attributes if not (a[0].startswith("__") and a[0].endswith("__")) - ] + attributes = [a for a in attributes if not (a[0].startswith("__") and a[0].endswith("__"))] attributes = [a for a in attributes if not a[0].startswith("_abc_")] return attributes @@ -702,9 +684,7 @@ def load( registry = attr_dict.pop("registry_") if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError( - "It appears you are loading a model from a different class." - ) + raise ValueError("It appears you are loading a model from a different class.") if _SETUP_ARGS_KEY not in registry: raise ValueError( @@ -716,9 +696,7 @@ def load( # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") - getattr(cls, method_name)( - adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] - ) + getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) model = _initialize_model(cls, adata, attr_dict) model.module.on_load(model) diff --git a/scvi/model/base/_differential.py b/scvi/model/base/_differential.py index 08b05b56a2..d714040f1c 100644 --- a/scvi/model/base/_differential.py +++ b/scvi/model/base/_differential.py @@ -236,22 +236,14 @@ def get_bayes_factors( # First case: same batch normalization in two groups logger.debug("Same batches in both cell groups") n_batches = len(set(batchid1_vals)) - n_samples_per_batch = ( - m_permutation // n_batches if m_permutation is not None else None - ) - logger.debug( - f"Using {n_samples_per_batch} samples per batch for pair matching" - ) + n_samples_per_batch = m_permutation // n_batches if m_permutation is not None else None + logger.debug(f"Using {n_samples_per_batch} samples per batch for pair matching") scales_1 = [] scales_2 = [] for batch_val in set(batchid1_vals): # Select scale samples that originate from the same batch id - scales_1_batch = scales_batches_1["scale"][ - scales_batches_1["batch"] == batch_val - ] - scales_2_batch = scales_batches_2["scale"][ - scales_batches_2["batch"] == batch_val - ] + scales_1_batch = scales_batches_1["scale"][scales_batches_1["batch"] == batch_val] + scales_2_batch = scales_batches_2["scale"][scales_batches_2["batch"] == batch_val] # Create more pairs scales_1_local, scales_2_local = pairs_sampler( @@ -324,9 +316,7 @@ def lfc(x, y): def m1_domain_fn(samples): delta_ = ( - delta - if delta is not None - else estimate_delta(lfc_means=samples.mean(0)) + delta if delta is not None else estimate_delta(lfc_means=samples.mean(0)) ) logger.debug(f"Using delta ~ {delta_:.2f}") return np.abs(samples) >= delta_ @@ -423,9 +413,7 @@ def scale_sampler( """ # Get overall number of desired samples and desired batches if batchid is None and not use_observed_batches: - batch_registry = self.adata_manager.get_state_registry( - REGISTRY_KEYS.BATCH_KEY - ) + batch_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY) batchid = batch_registry.categorical_mapping if use_observed_batches: if batchid is not None: @@ -534,8 +522,7 @@ def estimate_pseudocounts_offset( max_scales_a = np.max(scales_a, 0) max_scales_b = np.max(scales_b, 0) asserts = ( - (max_scales_a.shape == where_zero_a.shape) - and (max_scales_b.shape == where_zero_b.shape) + (max_scales_a.shape == where_zero_a.shape) and (max_scales_b.shape == where_zero_b.shape) ) and (where_zero_a.shape == where_zero_b.shape) if not asserts: raise ValueError( @@ -704,9 +691,7 @@ def describe_continuous_distrib( return dist_props -def save_cluster_xlsx( - filepath: str, de_results: list[pd.DataFrame], cluster_names: list -): +def save_cluster_xlsx(filepath: str, de_results: list[pd.DataFrame], cluster_names: list): """Saves multi-clusters DE in an xlsx sheet. Parameters diff --git a/scvi/model/base/_jaxmixin.py b/scvi/model/base/_jaxmixin.py index fe7145981a..5115408949 100644 --- a/scvi/model/base/_jaxmixin.py +++ b/scvi/model/base/_jaxmixin.py @@ -99,9 +99,7 @@ def train( # Ignore Pytorch Lightning warnings for Jax workarounds. with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=UserWarning, module=r"pytorch_lightning.*" - ) + warnings.filterwarnings("ignore", category=UserWarning, module=r"pytorch_lightning.*") runner = self._train_runner_cls( self, training_plan=self.training_plan, diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 4dfb0aef01..6daf86d939 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -226,9 +226,9 @@ def _get_one_posterior_sample( sample = self.module.guide(*args, **kwargs) else: guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) - model_trace = poutine.trace( - poutine.replay(self.module.model, guide_trace) - ).get_trace(*args, **kwargs) + model_trace = poutine.trace(poutine.replay(self.module.model, guide_trace)).get_trace( + *args, **kwargs + ) sample = { name: site["value"] for name, site in model_trace.nodes.items() @@ -398,9 +398,7 @@ def _posterior_samples_minibatch( batch_size = batch_size if batch_size is not None else settings.batch_size - train_dl = AnnDataLoader( - self.adata_manager, shuffle=False, batch_size=batch_size - ) + train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in track( @@ -424,20 +422,14 @@ def _posterior_samples_minibatch( obs_plate_dim = list(obs_plate_sites.values())[0] sample_kwargs_obs_plate = sample_kwargs.copy() - sample_kwargs_obs_plate[ - "return_sites" - ] = self._get_obs_plate_return_sites( + sample_kwargs_obs_plate["return_sites"] = self._get_obs_plate_return_sites( sample_kwargs["return_sites"], list(obs_plate_sites.keys()) ) sample_kwargs_obs_plate["show_progress"] = False - samples = self._get_posterior_samples( - args, kwargs, **sample_kwargs_obs_plate - ) + samples = self._get_posterior_samples(args, kwargs, **sample_kwargs_obs_plate) else: - samples_ = self._get_posterior_samples( - args, kwargs, **sample_kwargs_obs_plate - ) + samples_ = self._get_posterior_samples(args, kwargs, **sample_kwargs_obs_plate) samples = { k: np.array( @@ -446,9 +438,7 @@ def _posterior_samples_minibatch( [samples[k][j], samples_[k][j]], axis=obs_plate_dim, ) - for j in range( - len(samples[k]) - ) # for each sample (in 0 dimension + for j in range(len(samples[k])) # for each sample (in 0 dimension ] ) for k in samples.keys() # for each variable @@ -458,9 +448,7 @@ def _posterior_samples_minibatch( # sample global parameters global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) global_samples = { - k: v - for k, v in global_samples.items() - if k not in list(obs_plate_sites.keys()) + k: v for k, v in global_samples.items() if k not in list(obs_plate_sites.keys()) } for k in global_samples.keys(): @@ -549,8 +537,6 @@ def sample_posterior( "q95": lambda x, axis: np.quantile(x, 0.95, axis=axis), } for k, fun in summary_fun.items(): - results[f"post_sample_{k}"] = { - v: fun(samples[v], axis=0) for v in param_names - } + results[f"post_sample_{k}"] = {v: fun(samples[v], axis=0) for v in param_names} return results diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 00a4ab53ed..24fe5ba7fe 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -37,9 +37,7 @@ def _get_transform_batch_gen_kwargs(self, batch): if "transform_batch" in inspect.signature(self.module.generative).parameters: return {"transform_batch": batch} else: - raise NotImplementedError( - "Transforming batches is not implemented for this model." - ) + raise NotImplementedError("Transforming batches is not implemented for this model.") def _get_importance_weights( self, @@ -111,9 +109,7 @@ def _get_importance_weights( ) mask = torch.tensor(anchor_cells) qz_anchor = subset_distribution(qz, mask, 0) # n_anchors, n_latent - log_qz = qz_anchor.log_prob(zs.unsqueeze(-2)).sum( - dim=-1 - ) # n_samples, n_cells, n_anchors + log_qz = qz_anchor.log_prob(zs.unsqueeze(-2)).sum(dim=-1) # n_samples, n_cells, n_anchors log_px_z = [] distributions_px = deep_to(px, device=device) @@ -139,9 +135,7 @@ def _get_importance_weights( dim=1, ) if truncation: - tau = torch.logsumexp(importance_weight, 0) - np.log( - importance_weight.shape[0] - ) + tau = torch.logsumexp(importance_weight, 0) - np.log(importance_weight.shape[0]) importance_weight = torch.clamp(importance_weight, min=tau) log_probs = importance_weight - torch.logsumexp(importance_weight, 0) @@ -222,9 +216,7 @@ def get_normalized_expression( if n_samples_overall is not None: assert n_samples == 1 # default value n_samples = n_samples_overall // len(indices) + 1 - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) transform_batch = _get_batch_code_from_category( self.get_anndata_manager(adata, required=True), transform_batch @@ -385,9 +377,7 @@ def differential_expression( weights=weights, **importance_weighting_kwargs, ) - representation_fn = ( - self.get_latent_representation if filter_outlier_cells else None - ) + representation_fn = self.get_latent_representation if filter_outlier_cells else None result = _de_core( self.get_anndata_manager(adata, required=True), @@ -457,9 +447,7 @@ def posterior_predictive_sample( import sparse adata = self._validate_anndata(adata) - dataloader = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + dataloader = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) if gene_list is None: gene_mask = slice(None) @@ -513,9 +501,7 @@ def _get_denoised_samples( denoised_samples """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) data_loader_list = [] for tensors in scdl: @@ -616,9 +602,7 @@ def get_feature_correlation_matrix( rna_size_factor=rna_size_factor, transform_batch=b, ) - flattened = np.zeros( - (denoised_data.shape[0] * n_samples, denoised_data.shape[1]) - ) + flattened = np.zeros((denoised_data.shape[0] * n_samples, denoised_data.shape[1])) for i in range(n_samples): if n_samples == 1: flattened[ @@ -633,9 +617,7 @@ def get_feature_correlation_matrix( elif correlation_type == "spearman": corr_matrix, _ = spearmanr(flattened) else: - raise ValueError( - "Unknown correlation type. Choose one of 'spearman', 'pearson'." - ) + raise ValueError("Unknown correlation type. Choose one of 'spearman', 'pearson'.") corr_mats.append(corr_matrix) corr_matrix = np.mean(np.stack(corr_mats), axis=0) var_names = adata.var_names @@ -668,9 +650,7 @@ def get_likelihood_parameters( """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) dropout_list = [] mean_list = [] @@ -747,9 +727,7 @@ def get_latent_library_size( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) libraries = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) diff --git a/scvi/model/base/_training_mixin.py b/scvi/model/base/_training_mixin.py index ca050ca823..aeeb3d9395 100644 --- a/scvi/model/base/_training_mixin.py +++ b/scvi/model/base/_training_mixin.py @@ -110,9 +110,7 @@ def train( validation_size=validation_size, batch_size=batch_size, shuffle_set_split=shuffle_set_split, - distributed_sampler=use_distributed_sampler( - trainer_kwargs.get("strategy", None) - ), + distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), load_sparse_tensor=load_sparse_tensor, **datasplitter_kwargs, ) diff --git a/scvi/model/base/_utils.py b/scvi/model/base/_utils.py index c2cf6c8f50..41632197e4 100644 --- a/scvi/model/base/_utils.py +++ b/scvi/model/base/_utils.py @@ -42,9 +42,7 @@ def _load_legacy_saved_files( if os.path.exists(adata_path): adata = read(adata_path) elif not os.path.exists(adata_path): - raise ValueError( - "Save path contains no saved anndata and no adata was passed." - ) + raise ValueError("Save path contains no saved anndata and no adata was passed.") else: adata = None @@ -87,9 +85,7 @@ def _load_saved_files( else: adata = anndata.read(adata_path) else: - raise ValueError( - "Save path contains no saved anndata and no adata was passed." - ) + raise ValueError("Save path contains no saved anndata and no adata was passed.") else: adata = None @@ -219,9 +215,7 @@ def _de_core( if group1 is None and idx1 is None: group1 = adata.obs[groupby].astype("category").cat.categories.tolist() if len(group1) == 1: - raise ValueError( - "Only a single group in the data. Can't run DE on a single group." - ) + raise ValueError("Only a single group in the data. Can't run DE on a single group.") if not isinstance(group1, IterableClass) or isinstance(group1, str): group1 = [group1] @@ -290,9 +284,7 @@ def _fdr_de_prediction(posterior_probas: pd.Series, fdr: float = 0.05) -> pd.Ser sorted_pgs = posterior_probas.sort_values(ascending=False) cumulative_fdr = (1.0 - sorted_pgs).cumsum() / (1.0 + np.arange(len(sorted_pgs))) d = (cumulative_fdr <= fdr).sum() - is_pred_de = pd.Series( - np.zeros_like(cumulative_fdr).astype(bool), index=sorted_pgs.index - ) + is_pred_de = pd.Series(np.zeros_like(cumulative_fdr).astype(bool), index=sorted_pgs.index) is_pred_de.iloc[:d] = True is_pred_de = is_pred_de.loc[original_index] return is_pred_de diff --git a/scvi/model/base/_vaemixin.py b/scvi/model/base/_vaemixin.py index 29506655af..3de6fd9d43 100644 --- a/scvi/model/base/_vaemixin.py +++ b/scvi/model/base/_vaemixin.py @@ -40,9 +40,7 @@ def get_elbo( Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) elbo = compute_elbo(self.module, scdl) return -elbo @@ -131,9 +129,7 @@ def get_reconstruction_error( Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) reconstruction_error = compute_reconstruction_error(self.module, scdl) return reconstruction_error @@ -178,9 +174,7 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent = [] latent_qzm = [] latent_qzv = [] diff --git a/scvi/model/utils/_mde.py b/scvi/model/utils/_mde.py index 7dedeff92b..dcf7a20bb4 100644 --- a/scvi/model/utils/_mde.py +++ b/scvi/model/utils/_mde.py @@ -65,9 +65,7 @@ def mde( try: import pymde except ImportError as err: - raise ImportError( - "Please install pymde package via `pip install pymde`" - ) from err + raise ImportError("Please install pymde package via `pip install pymde`") from err if seed is None and settings.seed is not None: seed = settings.seed diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index a66bfe3933..65946725a5 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -117,9 +117,9 @@ def forward( with pyro.plate("topics", self.n_topics), poutine.scale(None, kl_weight): log_topic_feature_dist = pyro.sample( "log_topic_feature_dist", - dist.Normal( - self.topic_feature_prior_mu, self.topic_feature_prior_sigma - ).to_event(1), + dist.Normal(self.topic_feature_prior_mu, self.topic_feature_prior_sigma).to_event( + 1 + ), ) topic_feature_dist = F.softmax(log_topic_feature_dist, dim=1) @@ -129,9 +129,7 @@ def forward( with poutine.scale(None, kl_weight): log_cell_topic_dist = pyro.sample( "log_cell_topic_dist", - dist.Normal( - self.cell_topic_prior_mu, self.cell_topic_prior_sigma - ).to_event(1), + dist.Normal(self.cell_topic_prior_mu, self.cell_topic_prior_sigma).to_event(1), ) cell_topic_dist = F.softmax(log_cell_topic_dist, dim=1) @@ -210,9 +208,7 @@ def forward( cell_topic_posterior_sigma = cell_topic_posterior.scale**2 pyro.sample( "log_cell_topic_dist", - dist.Normal( - cell_topic_posterior_mu, cell_topic_posterior_sigma - ).to_event(1), + dist.Normal(cell_topic_posterior_mu, cell_topic_posterior_sigma).to_event(1), ) diff --git a/scvi/module/_autozivae.py b/scvi/module/_autozivae.py index 278b86fd92..36d6f0d473 100644 --- a/scvi/module/_autozivae.py +++ b/scvi/module/_autozivae.py @@ -62,9 +62,7 @@ def __init__( alpha_prior: Tunable[float] = 0.5, beta_prior: Tunable[float] = 0.5, minimal_dropout: Tunable[float] = 0.01, - zero_inflation: Tunable[ - Literal["gene", "gene-batch", "gene-label", "gene-cell"] - ] = "gene", + zero_inflation: Tunable[Literal["gene", "gene-batch", "gene-label", "gene-cell"]] = "gene", **kwargs, ) -> None: if "reconstruction_loss" in kwargs: @@ -92,57 +90,35 @@ def __init__( if alpha_prior is None: self.alpha_prior_logit = torch.nn.Parameter(torch.randn(1)) else: - self.register_buffer( - "alpha_prior_logit", torch.tensor([logit(alpha_prior)]) - ) + self.register_buffer("alpha_prior_logit", torch.tensor([logit(alpha_prior)])) if beta_prior is None: self.beta_prior_logit = torch.nn.Parameter(torch.randn(1)) else: - self.register_buffer( - "beta_prior_logit", torch.tensor([logit(alpha_prior)]) - ) + self.register_buffer("beta_prior_logit", torch.tensor([logit(alpha_prior)])) elif self.zero_inflation == "gene-batch": - self.alpha_posterior_logit = torch.nn.Parameter( - torch.randn(n_input, self.n_batch) - ) - self.beta_posterior_logit = torch.nn.Parameter( - torch.randn(n_input, self.n_batch) - ) + self.alpha_posterior_logit = torch.nn.Parameter(torch.randn(n_input, self.n_batch)) + self.beta_posterior_logit = torch.nn.Parameter(torch.randn(n_input, self.n_batch)) if alpha_prior is None: self.alpha_prior_logit = torch.nn.parameter(torch.randn(1, self.n_batch)) else: - self.register_buffer( - "alpha_prior_logit", torch.tensor([logit(alpha_prior)]) - ) + self.register_buffer("alpha_prior_logit", torch.tensor([logit(alpha_prior)])) if beta_prior is None: self.beta_prior_logit = torch.nn.parameter(torch.randn(1, self.n_batch)) else: - self.register_buffer( - "beta_prior_logit", torch.tensor([logit(beta_prior)]) - ) + self.register_buffer("beta_prior_logit", torch.tensor([logit(beta_prior)])) elif self.zero_inflation == "gene-label": - self.alpha_posterior_logit = torch.nn.Parameter( - torch.randn(n_input, self.n_labels) - ) - self.beta_posterior_logit = torch.nn.Parameter( - torch.randn(n_input, self.n_labels) - ) + self.alpha_posterior_logit = torch.nn.Parameter(torch.randn(n_input, self.n_labels)) + self.beta_posterior_logit = torch.nn.Parameter(torch.randn(n_input, self.n_labels)) if alpha_prior is None: - self.alpha_prior_logit = torch.nn.parameter( - torch.randn(1, self.n_labels) - ) + self.alpha_prior_logit = torch.nn.parameter(torch.randn(1, self.n_labels)) else: - self.register_buffer( - "alpha_prior_logit", torch.tensor([logit(alpha_prior)]) - ) + self.register_buffer("alpha_prior_logit", torch.tensor([logit(alpha_prior)])) if beta_prior is None: self.beta_prior_logit = torch.nn.parameter(torch.randn(1, self.n_labels)) else: - self.register_buffer( - "beta_prior_logit", torch.tensor([logit(beta_prior)]) - ) + self.register_buffer("beta_prior_logit", torch.tensor([logit(beta_prior)])) else: # gene-cell raise Exception("Gene-cell not implemented yet for AutoZI") @@ -161,9 +137,7 @@ def get_alphas_betas( if as_numpy: for key, value in outputs.items(): outputs[key] = ( - value.detach().cpu().numpy() - if value.requires_grad - else value.cpu().numpy() + value.detach().cpu().numpy() if value.requires_grad else value.cpu().numpy() ) return outputs @@ -212,9 +186,7 @@ def reshape_bernoulli( else: bernoulli_params_res = [] for sample in range(bernoulli_params.shape[0]): - bernoulli_params_res.append( - F.linear(one_hot_label, bernoulli_params[sample]) - ) + bernoulli_params_res.append(F.linear(one_hot_label, bernoulli_params[sample])) bernoulli_params = torch.stack(bernoulli_params_res) elif self.zero_inflation == "gene-batch": one_hot_batch = one_hot(batch_index, self.n_batch) @@ -224,9 +196,7 @@ def reshape_bernoulli( else: bernoulli_params_res = [] for sample in range(bernoulli_params.shape[0]): - bernoulli_params_res.append( - F.linear(one_hot_batch, bernoulli_params[sample]) - ) + bernoulli_params_res.append(F.linear(one_hot_batch, bernoulli_params[sample])) bernoulli_params = torch.stack(bernoulli_params_res) return bernoulli_params @@ -258,16 +228,12 @@ def sample_bernoulli_params( ) ) - bernoulli_params = self.sample_from_beta_distribution( - alpha_posterior, beta_posterior - ) + bernoulli_params = self.sample_from_beta_distribution(alpha_posterior, beta_posterior) bernoulli_params = self.reshape_bernoulli(bernoulli_params, batch_index, y) return bernoulli_params - def rescale_dropout( - self, px_dropout: torch.Tensor, eps_log: float = 1e-8 - ) -> torch.Tensor: + def rescale_dropout(self, px_dropout: torch.Tensor, eps_log: float = 1e-8) -> torch.Tensor: """Rescale dropout rate.""" if self.minimal_dropout > 0.0: dropout_prob_rescaled = self.minimal_dropout + ( @@ -325,9 +291,7 @@ def compute_global_kl_divergence(self) -> torch.Tensor: alpha_prior = outputs["alpha_prior"] beta_prior = outputs["beta_prior"] - return kl( - Beta(alpha_posterior, beta_posterior), Beta(alpha_prior, beta_prior) - ).sum() + return kl(Beta(alpha_posterior, beta_posterior), Beta(alpha_prior, beta_prior)).sum() def get_reconstruction_loss( self, @@ -341,9 +305,7 @@ def get_reconstruction_loss( ) -> torch.Tensor: """Compute the reconstruction loss.""" # LLs for NB and ZINB - ll_zinb = torch.log( - 1.0 - bernoulli_params + eps_log - ) + ZeroInflatedNegativeBinomial( + ll_zinb = torch.log(1.0 - bernoulli_params + eps_log) + ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ).log_prob(x) ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial( @@ -352,9 +314,7 @@ def get_reconstruction_loss( # Reconstruction loss using a logsumexp-type computation ll_max = torch.max(ll_zinb, ll_nb) - ll_tot = ll_max + torch.log( - torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max) - ) + ll_tot = ll_max + torch.log(torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max)) reconst_loss = -ll_tot.sum(dim=-1) return reconst_loss @@ -402,9 +362,7 @@ def loss( kl_divergence_bernoulli = self.compute_global_kl_divergence() # Reconstruction loss - reconst_loss = self.get_reconstruction_loss( - x, px_rate, px_r, px_dropout, bernoulli_params - ) + reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout, bernoulli_params) kl_global = kl_divergence_bernoulli kl_local_for_warmup = kl_divergence_z diff --git a/scvi/module/_jaxvae.py b/scvi/module/_jaxvae.py index 8ed48e0a89..7d234f6d85 100644 --- a/scvi/module/_jaxvae.py +++ b/scvi/module/_jaxvae.py @@ -92,9 +92,7 @@ def setup(self): "disp", lambda rng, shape: jax.random.normal(rng, shape), (self.n_input, 1) ) - def __call__( - self, z: jnp.ndarray, batch: jnp.ndarray, training: Optional[bool] = None - ): + def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, training: Optional[bool] = None): """Forward pass.""" # TODO(adamgayoso): Test this training = nn.merge_param("training", self.training, training) diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index 3213df7424..9c92830bf1 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -134,9 +134,7 @@ def __init__( self.V = torch.nn.Parameter(torch.randn(self.n_labels + 1, self.n_spots)) # within cell_type factor loadings - self.gamma = torch.nn.Parameter( - torch.randn(n_latent, self.n_labels, self.n_spots) - ) + self.gamma = torch.nn.Parameter(torch.randn(n_latent, self.n_labels, self.n_spots)) if mean_vprior is not None: self.p = mean_vprior.shape[1] self.register_buffer("mean_vprior", torch.tensor(mean_vprior)) @@ -223,9 +221,7 @@ def generative(self, x, ind_x): v_ind = torch.nn.functional.softplus(v_ind) # reshape and get gene expression value for all minibatch - gamma_ind = torch.transpose( - gamma_ind, 2, 0 - ) # minibatch_size, n_labels, n_latent + gamma_ind = torch.transpose(gamma_ind, 2, 0) # minibatch_size, n_labels, n_latent gamma_reshape = gamma_ind.reshape( (-1, self.n_latent) ) # minibatch_size * n_labels, n_latent @@ -238,9 +234,7 @@ def generative(self, x, ind_x): ) # (minibatch, n_labels, n_genes) # add the dummy cell type - eps = eps.repeat((m, 1)).view( - m, 1, -1 - ) # (M, 1, n_genes) <- this is the dummy cell type + eps = eps.repeat((m, 1)).view(m, 1, -1) # (M, 1, n_genes) <- this is the dummy cell type # account for gene specific bias and add noise r_hat = torch.cat( @@ -278,9 +272,7 @@ def loss( # eta prior likelihood mean = torch.zeros_like(self.eta) scale = torch.ones_like(self.eta) - glo_neg_log_likelihood_prior = ( - -self.eta_reg * Normal(mean, scale).log_prob(self.eta).sum() - ) + glo_neg_log_likelihood_prior = -self.eta_reg * Normal(mean, scale).log_prob(self.eta).sum() glo_neg_log_likelihood_prior += self.beta_reg * torch.var(self.beta) v_sparsity_loss = self.l1_reg * torch.abs(v).mean(1) @@ -312,9 +304,7 @@ def loss( # High v_sparsity_loss is detrimental early in training, scaling by kl_weight to increase over training epochs. loss = n_obs * ( - torch.mean( - reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss) - ) + torch.mean(reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss)) + glo_neg_log_likelihood_prior ) @@ -344,9 +334,7 @@ def get_proportions(self, x=None, keep_noise=False) -> np.ndarray: x_ = torch.log(1 + x) res = torch.nn.functional.softplus(self.V_encoder(x_)) else: - res = ( - torch.nn.functional.softplus(self.V).cpu().numpy().T - ) # n_spots, n_labels + 1 + res = torch.nn.functional.softplus(self.V).cpu().numpy().T # n_spots, n_labels + 1 # remove dummy cell type proportion values if not keep_noise: res = res[:, :-1] diff --git a/scvi/module/_multivae.py b/scvi/module/_multivae.py index 95eeadef60..5e444a36f6 100644 --- a/scvi/module/_multivae.py +++ b/scvi/module/_multivae.py @@ -50,9 +50,7 @@ def __init__( inject_covariates=deep_inject_covariates, **kwargs, ) - self.output = torch.nn.Sequential( - torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU() - ) + self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU()) def forward(self, x: torch.Tensor, *cat_list: int): """Forward pass.""" @@ -161,17 +159,13 @@ def forward(self, z: torch.Tensor, *cat_list: int): py_back_cat_z = torch.cat([py_back, z], dim=-1) py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) - py_["back_beta"] = torch.exp( - self.py_back_mean_log_beta(py_back_cat_z, *cat_list) - ) + py_["back_beta"] = torch.exp(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() py_["rate_back"] = torch.exp(log_pro_back_mean) py_fore = self.py_fore_decoder(z, *cat_list) py_fore_cat_z = torch.cat([py_fore, z], dim=-1) - py_["fore_scale"] = ( - self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8 - ) + py_["fore_scale"] = self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8 py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"] p_mixing = self.sigmoid_decoder(z, *cat_list) @@ -463,9 +457,7 @@ def __init__( torch.clamp(torch.randn(n_input_proteins, n_batch), -10, 1) ) else: - self.background_pro_alpha = torch.nn.Parameter( - torch.randn(n_input_proteins) - ) + self.background_pro_alpha = torch.nn.Parameter(torch.randn(n_input_proteins)) self.background_pro_log_beta = torch.nn.Parameter( torch.clamp(torch.randn(n_input_proteins), -10, 1) ) @@ -520,13 +512,9 @@ def __init__( if self.protein_dispersion == "protein": self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins)) elif self.protein_dispersion == "protein-batch": - self.py_r = torch.nn.Parameter( - 2 * torch.rand(self.n_input_proteins, n_batch) - ) + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins, n_batch)) elif self.protein_dispersion == "protein-label": - self.py_r = torch.nn.Parameter( - 2 * torch.rand(self.n_input_proteins, n_labels) - ) + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins, n_labels)) else: # protein-cell pass @@ -588,9 +576,7 @@ def inference( if self.n_input_regions == 0: x_chr = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: - x_chr = x[ - :, self.n_input_genes : (self.n_input_genes + self.n_input_regions) - ] + x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] mask_expr = x_rna.sum(dim=1) > 0 mask_acc = x_chr.sum(dim=1) > 0 @@ -691,9 +677,7 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY size_factor = ( - torch.log(tensors[size_factor_key]) - if size_factor_key in tensors.keys() - else None + torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None ) batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] @@ -775,9 +759,7 @@ def generative( px_r = torch.exp(px_r) # Protein Decoder - py_, log_pro_back_mean = self.z_decoder_pro( - decoder_input, batch_index, *categorical_input - ) + py_, log_pro_back_mean = self.z_decoder_pro(decoder_input, batch_index, *categorical_input) # Protein Dispersion if self.protein_dispersion == "protein-label": # py_r gets transposed - last dimension is n_proteins @@ -799,9 +781,7 @@ def generative( "log_pro_back_mean": log_pro_back_mean, } - def loss( - self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0 - ): + def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): """Computes the loss function for the model.""" # Get the data x = tensors[REGISTRY_KEYS.X_KEY] @@ -821,9 +801,7 @@ def loss( # Compute Accessibility loss p = generative_outputs["p"] libsize_acc = inference_outputs["libsize_acc"] - rl_accessibility = self.get_reconstruction_loss_accessibility( - x_chr, p, libsize_acc - ) + rl_accessibility = self.get_reconstruction_loss_accessibility(x_chr, p, libsize_acc) # Compute Expression loss px_rate = generative_outputs["px_rate"] @@ -846,9 +824,7 @@ def loss( recon_loss_expression = rl_expression * mask_expr recon_loss_accessibility = rl_accessibility * mask_acc recon_loss_protein = rl_protein * mask_pro - recon_loss = ( - recon_loss_expression + recon_loss_accessibility + recon_loss_protein - ) + recon_loss = recon_loss_expression + recon_loss_accessibility + recon_loss_protein # Compute KLD between Z and N(0,I) qz_m = inference_outputs["qz_m"] @@ -891,9 +867,7 @@ def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout): rl = 0.0 if self.gene_likelihood == "zinb": rl = ( - -ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout - ) + -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) .log_prob(x) .sum(dim=-1) ) @@ -905,16 +879,10 @@ def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout): def get_reconstruction_loss_accessibility(self, x, p, d): """Computes the reconstruction loss for the accessibility data.""" - reg_factor = ( - torch.sigmoid(self.region_factors) if self.region_factors is not None else 1 - ) - return torch.nn.BCELoss(reduction="none")( - p * d * reg_factor, (x > 0).float() - ).sum(dim=-1) + reg_factor = torch.sigmoid(self.region_factors) if self.region_factors is not None else 1 + return torch.nn.BCELoss(reduction="none")(p * d * reg_factor, (x > 0).float()).sum(dim=-1) - def _compute_mod_penalty( - self, mod_params1, mod_params2, mod_params3, mask1, mask2, mask3 - ): + def _compute_mod_penalty(self, mod_params1, mod_params2, mod_params3, mask1, mask2, mask3): """Computes Similarity Penalty across modalities given selection (None, Jeffreys, MMD). Parameters @@ -931,9 +899,7 @@ def _compute_mod_penalty( if self.modality_penalty == "None": return 0 elif self.modality_penalty == "Jeffreys": - pair_penalty = torch.zeros( - mask1.shape[0], device=mask1.device, requires_grad=True - ) + pair_penalty = torch.zeros(mask1.shape[0], device=mask1.device, requires_grad=True) if mask12.sum().gt(0): penalty12 = sym_kld( mod_params1[0], @@ -941,9 +907,9 @@ def _compute_mod_penalty( mod_params2[0], mod_params2[1].sqrt(), ) - penalty12 = torch.where( - mask12, penalty12.T, torch.zeros_like(penalty12).T - ).sum(dim=0) + penalty12 = torch.where(mask12, penalty12.T, torch.zeros_like(penalty12).T).sum( + dim=0 + ) pair_penalty = pair_penalty + penalty12 if mask13.sum().gt(0): penalty13 = sym_kld( @@ -952,9 +918,9 @@ def _compute_mod_penalty( mod_params3[0], mod_params3[1].sqrt(), ) - penalty13 = torch.where( - mask13, penalty13.T, torch.zeros_like(penalty13).T - ).sum(dim=0) + penalty13 = torch.where(mask13, penalty13.T, torch.zeros_like(penalty13).T).sum( + dim=0 + ) pair_penalty = pair_penalty + penalty13 if mask23.sum().gt(0): penalty23 = sym_kld( @@ -963,32 +929,30 @@ def _compute_mod_penalty( mod_params3[0], mod_params3[1].sqrt(), ) - penalty23 = torch.where( - mask23, penalty23.T, torch.zeros_like(penalty23).T - ).sum(dim=0) + penalty23 = torch.where(mask23, penalty23.T, torch.zeros_like(penalty23).T).sum( + dim=0 + ) pair_penalty = pair_penalty + penalty23 elif self.modality_penalty == "MMD": - pair_penalty = torch.zeros( - mask1.shape[0], device=mask1.device, requires_grad=True - ) + pair_penalty = torch.zeros(mask1.shape[0], device=mask1.device, requires_grad=True) if mask12.sum().gt(0): penalty12 = torch.linalg.norm(mod_params1[0] - mod_params2[0], dim=1) - penalty12 = torch.where( - mask12, penalty12.T, torch.zeros_like(penalty12).T - ).sum(dim=0) + penalty12 = torch.where(mask12, penalty12.T, torch.zeros_like(penalty12).T).sum( + dim=0 + ) pair_penalty = pair_penalty + penalty12 if mask13.sum().gt(0): penalty13 = torch.linalg.norm(mod_params1[0] - mod_params3[0], dim=1) - penalty13 = torch.where( - mask13, penalty13.T, torch.zeros_like(penalty13).T - ).sum(dim=0) + penalty13 = torch.where(mask13, penalty13.T, torch.zeros_like(penalty13).T).sum( + dim=0 + ) pair_penalty = pair_penalty + penalty13 if mask23.sum().gt(0): penalty23 = torch.linalg.norm(mod_params2[0] - mod_params3[0], dim=1) - penalty23 = torch.where( - mask23, penalty23.T, torch.zeros_like(penalty23).T - ).sum(dim=0) + penalty23 = torch.where(mask23, penalty23.T, torch.zeros_like(penalty23).T).sum( + dim=0 + ) pair_penalty = pair_penalty + penalty23 else: raise ValueError("modality penalty not supported") diff --git a/scvi/module/_peakvae.py b/scvi/module/_peakvae.py index 7a877f6806..bdbb21a2bb 100644 --- a/scvi/module/_peakvae.py +++ b/scvi/module/_peakvae.py @@ -70,9 +70,7 @@ def __init__( inject_covariates=deep_inject_covariates, **kwargs, ) - self.output = torch.nn.Sequential( - torch.nn.Linear(n_hidden, n_output), torch.nn.Sigmoid() - ) + self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, n_output), torch.nn.Sigmoid()) def forward(self, z: torch.Tensor, *cat_list: int): """Forward pass.""" @@ -161,9 +159,7 @@ def __init__( super().__init__() self.n_input_regions = n_input_regions - self.n_hidden = ( - int(np.sqrt(self.n_input_regions)) if n_hidden is None else n_hidden - ) + self.n_hidden = int(np.sqrt(self.n_input_regions)) if n_hidden is None else n_hidden self.n_latent = int(np.sqrt(self.n_hidden)) if n_latent is None else n_latent self.n_layers_encoder = n_layers_encoder self.n_layers_decoder = n_layers_decoder @@ -329,9 +325,7 @@ def generative( return {"p": p} - def loss( - self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0 - ): + def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): """Compute the loss.""" x = tensors[REGISTRY_KEYS.X_KEY] qz = inference_outputs["qz"] diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 1b76852db1..896fbf55b2 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -92,9 +92,7 @@ def __init__( n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: Tunable[float] = 0.1, - dispersion: Tunable[ - Literal["gene", "gene-batch", "gene-label", "gene-cell"] - ] = "gene", + dispersion: Tunable[Literal["gene", "gene-batch", "gene-label", "gene-cell"]] = "gene", log_variational: Tunable[bool] = True, gene_likelihood: Tunable[Literal["zinb", "nb"]] = "zinb", y_prior=None, @@ -173,9 +171,7 @@ def __init__( requires_grad=False, ) self.use_labels_groups = use_labels_groups - self.labels_groups = ( - np.array(labels_groups) if labels_groups is not None else None - ) + self.labels_groups = np.array(labels_groups) if labels_groups is not None else None if self.use_labels_groups: if labels_groups is None: raise ValueError("Specify label groups") @@ -251,9 +247,7 @@ def classify( w_y = torch.zeros_like(unw_y) for i, group_index in enumerate(self.groups_index): unw_y_g = unw_y[:, group_index] - w_y[:, group_index] = unw_y_g / ( - unw_y_g.sum(dim=-1, keepdim=True) + 1e-8 - ) + w_y[:, group_index] = unw_y_g / (unw_y_g.sum(dim=-1, keepdim=True) + 1e-8) w_y[:, group_index] *= w_g[:, [i]] else: w_y = self.classifier(z) @@ -265,14 +259,10 @@ def classification_loss(self, labelled_dataset): y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] # (n_obs, 1) batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY - cont_covs = ( - labelled_dataset[cont_key] if cont_key in labelled_dataset.keys() else None - ) + cont_covs = labelled_dataset[cont_key] if cont_key in labelled_dataset.keys() else None cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = ( - labelled_dataset[cat_key] if cat_key in labelled_dataset.keys() else None - ) + cat_covs = labelled_dataset[cat_key] if cat_key in labelled_dataset.keys() else None # NOTE: prior to v1.1, this method returned probabilities per label by # default, see #2301 for more details logits = self.classify( @@ -351,9 +341,7 @@ def loss( true_labels=true_labels, logits=logits, extra_metrics={ - "n_labelled_tensors": labelled_tensors[ - REGISTRY_KEYS.X_KEY - ].shape[0], + "n_labelled_tensors": labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], }, ) return LossOutput( @@ -391,9 +379,7 @@ def loss( true_labels=true_labels, logits=logits, ) - return LossOutput( - loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence - ) + return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence) def on_load(self, model: BaseModelClass): manager = model.get_anndata_manager(model.adata, required=True) diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index fbb0771dc7..75a499ba15 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -158,12 +158,8 @@ def __init__( "must provide library_log_means and library_log_vars." ) - self.register_buffer( - "library_log_means", torch.from_numpy(library_log_means).float() - ) - self.register_buffer( - "library_log_vars", torch.from_numpy(library_log_vars).float() - ) + self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float()) + self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float()) # parameters for prior on rate_back (background protein mean) if protein_background_prior_mean is None: @@ -175,9 +171,7 @@ def __init__( torch.clamp(torch.randn(n_input_proteins, n_batch), -10, 1) ) else: - self.background_pro_alpha = torch.nn.Parameter( - torch.randn(n_input_proteins) - ) + self.background_pro_alpha = torch.nn.Parameter(torch.randn(n_input_proteins)) self.background_pro_log_beta = torch.nn.Parameter( torch.clamp(torch.randn(n_input_proteins), -10, 1) ) @@ -207,13 +201,9 @@ def __init__( if self.protein_dispersion == "protein": self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins)) elif self.protein_dispersion == "protein-batch": - self.py_r = torch.nn.Parameter( - 2 * torch.rand(self.n_input_proteins, n_batch) - ) + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins, n_batch)) elif self.protein_dispersion == "protein-label": - self.py_r = torch.nn.Parameter( - 2 * torch.rand(self.n_input_proteins, n_labels) - ) + self.py_r = torch.nn.Parameter(2 * torch.rand(self.n_input_proteins, n_labels)) else: # protein-cell pass @@ -284,9 +274,7 @@ def get_sample_dispersion( type tensors of dispersions of the negative binomial distribution """ - outputs = self.inference( - x, y, batch_index=batch_index, label=label, n_samples=n_samples - ) + outputs = self.inference(x, y, batch_index=batch_index, label=label, n_samples=n_samples) px_r = outputs["px_"]["r"] py_r = outputs["py_"]["r"] return px_r, py_r @@ -324,9 +312,7 @@ def get_reconstruction_loss( ) reconst_loss_protein_full = -py_conditional.log_prob(y) if pro_batch_mask_minibatch is not None: - temp_pro_loss_full = ( - pro_batch_mask_minibatch.bool() * reconst_loss_protein_full - ) + temp_pro_loss_full = pro_batch_mask_minibatch.bool() * reconst_loss_protein_full reconst_loss_protein = temp_pro_loss_full.sum(dim=-1) else: reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1) @@ -366,9 +352,7 @@ def _get_generative_input(self, tensors, inference_outputs): cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY - size_factor = ( - tensors[size_factor_key] if size_factor_key in tensors.keys() else None - ) + size_factor = tensors[size_factor_key] if size_factor_key in tensors.keys() else None return { "z": z, @@ -571,9 +555,7 @@ def loss( generative_outputs, pro_recons_weight=1.0, # double check these defaults kl_weight=1.0, - ) -> tuple[ - torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor - ]: + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Returns the reconstruction loss and the Kullback divergences. Parameters @@ -623,9 +605,7 @@ def loss( local_library_log_means = F.linear( one_hot(batch_index, n_batch), self.library_log_means ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) + local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) kl_div_l_gene = kl( ql, Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), @@ -659,9 +639,7 @@ def loss( "kl_div_back_pro": kl_div_back_pro, } - return LossOutput( - loss=loss, reconstruction_loss=reconst_losses, kl_local=kl_local - ) + return LossOutput(loss=loss, reconstruction_loss=reconst_losses, kl_local=kl_local) @torch.inference_mode() def sample(self, tensors, n_samples=1): @@ -742,9 +720,7 @@ def marginal_ll(self, tensors, n_mc_samples, return_mean: bool = True): p_xy_zl = -(reconst_loss_gene + reconst_loss_protein) q_z_x = qz.log_prob(z).sum(dim=-1) q_mu_back = ( - Normal(py_["back_alpha"], py_["back_beta"]) - .log_prob(log_pro_back_mean) - .sum(dim=-1) + Normal(py_["back_alpha"], py_["back_beta"]).log_prob(log_pro_back_mean).sum(dim=-1) ) log_prob_sum += p_z + p_mu_back + p_xy_zl - q_z_x - q_mu_back diff --git a/scvi/module/_utils.py b/scvi/module/_utils.py index bdd7f30247..81416de30c 100644 --- a/scvi/module/_utils.py +++ b/scvi/module/_utils.py @@ -24,9 +24,7 @@ def broadcast_labels(y, *o, n_broadcast=-1): ys = enumerate_discrete(o[0], n_broadcast) new_o = iterate( o, - lambda x: x.repeat(n_broadcast, 1) - if len(x.size()) == 2 - else x.repeat(n_broadcast), + lambda x: x.repeat(n_broadcast, 1) if len(x.size()) == 2 else x.repeat(n_broadcast), ) else: ys = one_hot(y, n_broadcast) diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 6b2eb8c7c2..17f9645652 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -107,9 +107,7 @@ def __init__( n_continuous_cov: int = 0, n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: Tunable[float] = 0.1, - dispersion: Tunable[ - Literal["gene", "gene-batch", "gene-label", "gene-cell"] - ] = "gene", + dispersion: Tunable[Literal["gene", "gene-batch", "gene-label", "gene-cell"]] = "gene", log_variational: Tunable[bool] = True, gene_likelihood: Tunable[Literal["zinb", "nb", "poisson"]] = "zinb", latent_distribution: Tunable[Literal["normal", "ln"]] = "normal", @@ -145,12 +143,8 @@ def __init__( "must provide library_log_means and library_log_vars." ) - self.register_buffer( - "library_log_means", torch.from_numpy(library_log_means).float() - ) - self.register_buffer( - "library_log_vars", torch.from_numpy(library_log_vars).float() - ) + self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float()) + self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float()) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) @@ -255,9 +249,7 @@ def _get_inference_input( "observed_lib_size": observed_lib_size, } else: - raise NotImplementedError( - f"Unknown minified-data type: {self.minified_data_type}" - ) + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") return input_dict @@ -275,9 +267,7 @@ def _get_generative_input(self, tensors, inference_outputs): size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY size_factor = ( - torch.log(tensors[size_factor_key]) - if size_factor_key in tensors.keys() - else None + torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None ) input_dict = { @@ -299,12 +289,8 @@ def _compute_local_library_params(self, batch_index): log library sizes in the batch the cell corresponds to. """ n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear( - one_hot(batch_index, n_batch), self.library_log_means - ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) + local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means) + local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) return local_library_log_means, local_library_log_vars @auto_move_data @@ -337,9 +323,7 @@ def _regular_inference( qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) ql = None if not self.use_observed_lib_size: - ql, library_encoded = self.l_encoder( - encoder_input, batch_index, *categorical_input - ) + ql, library_encoded = self.l_encoder(encoder_input, batch_index, *categorical_input) library = library_encoded if n_samples > 1: @@ -367,9 +351,7 @@ def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): (n_samples, library.size(0), library.size(1)) ) else: - raise NotImplementedError( - f"Unknown minified-data type: {self.minified_data_type}" - ) + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") outputs = {"z": z, "qz_m": qzm, "qz_v": qzv, "ql": None, "library": library} return outputs @@ -464,9 +446,7 @@ def loss( ): """Computes the loss function for the model.""" x = tensors[REGISTRY_KEYS.X_KEY] - kl_divergence_z = kl(inference_outputs["qz"], generative_outputs["pz"]).sum( - dim=-1 - ) + kl_divergence_z = kl(inference_outputs["qz"], generative_outputs["pz"]).sum(dim=-1) if not self.use_observed_lib_size: kl_divergence_l = kl( inference_outputs["ql"], @@ -530,9 +510,7 @@ def sample( dist = generative_outputs["px"] if self.gene_likelihood == "poisson": - dist = torch.distributions.Poisson( - torch.clamp(dist.rate, max=max_poisson_rate) - ) + dist = torch.distributions.Poisson(torch.clamp(dist.rate, max=max_poisson_rate)) # (n_obs, n_vars) if n_samples == 1, else (n_samples, n_obs, n_vars) samples = dist.sample() @@ -587,9 +565,7 @@ def marginal_ll( # Log-probabilities p_z = ( - Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)) - .log_prob(z) - .sum(dim=-1) + Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)).log_prob(z).sum(dim=-1) ) p_x_zl = -reconst_loss q_z_x = qz.log_prob(z).sum(dim=-1) diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 94a14a2c6e..99b995d4fd 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -151,9 +151,7 @@ def inference(self, x, y, n_samples=1): if n_samples > 1: untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) - library = library.unsqueeze(0).expand( - (n_samples, library.size(0), library.size(1)) - ) + library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) outputs = {"z": z, "qz": qz, "library": library} return outputs @@ -189,9 +187,7 @@ def loss( scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) - return LossOutput( - loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z - ) + return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z) @torch.inference_mode() def sample( @@ -227,9 +223,7 @@ def sample( dist = NegativeBinomial(px_rate, logits=px_r) if n_samples > 1: - exprs = dist.sample().permute( - [1, 2, 0] - ) # Shape : (n_cells_batch, n_genes, n_samples) + exprs = dist.sample().permute([1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index 387b5ae7f8..ddab88ff56 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -91,9 +91,7 @@ def __post_init__(self): object.__setattr__(self, "loss", self.dict_sum(self.loss)) if self.n_obs_minibatch is None and self.reconstruction_loss is None: - raise ValueError( - "Must provide either n_obs_minibatch or reconstruction_loss" - ) + raise ValueError("Must provide either n_obs_minibatch or reconstruction_loss") default = 0 * self.loss if self.reconstruction_loss is None: @@ -103,9 +101,7 @@ def __post_init__(self): if self.kl_global is None: object.__setattr__(self, "kl_global", default) - object.__setattr__( - self, "reconstruction_loss", self._as_dict("reconstruction_loss") - ) + object.__setattr__(self, "reconstruction_loss", self._as_dict("reconstruction_loss")) object.__setattr__(self, "kl_local", self._as_dict("kl_local")) object.__setattr__(self, "kl_global", self._as_dict("kl_global")) object.__setattr__( @@ -118,16 +114,13 @@ def __post_init__(self): if self.reconstruction_loss is not None and self.n_obs_minibatch is None: rec_loss = self.reconstruction_loss - object.__setattr__( - self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0] - ) + object.__setattr__(self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0]) if self.classification_loss is not None and ( self.logits is None or self.true_labels is None ): raise ValueError( - "Must provide `logits` and `true_labels` if `classification_loss` is " - "provided." + "Must provide `logits` and `true_labels` if `classification_loss` is " "provided." ) @staticmethod @@ -186,9 +179,7 @@ def forward( generative_kwargs: dict | None = None, loss_kwargs: dict | None = None, compute_loss=True, - ) -> ( - tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput] - ): + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]: """Forward pass through the network. Parameters @@ -649,9 +640,7 @@ def load_state_dict(self, state_dict: dict[str, Any]): raise RuntimeError( "Train state is not set. Train for one iteration prior to loading state dict." ) - self.train_state = flax.serialization.from_state_dict( - self.train_state, state_dict - ) + self.train_state = flax.serialization.from_state_dict(self.train_state, state_dict) def to(self, device: Device): """Move module to device.""" @@ -679,9 +668,7 @@ def get_jit_inference_fn( self, get_inference_input_kwargs: dict[str, Any] | None = None, inference_kwargs: dict[str, Any] | None = None, - ) -> Callable[ - [dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray] - ]: + ) -> Callable[[dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]: """Create a method to run inference using the bound module. Parameters @@ -755,9 +742,7 @@ def _generic_forward( ) generative_outputs = module.generative(**generative_inputs, **generative_kwargs) if compute_loss: - losses = module.loss( - tensors, inference_outputs, generative_outputs, **loss_kwargs - ) + losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs) return inference_outputs, generative_outputs, losses else: return inference_outputs, generative_outputs diff --git a/scvi/module/base/_decorators.py b/scvi/module/base/_decorators.py index d384e1f970..5e6a51daeb 100644 --- a/scvi/module/base/_decorators.py +++ b/scvi/module/base/_decorators.py @@ -104,19 +104,14 @@ def _apply_to_collection( # Recursively apply to collection items elif isinstance(data, Mapping): return elem_type( - { - k: _apply_to_collection(v, dtype, function, *args, **kwargs) - for k, v in data.items() - } + {k: _apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()} ) elif isinstance(data, tuple) and hasattr(data, "_fields"): # named tuple return elem_type( *(_apply_to_collection(d, dtype, function, *args, **kwargs) for d in data) ) elif isinstance(data, Sequence) and not isinstance(data, str): - return elem_type( - [_apply_to_collection(d, dtype, function, *args, **kwargs) for d in data] - ) + return elem_type([_apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) # data is neither of dtype, nor a collection return data diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index 52954bf13c..4b84514f9d 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -95,9 +95,7 @@ def __init__( nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) - for i, (n_in, n_out) in enumerate( - zip(layers_dim[:-1], layers_dim[1:]) - ) + for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) ] ) ) @@ -167,18 +165,14 @@ def forward(self, x: torch.Tensor, *cat_list: int): if layer is not None: if isinstance(layer, nn.BatchNorm1d): if x.dim() == 3: - x = torch.cat( - [(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0 - ) + x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) else: x = layer(x) else: if isinstance(layer, nn.Linear) and self.inject_into_layer(i): if x.dim() == 3: one_hot_cat_list_layer = [ - o.unsqueeze(0).expand( - (x.size(0), o.size(0), o.size(1)) - ) + o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1))) for o in one_hot_cat_list ] else: @@ -457,9 +451,7 @@ def __init__( **kwargs, ) - def forward( - self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int - ): + def forward(self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int): """Forward pass.""" # The decoder returns values for the parameters of the ZINB distribution raw_px_scale = self.factor_regressor(z, *cat_list) @@ -659,9 +651,7 @@ def __init__( else: self.px_decoder_final = None - self.px_scale_decoder = nn.Sequential( - nn.Linear(n_in, n_output), nn.Softmax(dim=-1) - ) + self.px_scale_decoder = nn.Sequential(nn.Linear(n_in, n_output), nn.Softmax(dim=-1)) self.px_r_decoder = nn.Linear(n_in, n_output) self.px_dropout_decoder = nn.Linear(n_in, n_output) @@ -881,17 +871,13 @@ def forward(self, z: torch.Tensor, library_gene: torch.Tensor, *cat_list: int): py_back_cat_z = torch.cat([py_back, z], dim=-1) py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) - py_["back_beta"] = torch.exp( - self.py_back_mean_log_beta(py_back_cat_z, *cat_list) - ) + py_["back_beta"] = torch.exp(self.py_back_mean_log_beta(py_back_cat_z, *cat_list)) log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() py_["rate_back"] = torch.exp(log_pro_back_mean) py_fore = self.py_fore_decoder(z, *cat_list) py_fore_cat_z = torch.cat([py_fore, z], dim=-1) - py_["fore_scale"] = ( - self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8 - ) + py_["fore_scale"] = self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8 py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"] p_mixing = self.sigmoid_decoder(z, *cat_list) diff --git a/scvi/train/_callbacks.py b/scvi/train/_callbacks.py index 1a25f2a3d2..0741f3585f 100644 --- a/scvi/train/_callbacks.py +++ b/scvi/train/_callbacks.py @@ -281,8 +281,7 @@ def on_validation_epoch_end(self, trainer, pl_module): if current is None: warnings.warn( - f"Can save best module state only with {self.monitor} available, " - "skipping.", + f"Can save best module state only with {self.monitor} available, " "skipping.", RuntimeWarning, stacklevel=settings.warnings_stacklevel, ) diff --git a/scvi/train/_logger.py b/scvi/train/_logger.py index eebf905b07..357078cfdc 100644 --- a/scvi/train/_logger.py +++ b/scvi/train/_logger.py @@ -45,9 +45,7 @@ def save(self) -> None: class SimpleLogger(Logger): """Simple logger class.""" - def __init__( - self, name: str = "lightning_logs", version: Optional[Union[int, str]] = None - ): + def __init__(self, name: str = "lightning_logs", version: Optional[Union[int, str]] = None): super().__init__() self._name = name self._experiment = None diff --git a/scvi/train/_metrics.py b/scvi/train/_metrics.py index 6be04c82ba..0e1a4fccae 100644 --- a/scvi/train/_metrics.py +++ b/scvi/train/_metrics.py @@ -75,9 +75,7 @@ def update( Filters for the relevant metric's value and updates this metric. """ if self._N_OBS_MINIBATCH_KEY not in kwargs: - raise ValueError( - f"Missing {self._N_OBS_MINIBATCH_KEY} value in metrics update." - ) + raise ValueError(f"Missing {self._N_OBS_MINIBATCH_KEY} value in metrics update.") if self._name not in kwargs: raise ValueError(f"Missing {self._name} value in metrics update.") diff --git a/scvi/train/_trainer.py b/scvi/train/_trainer.py index ff7819c3bc..8cf0c5288f 100644 --- a/scvi/train/_trainer.py +++ b/scvi/train/_trainer.py @@ -138,9 +138,7 @@ def __init__( callbacks.append(early_stopping_callback) check_val_every_n_epoch = 1 - if enable_checkpointing and not any( - isinstance(c, SaveCheckpoint) for c in callbacks - ): + if enable_checkpointing and not any(isinstance(c, SaveCheckpoint) for c in callbacks): callbacks.append(SaveCheckpoint(monitor=checkpointing_monitor)) check_val_every_n_epoch = 1 elif any(isinstance(c, SaveCheckpoint) for c in callbacks): diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index 8c1601c37a..fac48ddfd0 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -182,9 +182,7 @@ def __init__( self.optimizer_creator = optimizer_creator if self.optimizer_name == "Custom" and self.optimizer_creator is None: - raise ValueError( - "If optimizer is 'Custom', `optimizer_creator` must be provided." - ) + raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") self._n_obs_training = None self._n_obs_validation = None @@ -221,9 +219,7 @@ def initialize_train_metrics(self): self.kl_local_train, self.kl_global_train, self.train_metrics, - ) = self._create_elbo_metric_components( - mode="train", n_total=self.n_obs_training - ) + ) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training) self.elbo_train.reset() def initialize_val_metrics(self): @@ -234,9 +230,7 @@ def initialize_val_metrics(self): self.kl_local_val, self.kl_global_val, self.val_metrics, - ) = self._create_elbo_metric_components( - mode="validation", n_total=self.n_obs_validation - ) + ) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation) self.elbo_val.reset() @property @@ -372,9 +366,7 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn( - self, optimizer_cls: Union[torch.optim.Adam, torch.optim.AdamW] - ): + def _optimizer_creator_fn(self, optimizer_cls: Union[torch.optim.Adam, torch.optim.AdamW]): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -572,9 +564,7 @@ def training_step(self, batch, batch_idx): else: opt1, opt2 = opts - inference_outputs, _, scvi_loss = self.forward( - batch, loss_kwargs=self.loss_kwargs - ) + inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) z = inference_outputs["z"] loss = scvi_loss.loss # fool classifier if doing adversarial training @@ -638,9 +628,7 @@ def configure_optimizers(self): ) if self.adversarial_classifier is not False: - params2 = filter( - lambda p: p.requires_grad, self.adversarial_classifier.parameters() - ) + params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) optimizer2 = torch.optim.Adam( params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay ) @@ -907,16 +895,12 @@ def __init__( self.n_epochs_kl_warmup = n_epochs_kl_warmup self.use_kl_weight = False if isinstance(self.module.model, PyroModule): - self.use_kl_weight = ( - "kl_weight" in signature(self.module.model.forward).parameters - ) + self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters elif callable(self.module.model): self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters self.scale_elbo = scale_elbo self.scale_fn = ( - lambda obj: pyro.poutine.scale(obj, self.scale_elbo) - if self.scale_elbo != 1 - else obj + lambda obj: pyro.poutine.scale(obj, self.scale_elbo) if self.scale_elbo != 1 else obj ) self.differentiable_loss_fn = self.loss_fn.differentiable_loss self.training_step_outputs = [] @@ -1133,9 +1117,7 @@ def __init__( self.loss_fn = loss() if self.module.logits is False and loss == torch.nn.CrossEntropyLoss: - raise UserWarning( - "classifier should return logits when using CrossEntropyLoss." - ) + raise UserWarning("classifier should return logits when using CrossEntropyLoss.") def forward(self, *args, **kwargs): """Passthrough to the module's forward function.""" @@ -1165,9 +1147,7 @@ def configure_optimizers(self): optim_cls = torch.optim.AdamW else: raise ValueError("Optimizer not understood.") - optimizer = optim_cls( - params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay - ) + optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay) return optimizer @@ -1233,11 +1213,7 @@ def __init__( def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" - clip_by = ( - optax.clip_by_global_norm(self.max_norm) - if self.max_norm - else optax.identity() - ) + clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() if self.optimizer_name == "Adam": # Replicates PyTorch Adam defaults optim = optax.chain( @@ -1291,9 +1267,9 @@ def loss_fn(params): loss = loss_output.loss return loss, (loss_output, new_model_state) - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( - loss_fn, has_aux=True - )(state.params) + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params + ) new_state = state.apply_gradients(grads=grads, state=new_model_state) return new_state, loss, loss_output diff --git a/scvi/utils/_dependencies.py b/scvi/utils/_dependencies.py index d094760dab..c37f827031 100644 --- a/scvi/utils/_dependencies.py +++ b/scvi/utils/_dependencies.py @@ -11,9 +11,7 @@ def error_on_missing_dependencies(*modules): except ImportError: missing_modules.append(module) if len(missing_modules) > 0: - raise ModuleNotFoundError( - f"Please install {missing_modules} to use this functionality." - ) + raise ModuleNotFoundError(f"Please install {missing_modules} to use this functionality.") def dependencies(*modules) -> Callable: diff --git a/tests/criticism/test_criticism.py b/tests/criticism/test_criticism.py index 677a513d4b..315bffea64 100644 --- a/tests/criticism/test_criticism.py +++ b/tests/criticism/test_criticism.py @@ -11,9 +11,7 @@ from scvi.model import SCVI -def get_ppc_with_samples( - adata: AnnData, n_samples: int = 2, indices: list[int] | None = None -): +def get_ppc_with_samples(adata: AnnData, n_samples: int = 2, indices: list[int] | None = None): # create and train models SCVI.setup_anndata( adata, diff --git a/tests/data/test_anndata.py b/tests/data/test_anndata.py index 439198a0e9..876b0dfb76 100644 --- a/tests/data/test_anndata.py +++ b/tests/data/test_anndata.py @@ -192,18 +192,13 @@ def test_register_new_fields_with_transferred_manager(adata): # Should have protein field cdata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY) - np.testing.assert_array_equal( - cdata.obs["_scvi_batch"].values, adata.obs["_scvi_batch"].values - ) + np.testing.assert_array_equal(cdata.obs["_scvi_batch"].values, adata.obs["_scvi_batch"].values) def test_update_setup_args(adata): adata_manager = generic_setup_adata_manager(adata) adata_manager.update_setup_method_args({"test_arg": "test_val"}) - assert ( - "test_arg" - in adata_manager._get_setup_method_args()[_constants._SETUP_ARGS_KEY].keys() - ) + assert "test_arg" in adata_manager._get_setup_method_args()[_constants._SETUP_ARGS_KEY].keys() def test_data_format(adata): @@ -268,9 +263,7 @@ def test_setup_anndata(adata): adata_manager.get_from_registry(REGISTRY_KEYS.LABELS_KEY), np.array(adata.obs["labels"].cat.codes).reshape((-1, 1)), ) - np.testing.assert_array_equal( - adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), adata.X - ) + np.testing.assert_array_equal(adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), adata.X) np.testing.assert_array_equal( adata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY), adata.obsm["protein_expression"], @@ -311,9 +304,7 @@ def test_setup_anndata_layer(adata): adata.layers["X"] = true_x adata.X = np.ones_like(adata.X) adata_manager = generic_setup_adata_manager(adata, layer="X") - np.testing.assert_array_equal( - adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), true_x - ) + np.testing.assert_array_equal(adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), true_x) def test_setup_anndata_create_label_batch(adata): @@ -394,10 +385,7 @@ def test_extra_covariates_transfer(adata): # give it a new category bdata.obs["cat1"] = 6 bdata_manager = adata_manager.transfer_fields(bdata, extend_categories=True) - assert ( - bdata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings["cat1"][-1] - == 6 - ) + assert bdata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings["cat1"][-1] == 6 def test_anntorchdataset_getitem(adata): @@ -463,9 +451,7 @@ def test_anntorchdataset_numpy_sparse(adata): def test_anntorchdataset_getitem_numpy_sparse(adata): # check AnnTorchDataset returns numpy array if pro exp was sparse - adata.obsm["protein_expression"] = sparse.csr_matrix( - adata.obsm["protein_expression"] - ) + adata.obsm["protein_expression"] = sparse.csr_matrix(adata.obsm["protein_expression"]) adata_manager = generic_setup_adata_manager( adata, batch_key="batch", protein_expression_obsm_key="protein_expression" ) diff --git a/tests/data/test_anntorchdataset.py b/tests/data/test_anntorchdataset.py index 441db86379..9a5048cd24 100644 --- a/tests/data/test_anntorchdataset.py +++ b/tests/data/test_anntorchdataset.py @@ -59,9 +59,7 @@ def test_getitem_tensors(): assert list(dataset.keys_and_dtypes.keys()) == [REGISTRY_KEYS.X_KEY] # dict - dataset = manager.create_torch_dataset( - data_and_attributes={REGISTRY_KEYS.X_KEY: np.float64} - ) + dataset = manager.create_torch_dataset(data_and_attributes={REGISTRY_KEYS.X_KEY: np.float64}) assert isinstance(dataset.keys_and_dtypes, dict) assert list(dataset.keys_and_dtypes.keys()) == [REGISTRY_KEYS.X_KEY] assert dataset.keys_and_dtypes[REGISTRY_KEYS.X_KEY] == np.float64 diff --git a/tests/data/test_dataset10X.py b/tests/data/test_dataset10X.py index 2ae32f63fe..d310432a18 100644 --- a/tests/data/test_dataset10X.py +++ b/tests/data/test_dataset10X.py @@ -37,9 +37,7 @@ def test_pbmc_cite(save_path): tar = tarfile.open(file_path, "r:gz") tar.extractall(path=sp) tar.close() - dataset = sc.read_10x_mtx( - os.path.join(sp, "filtered_feature_bc_matrix"), gex_only=False - ) + dataset = sc.read_10x_mtx(os.path.join(sp, "filtered_feature_bc_matrix"), gex_only=False) organize_cite_seq_10x(dataset) unsupervised_training_one_epoch(dataset) diff --git a/tests/data/test_mudata.py b/tests/data/test_mudata.py index e40ef657c1..b643db8546 100644 --- a/tests/data/test_mudata.py +++ b/tests/data/test_mudata.py @@ -33,9 +33,7 @@ def test_setup_mudata(): protein_expression_mod="protein", protein_expression_layer=None, ) - np.testing.assert_array_equal( - adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), adata.X - ) + np.testing.assert_array_equal(adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), adata.X) np.testing.assert_array_equal( adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY), np.array(adata.obs["_scvi_batch"]).reshape((-1, 1)), @@ -73,9 +71,7 @@ def test_setup_mudata_unpaired(): protein_adata = synthetic_iid(batch_size=100) mdata = mudata.MuData({"rna": adata, "protein": protein_adata}) with pytest.raises(ValueError): - generic_setup_mudata_manager( - mdata, layer_mod="rna", protein_expression_mod="protein" - ) + generic_setup_mudata_manager(mdata, layer_mod="rna", protein_expression_mod="protein") # Pad unpaired with zeros unpaired_adata = adata[mdata.obsm["rna"] & ~(mdata.obsm["protein"])] @@ -87,9 +83,7 @@ def test_setup_mudata_unpaired(): ) mdata.mod["protein"] = anndata.concat([protein_adata, pad_adata]) mdata.update() - generic_setup_mudata_manager( - mdata, layer_mod="rna", protein_expression_mod="protein" - ) + generic_setup_mudata_manager(mdata, layer_mod="rna", protein_expression_mod="protein") def test_setup_mudata_anndata(): @@ -107,9 +101,7 @@ def test_setup_mudata_layer(): adata.X = np.ones_like(adata.X) mdata = mudata.MuData({"rna": adata}) adata_manager = generic_setup_mudata_manager(mdata, layer_mod="rna", layer="X") - np.testing.assert_array_equal( - adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), true_x - ) + np.testing.assert_array_equal(adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY), true_x) def test_setup_mudata_default_batch(): @@ -133,9 +125,7 @@ def test_setup_mudata_nan_batch(): adata.obs["batch"][:10] = np.nan mdata = mudata.MuData({"rna": adata}) with pytest.raises(ValueError): - generic_setup_mudata_manager( - mdata, layer_mod="rna", batch_mod="rna", batch_key="batch" - ) + generic_setup_mudata_manager(mdata, layer_mod="rna", batch_mod="rna", batch_key="batch") def test_save_setup_mudata(save_path): @@ -259,9 +249,7 @@ def test_transfer_fields_diff_batch_mapping(): mdata1, layer_mod="rna", batch_mod="rna", batch_key="batch" ) adata1_manager.transfer_fields(mdata2) - batch_mapping = adata1_manager.get_state_registry( - REGISTRY_KEYS.BATCH_KEY - ).categorical_mapping + batch_mapping = adata1_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).categorical_mapping print(batch_mapping) correct_batch = np.where(batch_mapping == "batch_1")[0][0] assert adata2.obs["_scvi_batch"][0] == correct_batch @@ -330,12 +318,7 @@ def test_transfer_fields_covariates(): # give it a new category adata2.obs["cat1"] = 6 adata_manager2 = adata_manager.transfer_fields(mdata2, extend_categories=True) - assert ( - adata_manager2.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings["cat1"][ - -1 - ] - == 6 - ) + assert adata_manager2.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings["cat1"][-1] == 6 def test_data_format(): diff --git a/tests/data/utils.py b/tests/data/utils.py index 439a89d58f..73f14b3cc2 100644 --- a/tests/data/utils.py +++ b/tests/data/utils.py @@ -67,9 +67,7 @@ def generic_setup_adata_manager( is_count_data=True, ) ) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata) return adata_manager @@ -85,13 +83,9 @@ def scanvi_setup_adata_manager( setup_method_args = {_MODEL_NAME_KEY: "TestModel", _SETUP_ARGS_KEY: setup_args} anndata_fields = [ CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - LabelsWithUnlabeledObsField( - REGISTRY_KEYS.LABELS_KEY, labels_key, unlabeled_category - ), + LabelsWithUnlabeledObsField(REGISTRY_KEYS.LABELS_KEY, labels_key, unlabeled_category), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata) return adata_manager @@ -113,9 +107,7 @@ def generic_setup_mudata_manager( setup_args.pop("mdata") setup_method_args = {_MODEL_NAME_KEY: "TestModel", _SETUP_ARGS_KEY: setup_args} - batch_field = MuDataCategoricalObsField( - REGISTRY_KEYS.BATCH_KEY, batch_key, mod_key=batch_mod - ) + batch_field = MuDataCategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key, mod_key=batch_mod) anndata_fields = [ MuDataLayerField( REGISTRY_KEYS.X_KEY, @@ -147,8 +139,6 @@ def generic_setup_mudata_manager( batch_field=batch_field, ) ) - mdata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + mdata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) mdata_manager.register_fields(mdata) return mdata_manager diff --git a/tests/dataloaders/sparse_utils.py b/tests/dataloaders/sparse_utils.py index 834a2f1149..21d6b8026f 100644 --- a/tests/dataloaders/sparse_utils.py +++ b/tests/dataloaders/sparse_utils.py @@ -10,9 +10,7 @@ class TestSparseDataSplitter(scvi.dataloaders.DataSplitter): - def __init__( - self, *args, expected_sparse_layout: Literal["csr", "csc"] = None, **kwargs - ): + def __init__(self, *args, expected_sparse_layout: Literal["csr", "csc"] = None, **kwargs): if expected_sparse_layout == "csr": self.expected_sparse_layout = torch.sparse_csr elif expected_sparse_layout == "csc": @@ -82,9 +80,7 @@ def setup_anndata( fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata) cls.register_manager(adata_manager) diff --git a/tests/dataloaders/test_dataloaders.py b/tests/dataloaders/test_dataloaders.py index 8e572bbec8..4fba0fc498 100644 --- a/tests/dataloaders/test_dataloaders.py +++ b/tests/dataloaders/test_dataloaders.py @@ -54,9 +54,7 @@ def test_semisuperviseddataloader_subsampling( n_labels: int = 3, n_samples_per_label: int = 10, ): - adata = scvi.data.synthetic_iid( - batch_size=batch_size, n_batches=n_batches, n_labels=n_labels - ) + adata = scvi.data.synthetic_iid(batch_size=batch_size, n_batches=n_batches, n_labels=n_labels) adata.obs["indices"] = np.arange(adata.n_obs) original_training_plan_cls = scvi.model.SCANVI._training_plan_cls diff --git a/tests/dataloaders/test_samplers.py b/tests/dataloaders/test_samplers.py index 4e4611c080..288e18fb94 100644 --- a/tests/dataloaders/test_samplers.py +++ b/tests/dataloaders/test_samplers.py @@ -143,7 +143,5 @@ def test_batchdistributedsampler_indices( assert len(sampler_indices[i].intersection(sampler_indices[j])) == 0 # check that all indices are covered - covered_indices = np.concatenate( - [np.array(list(indices)) for indices in sampler_indices] - ) + covered_indices = np.concatenate([np.array(list(indices)) for indices in sampler_indices]) assert len(covered_indices) == len(dataset) diff --git a/tests/distributions/test_negative_binomial.py b/tests/distributions/test_negative_binomial.py index e0a72146d3..af9215fce3 100644 --- a/tests/distributions/test_negative_binomial.py +++ b/tests/distributions/test_negative_binomial.py @@ -71,9 +71,7 @@ def test_zinb_distribution(): dist2.log_prob(0.5 * x) # ensures float values raise warning # test with no scale - dist1 = ZeroInflatedNegativeBinomial( - mu=mu, theta=theta, zi_logits=pi, validate_args=True - ) + dist1 = ZeroInflatedNegativeBinomial(mu=mu, theta=theta, zi_logits=pi, validate_args=True) dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True) dist1.__repr__() dist2.__repr__() diff --git a/tests/external/contrastivevi/test_contrastive_dataloaders.py b/tests/external/contrastivevi/test_contrastive_dataloaders.py index c7868b75f4..68648dbde0 100644 --- a/tests/external/contrastivevi/test_contrastive_dataloaders.py +++ b/tests/external/contrastivevi/test_contrastive_dataloaders.py @@ -24,9 +24,7 @@ def test_contrastive_dataloader( ).unsqueeze(1) expected_target_indices = mock_target_indices[:batch_size] - expected_target_input = torch.Tensor( - adata.layers["raw_counts"][expected_target_indices, :] - ) + expected_target_input = torch.Tensor(adata.layers["raw_counts"][expected_target_indices, :]) expected_target_labels = torch.LongTensor( adata.obs["_scvi_labels"].iloc[expected_target_indices] ).unsqueeze(1) @@ -47,15 +45,9 @@ def test_contrastive_dataloader( assert "background" in batch.keys() assert "target" in batch.keys() - assert torch.equal( - batch["background"][REGISTRY_KEYS.X_KEY], expected_background_input - ) - assert torch.equal( - batch["background"][REGISTRY_KEYS.LABELS_KEY], expected_background_labels - ) - assert torch.equal( - batch["background"][REGISTRY_KEYS.BATCH_KEY], expected_background_batch - ) + assert torch.equal(batch["background"][REGISTRY_KEYS.X_KEY], expected_background_input) + assert torch.equal(batch["background"][REGISTRY_KEYS.LABELS_KEY], expected_background_labels) + assert torch.equal(batch["background"][REGISTRY_KEYS.BATCH_KEY], expected_background_batch) assert torch.equal(batch["target"][REGISTRY_KEYS.X_KEY], expected_target_input) assert torch.equal(batch["target"][REGISTRY_KEYS.LABELS_KEY], expected_target_labels) diff --git a/tests/external/contrastivevi/test_contrastivevae.py b/tests/external/contrastivevi/test_contrastivevae.py index d3f2f9df67..90158f4487 100644 --- a/tests/external/contrastivevi/test_contrastivevae.py +++ b/tests/external/contrastivevi/test_contrastivevae.py @@ -90,9 +90,7 @@ def mock_target_batch(mock_contrastive_adata_manager, mock_target_indices): return target_batch.item() -@pytest.fixture( - params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"] -) +@pytest.fixture(params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"]) def mock_contrastive_vae( mock_n_input, mock_n_batch, mock_library_log_means, mock_library_log_vars, request ): @@ -130,9 +128,7 @@ def mock_contrastive_vi_data( ): concat_tensors = mock_contrastive_batch inference_input = mock_contrastive_vae._get_inference_input(concat_tensors) - inference_outputs = mock_contrastive_vae.inference( - **inference_input, n_samples=request.param - ) + inference_outputs = mock_contrastive_vae.inference(**inference_input, n_samples=request.param) generative_input = mock_contrastive_vae._get_generative_input( concat_tensors, inference_outputs ) @@ -171,9 +167,7 @@ def test_get_inference_input( mock_background_batch, mock_target_batch, ): - inference_input = mock_contrastive_vae._get_inference_input( - mock_contrastive_batch - ) + inference_input = mock_contrastive_vae._get_inference_input(mock_contrastive_batch) for data_source in REQUIRED_DATA_SOURCES: assert data_source in inference_input.keys() @@ -244,9 +238,7 @@ def test_inference( mock_contrastive_vae, mock_contrastive_batch, ): - inference_input = mock_contrastive_vae._get_inference_input( - mock_contrastive_batch - ) + inference_input = mock_contrastive_vae._get_inference_input(mock_contrastive_batch) inference_outputs = mock_contrastive_vae.inference(**inference_input) for data_source in REQUIRED_DATA_SOURCES: assert data_source in inference_outputs.keys() @@ -263,10 +255,8 @@ def test_get_generative_input_from_concat_tensors( mock_contrastive_batch, mock_n_input, ): - generative_input = ( - mock_contrastive_vae._get_generative_input_from_concat_tensors( - mock_contrastive_batch, "background" - ) + generative_input = mock_contrastive_vae._get_generative_input_from_concat_tensors( + mock_contrastive_batch, "background" ) for key in REQUIRED_GENERATIVE_INPUT_KEYS_FROM_CONCAT_TENSORS: assert key in generative_input.keys() @@ -281,10 +271,8 @@ def test_get_generative_input_from_inference_outputs( inference_outputs = mock_contrastive_vae.inference( **mock_contrastive_vae._get_inference_input(mock_contrastive_batch) ) - generative_input = ( - mock_contrastive_vae._get_generative_input_from_inference_outputs( - inference_outputs, REQUIRED_DATA_SOURCES[0] - ) + generative_input = mock_contrastive_vae._get_generative_input_from_inference_outputs( + inference_outputs, REQUIRED_DATA_SOURCES[0] ) for key in REQUIRED_GENERATIVE_INPUT_KEYS_FROM_INFERENCE_OUTPUTS: assert key in generative_input @@ -325,9 +313,7 @@ def test_get_generative_input( assert key in target_generative_input_keys # Check background vs. target labels are consistent. - assert ( - background_generative_input["batch_index"] != mock_background_batch - ).sum() == 0 + assert (background_generative_input["batch_index"] != mock_background_batch).sum() == 0 assert (target_generative_input["batch_index"] != mock_target_batch).sum() == 0 @pytest.mark.parametrize("n_samples", [1, 2]) @@ -388,9 +374,7 @@ def test_reconstruction_loss(self, mock_contrastive_vae, mock_contrastive_vi_dat px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] - recon_loss = mock_contrastive_vae.reconstruction_loss( - x, px_rate, px_r, px_dropout - ) + recon_loss = mock_contrastive_vae.reconstruction_loss(x, px_rate, px_r, px_dropout) if len(px_rate.shape) == 3: expected_shape = px_rate.shape[:2] else: @@ -416,18 +400,16 @@ def test_library_kl_divergence(self, mock_contrastive_vae, mock_contrastive_vi_d ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] library = inference_outputs["library"] - kl_library = mock_contrastive_vae.library_kl_divergence( - batch_index, ql_m, ql_v, library - ) + kl_library = mock_contrastive_vae.library_kl_divergence(batch_index, ql_m, ql_v, library) expected_shape = library.shape[:-1] assert kl_library.shape == expected_shape if mock_contrastive_vae.use_observed_lib_size: assert torch.equal(kl_library, torch.zeros(expected_shape)) def test_loss(self, mock_contrastive_vae, mock_contrastive_vi_data): - expected_shape = mock_contrastive_vi_data["inference_outputs"]["background"][ - "qz_m" - ].shape[:-1] + expected_shape = mock_contrastive_vi_data["inference_outputs"]["background"]["qz_m"].shape[ + :-1 + ] losses = mock_contrastive_vae.loss( mock_contrastive_vi_data["concat_tensors"], mock_contrastive_vi_data["inference_outputs"], diff --git a/tests/external/contrastivevi/test_contrastivevi.py b/tests/external/contrastivevi/test_contrastivevi.py index 8babdc3802..82b54697e8 100644 --- a/tests/external/contrastivevi/test_contrastivevi.py +++ b/tests/external/contrastivevi/test_contrastivevi.py @@ -13,9 +13,7 @@ def copy_module_state_dict(module) -> dict[str, torch.Tensor]: return copy -@pytest.fixture( - params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"] -) +@pytest.fixture(params=[True, False], ids=["with_observed_lib_size", "without_observed_lib_size"]) def mock_contrastive_vi_model( mock_contrastive_adata, request, @@ -70,19 +68,14 @@ def test_train( ) trained_state_dict = copy_module_state_dict(mock_contrastive_vi_model.module) for param_key in mock_contrastive_vi_model.module.state_dict().keys(): - is_library_param = ( - param_key == "library_log_means" or param_key == "library_log_vars" - ) + is_library_param = param_key == "library_log_means" or param_key == "library_log_vars" is_px_r_decoder_param = "px_r_decoder" in param_key is_l_encoder_param = "l_encoder" in param_key if ( is_library_param or is_px_r_decoder_param - or ( - is_l_encoder_param - and mock_contrastive_vi_model.module.use_observed_lib_size - ) + or (is_l_encoder_param and mock_contrastive_vi_model.module.use_observed_lib_size) ): # There are three cases where parameters are not updated. # 1. Library means and vars are derived from input data and should @@ -92,19 +85,13 @@ def test_train( # decoder are not used and should not be updated. # 3. When observed library size is used, the library encoder is not # used and its parameters not updated. - assert torch.equal( - init_state_dict[param_key], trained_state_dict[param_key] - ) + assert torch.equal(init_state_dict[param_key], trained_state_dict[param_key]) else: # Other parameters should be updated after training. - assert not torch.equal( - init_state_dict[param_key], trained_state_dict[param_key] - ) + assert not torch.equal(init_state_dict[param_key], trained_state_dict[param_key]) @pytest.mark.parametrize("representation_kind", ["background", "salient"]) - def test_get_latent_representation( - self, mock_contrastive_vi_model, representation_kind - ): + def test_get_latent_representation(self, mock_contrastive_vi_model, representation_kind): n_cells = mock_contrastive_vi_model.adata.n_obs if representation_kind == "background": n_latent = mock_contrastive_vi_model.module.n_background_latent @@ -116,9 +103,7 @@ def test_get_latent_representation( assert representation.shape == (n_cells, n_latent) @pytest.mark.parametrize("representation_kind", ["background", "salient"]) - def test_get_normalized_expression( - self, mock_contrastive_vi_model, representation_kind - ): + def test_get_normalized_expression(self, mock_contrastive_vi_model, representation_kind): n_samples = 50 n_cells = mock_contrastive_vi_model.adata.n_obs n_genes = mock_contrastive_vi_model.adata.n_vars diff --git a/tests/external/gimvi/test_gimvi.py b/tests/external/gimvi/test_gimvi.py index 956ce7b349..a343e8f937 100644 --- a/tests/external/gimvi/test_gimvi.py +++ b/tests/external/gimvi/test_gimvi.py @@ -37,9 +37,7 @@ def legacy_save( dataset_names = ["seq", "spatial"] for i in range(len(model.adatas)): dataset_name = dataset_names[i] - save_path = os.path.join( - dir_path, f"{file_name_prefix}adata_{dataset_name}.h5ad" - ) + save_path = os.path.join(dir_path, f"{file_name_prefix}adata_{dataset_name}.h5ad") model.adatas[i].write(save_path) varnames_save_path = os.path.join( dir_path, f"{file_name_prefix}var_names_{dataset_name}.csv" @@ -80,9 +78,7 @@ def legacy_save( tmp_adata = scvi.data.synthetic_iid(n_genes=200) tmp_adata2 = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): - GIMVI.load( - save_path, adata_seq=tmp_adata, adata_spatial=tmp_adata2, prefix=prefix - ) + GIMVI.load(save_path, adata_seq=tmp_adata, adata_spatial=tmp_adata2, prefix=prefix) model = GIMVI.load(save_path, adata_seq=adata, adata_spatial=adata2, prefix=prefix) z2 = model.get_latent_representation([adata]) np.testing.assert_array_equal(z1, z2) @@ -98,22 +94,16 @@ def legacy_save( # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): - GIMVI.load( - legacy_save_path, adata_seq=adata, adata_spatial=adata2, prefix=prefix - ) + GIMVI.load(legacy_save_path, adata_seq=adata, adata_spatial=adata2, prefix=prefix) GIMVI.convert_legacy_save( legacy_save_path, legacy_save_path, overwrite=True, prefix=prefix, ) - m = GIMVI.load( - legacy_save_path, adata_seq=adata, adata_spatial=adata2, prefix=prefix - ) + m = GIMVI.load(legacy_save_path, adata_seq=adata, adata_spatial=adata2, prefix=prefix) m.train(1) diff --git a/tests/external/scbasset/test_scbasset.py b/tests/external/scbasset/test_scbasset.py index 46e8217a73..135d813446 100644 --- a/tests/external/scbasset/test_scbasset.py +++ b/tests/external/scbasset/test_scbasset.py @@ -62,14 +62,8 @@ def test_scbasset_tf(save_path): # model.train(max_epochs=2, early_stopping=True) model.is_trained = True - motif_seqs, bg_seqs = model._get_motif_library( - tf="MYOD1", motif_dir=save_path, genome="human" - ) + motif_seqs, bg_seqs = model._get_motif_library(tf="MYOD1", motif_dir=save_path, genome="human") - model.get_tf_activity( - tf="MYOD1", motif_dir=save_path, genome="human", lib_size_norm=True - ) - model.get_tf_activity( - tf="MYOD1", motif_dir=save_path, genome="human", lib_size_norm=False - ) + model.get_tf_activity(tf="MYOD1", motif_dir=save_path, genome="human", lib_size_norm=True) + model.get_tf_activity(tf="MYOD1", motif_dir=save_path, genome="human", lib_size_norm=False) return diff --git a/tests/external/solo/test_solo.py b/tests/external/solo/test_solo.py index 01a1b1682d..27a87a378d 100644 --- a/tests/external/solo/test_solo.py +++ b/tests/external/solo/test_solo.py @@ -55,9 +55,7 @@ def test_solo_scvi_labels(): def test_solo_from_scvi_errors(): adata = synthetic_iid() adata.obs["continuous_covariate"] = np.random.normal(size=(adata.n_obs, 1)) - adata.obs["categorical_covariate"] = np.random.choice( - ["a", "b", "c"], size=(adata.n_obs, 1) - ) + adata.obs["categorical_covariate"] = np.random.choice(["a", "b", "c"], size=(adata.n_obs, 1)) # no batch key, restrict_to_batch SCVI.setup_anndata(adata, labels_key="labels") diff --git a/tests/external/stereoscope/test_stereoscope.py b/tests/external/stereoscope/test_stereoscope.py index 06a2e291c8..3241f3db76 100644 --- a/tests/external/stereoscope/test_stereoscope.py +++ b/tests/external/stereoscope/test_stereoscope.py @@ -29,9 +29,7 @@ def test_stereoscope(save_path): SpatialStereoscope.setup_anndata( dataset, ) - st_model = SpatialStereoscope.from_rna_model( - dataset, sc_model, prior_weight="minibatch" - ) + st_model = SpatialStereoscope.from_rna_model(dataset, sc_model, prior_weight="minibatch") st_model.train(max_epochs=1) st_model.get_proportions() # test save/load diff --git a/tests/external/tangram/test_tangram.py b/tests/external/tangram/test_tangram.py index a0b10625c8..075d8a0bc9 100644 --- a/tests/external/tangram/test_tangram.py +++ b/tests/external/tangram/test_tangram.py @@ -17,9 +17,7 @@ def _get_mdata(sparse_format: Optional[str] = None): mdata = mudata.MuData({"sc": dataset1, "sp": dataset2}) ad_sp = mdata.mod["sp"] rna_count_per_spot = np.asarray(ad_sp.X.sum(axis=1)).squeeze() - ad_sp.obs["rna_count_based_density"] = rna_count_per_spot / np.sum( - rna_count_per_spot - ) + ad_sp.obs["rna_count_based_density"] = rna_count_per_spot / np.sum(rna_count_per_spot) ad_sp.obs["bad_prior"] = np.random.uniform(size=ad_sp.n_obs) return mdata diff --git a/tests/hub/test_hub_metadata.py b/tests/hub/test_hub_metadata.py index 7da0c173cd..bf8cac951b 100644 --- a/tests/hub/test_hub_metadata.py +++ b/tests/hub/test_hub_metadata.py @@ -51,9 +51,7 @@ def test_hub_metadata(request, save_path): model = prep_model() test_save_path = os.path.join(save_path, request.node.name) model.save(test_save_path, overwrite=True) - hm = HubMetadata.from_dir( - test_save_path, anndata_version="0.9.0", model_parent_module="foo" - ) + hm = HubMetadata.from_dir(test_save_path, anndata_version="0.9.0", model_parent_module="foo") assert hm.scvi_version == scvi.__version__ assert hm.anndata_version == "0.9.0" assert hm.training_data_url is None @@ -62,9 +60,7 @@ def test_hub_metadata(request, save_path): def test_hub_metadata_invalid_url(): with pytest.raises(ValueError): - HubMetadata( - "0.17.4", "0.8.0", "SCVI", training_data_url="https//invalid_url.org/" - ) + HubMetadata("0.17.4", "0.8.0", "SCVI", training_data_url="https//invalid_url.org/") def test_hub_modelcardhelper(request, save_path): @@ -74,9 +70,7 @@ def test_hub_modelcardhelper(request, save_path): license_info="cc-by-4.0", model_cls_name="SCVI", model_init_params=model.init_params_, - model_setup_anndata_args=model.adata_manager._get_setup_method_args()[ - "setup_args" - ], + model_setup_anndata_args=model.adata_manager._get_setup_method_args()["setup_args"], model_summary_stats=model.summary_stats, model_data_registry=model.adata_manager.data_registry, scvi_version="0.17.8", @@ -162,8 +156,7 @@ def test_hub_modelcardhelper(request, save_path): assert hmch.model_cls_name == "SCVI" assert hmch.model_init_params == model.init_params_ assert ( - hmch.model_setup_anndata_args - == model.adata_manager._get_setup_method_args()["setup_args"] + hmch.model_setup_anndata_args == model.adata_manager._get_setup_method_args()["setup_args"] ) assert hmch.model_summary_stats == dict(model.summary_stats) assert hmch.model_data_registry == dict(model.adata_manager.data_registry) diff --git a/tests/hub/test_hub_model.py b/tests/hub/test_hub_model.py index eb77ad84fc..da8f0bcd4d 100644 --- a/tests/hub/test_hub_model.py +++ b/tests/hub/test_hub_model.py @@ -260,15 +260,11 @@ def test_hub_model_pull_from_hf(): assert hub_model.model is not None assert hub_model.adata is not None - hub_model = HubModel.pull_from_huggingface_hub( - repo_name="scvi-tools/test-scvi-minified" - ) + hub_model = HubModel.pull_from_huggingface_hub(repo_name="scvi-tools/test-scvi-minified") assert hub_model.model is not None assert hub_model.adata is not None - hub_model = HubModel.pull_from_huggingface_hub( - repo_name="scvi-tools/test-scvi-no-anndata" - ) + hub_model = HubModel.pull_from_huggingface_hub(repo_name="scvi-tools/test-scvi-no-anndata") with pytest.raises(ValueError): _ = hub_model.model @@ -285,12 +281,8 @@ def test_hub_model_push_to_s3(save_path: str): hub_model = prep_scvi_no_anndata_hub_model(save_path) with pytest.raises(ValueError): - hub_model.push_to_s3( - "scvi-tools", "tests/hub/test-scvi-no-anndata", push_anndata=True - ) - hub_model.push_to_s3( - "scvi-tools", "tests/hub/test-scvi-no-anndata", push_anndata=False - ) + hub_model.push_to_s3("scvi-tools", "tests/hub/test-scvi-no-anndata", push_anndata=True) + hub_model.push_to_s3("scvi-tools", "tests/hub/test-scvi-no-anndata", push_anndata=False) hub_model = prep_scvi_minified_hub_model(save_path) hub_model.push_to_s3("scvi-tools", "tests/hub/test-scvi-minified") diff --git a/tests/model/base/test_base_model.py b/tests/model/base/test_base_model.py index b4b6139928..4f3e8e3e8f 100644 --- a/tests/model/base/test_base_model.py +++ b/tests/model/base/test_base_model.py @@ -33,17 +33,13 @@ def setup_anndata( fields.CategoricalJointObsField( REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys ), - fields.NumericalJointObsField( - REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys - ), + fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/tests/model/test_amortizedlda.py b/tests/model/test_amortizedlda.py index a3b07dbd12..251806e355 100644 --- a/tests/model/test_amortizedlda.py +++ b/tests/model/test_amortizedlda.py @@ -9,9 +9,7 @@ def test_lda_model_single_step(n_topics: int = 5): adata = synthetic_iid() AmortizedLDA.setup_anndata(adata) - mod1 = AmortizedLDA( - adata, n_topics=n_topics, cell_topic_prior=1.5, topic_feature_prior=1.5 - ) + mod1 = AmortizedLDA(adata, n_topics=n_topics, cell_topic_prior=1.5, topic_feature_prior=1.5) mod1.train(max_steps=1, max_epochs=10) assert len(mod1.history["elbo_train"]) == 1 @@ -21,9 +19,7 @@ def test_lda_model(n_topics: int = 5): # Test with float and Sequence priors. AmortizedLDA.setup_anndata(adata) - mod1 = AmortizedLDA( - adata, n_topics=n_topics, cell_topic_prior=1.5, topic_feature_prior=1.5 - ) + mod1 = AmortizedLDA(adata, n_topics=n_topics, cell_topic_prior=1.5, topic_feature_prior=1.5) mod1.train( max_epochs=1, batch_size=256, diff --git a/tests/model/test_autozi.py b/tests/model/test_autozi.py index b7fb7a7444..578e51939b 100644 --- a/tests/model/test_autozi.py +++ b/tests/model/test_autozi.py @@ -81,14 +81,10 @@ def legacy_save( # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): AUTOZI.load(legacy_save_path, adata=adata, prefix=prefix) - AUTOZI.convert_legacy_save( - legacy_save_path, legacy_save_path, overwrite=True, prefix=prefix - ) + AUTOZI.convert_legacy_save(legacy_save_path, legacy_save_path, overwrite=True, prefix=prefix) m = AUTOZI.load(legacy_save_path, adata=adata, prefix=prefix) m.train(1) diff --git a/tests/model/test_differential.py b/tests/model/test_differential.py index f6fd9cb93b..7fe561d3db 100644 --- a/tests/model/test_differential.py +++ b/tests/model/test_differential.py @@ -66,9 +66,7 @@ def test_differential_computation(save_path): cell_idx1 = np.asarray(adata.obs.labels == "label_1") cell_idx2 = ~cell_idx1 dc.get_bayes_factors(cell_idx1, cell_idx2, mode="vanilla", use_permutation=True) - res = dc.get_bayes_factors( - cell_idx1, cell_idx2, mode="change", use_permutation=False - ) + res = dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", use_permutation=False) model_fn = partial( model.get_normalized_expression, @@ -83,9 +81,7 @@ def test_differential_computation(save_path): cell_idx1 = np.asarray(adata.obs.labels == "label_1") cell_idx2 = ~cell_idx1 dc.get_bayes_factors(cell_idx1, cell_idx2, mode="vanilla", use_permutation=True) - res = dc.get_bayes_factors( - cell_idx1, cell_idx2, mode="change", use_permutation=False - ) + res = dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", use_permutation=False) assert (res["delta"] == 0.5) and (res["pseudocounts"] == 0.0) res = dc.get_bayes_factors( cell_idx1, cell_idx2, mode="change", use_permutation=False, delta=None @@ -121,18 +117,14 @@ def m1_domain_fn_test(samples): model.differential_expression(adata[:20], groupby="batch") # test view - model.differential_expression( - adata[adata.obs["labels"] == "label_1"], groupby="batch" - ) + model.differential_expression(adata[adata.obs["labels"] == "label_1"], groupby="batch") # Test query features ( obs_col, group1, _, - ) = _prepare_obs( - idx1="(labels == 'label_1') & (batch == 'batch_1')", idx2=None, adata=adata - ) + ) = _prepare_obs(idx1="(labels == 'label_1') & (batch == 'batch_1')", idx2=None, adata=adata) assert (obs_col == group1).sum() == adata.obs.loc[ lambda x: (x.labels == "label_1") & (x.batch == "batch_1") ].shape[0] diff --git a/tests/model/test_jaxscvi.py b/tests/model/test_jaxscvi.py index 922cf6ca1c..38b7fe4096 100644 --- a/tests/model/test_jaxscvi.py +++ b/tests/model/test_jaxscvi.py @@ -37,9 +37,7 @@ def test_jax_scvi_training(n_latent: int = 5, dropout_rate: float = 0.1): model = JaxSCVI(adata, n_latent=n_latent, dropout_rate=dropout_rate) assert model.module.training - with mock.patch( - "scvi.module._jaxvae.nn.Dropout", wraps=nn.Dropout - ) as mock_dropout_cls: + with mock.patch("scvi.module._jaxvae.nn.Dropout", wraps=nn.Dropout) as mock_dropout_cls: mock_dropout = mock.Mock() mock_dropout.side_effect = lambda h, **_kwargs: h mock_dropout_cls.return_value = mock_dropout @@ -76,9 +74,7 @@ def test_jax_scvi_save_load(save_path: str, n_latent: int = 5): # Load with different batches. tmp_adata = synthetic_iid() - tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories( - ["batch_2", "batch_3"] - ) + tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(["batch_2", "batch_3"]) with pytest.raises(ValueError): JaxSCVI.load(save_path, adata=tmp_adata) diff --git a/tests/model/test_linear_scvi.py b/tests/model/test_linear_scvi.py index 141acfb007..a339cce8de 100644 --- a/tests/model/test_linear_scvi.py +++ b/tests/model/test_linear_scvi.py @@ -92,9 +92,7 @@ def test_save_load_model(cls, adata, save_path, prefix=None): # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): cls.load(legacy_save_path, adata=adata, prefix=prefix) cls.convert_legacy_save( diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 37b61ff34b..a0ea453705 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -74,9 +74,7 @@ def run_test_for_model_with_minified_adata( model.adata.obsm["X_latent_qzv"] = qzv scvi.settings.seed = 1 - params_orig = model.get_likelihood_parameters( - n_samples=n_samples, give_mean=give_mean - ) + params_orig = model.get_likelihood_parameters(n_samples=n_samples, give_mean=give_mean) adata_orig = adata.copy() model.minify_adata() @@ -95,24 +93,18 @@ def run_test_for_model_with_minified_adata( assert model.adata.var_names.equals(adata_orig.var_names) assert model.adata.var.equals(adata_orig.var) assert model.adata.varm.keys() == adata_orig.varm.keys() - np.testing.assert_array_equal( - model.adata.varm["my_varm"], adata_orig.varm["my_varm"] - ) + np.testing.assert_array_equal(model.adata.varm["my_varm"], adata_orig.varm["my_varm"]) scvi.settings.seed = 1 keys = ["mean", "dispersions", "dropout"] if n_samples == 1: - params_latent = model.get_likelihood_parameters( - n_samples=n_samples, give_mean=give_mean - ) + params_latent = model.get_likelihood_parameters(n_samples=n_samples, give_mean=give_mean) else: # do this so that we generate the same sequence of random numbers in the # minified and non-minified cases (purely to get the tests to pass). this is # because in the non-minified case we sample once more (in the call to z_encoder # during inference) - params_latent = model.get_likelihood_parameters( - n_samples=n_samples + 1, give_mean=False - ) + params_latent = model.get_likelihood_parameters(n_samples=n_samples + 1, give_mean=False) for k in keys: params_latent[k] = params_latent[k][1:].mean(0) for k in keys: @@ -134,9 +126,7 @@ def test_scvi_with_minified_adata_one_sample_with_layer(): def test_scvi_with_minified_adata_n_samples(): run_test_for_model_with_minified_adata(n_samples=10, give_mean=True) - run_test_for_model_with_minified_adata( - n_samples=10, give_mean=True, use_size_factor=True - ) + run_test_for_model_with_minified_adata(n_samples=10, give_mean=True, use_size_factor=True) def test_scanvi_with_minified_adata_one_sample(): @@ -146,9 +136,7 @@ def test_scanvi_with_minified_adata_one_sample(): def test_scanvi_with_minified_adata_one_sample_with_layer(): run_test_for_model_with_minified_adata(SCANVI, layer="data_layer") - run_test_for_model_with_minified_adata( - SCANVI, layer="data_layer", use_size_factor=True - ) + run_test_for_model_with_minified_adata(SCANVI, layer="data_layer", use_size_factor=True) def test_scanvi_with_minified_adata_n_samples(): @@ -386,17 +374,13 @@ def test_scvi_with_minified_adata_posterior_predictive_sample(): model.adata.obsm["X_latent_qzv"] = qzv scvi.settings.seed = 1 - sample_orig = model.posterior_predictive_sample( - indices=[1, 2, 3], gene_list=["1", "2"] - ) + sample_orig = model.posterior_predictive_sample(indices=[1, 2, 3], gene_list=["1", "2"]) model.minify_adata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR scvi.settings.seed = 1 - sample_new = model.posterior_predictive_sample( - indices=[1, 2, 3], gene_list=["1", "2"] - ) + sample_new = model.posterior_predictive_sample(indices=[1, 2, 3], gene_list=["1", "2"]) assert sample_new.shape == (3, 2) np.testing.assert_array_equal(sample_new.todense(), sample_orig.todense()) diff --git a/tests/model/test_peakvi.py b/tests/model/test_peakvi.py index a33b2c8157..01f93dc7c3 100644 --- a/tests/model/test_peakvi.py +++ b/tests/model/test_peakvi.py @@ -92,9 +92,7 @@ def test_save_load_model(cls, adata, save_path, prefix=None): # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): cls.load(legacy_save_path, adata=adata, prefix=prefix) cls.convert_legacy_save( diff --git a/tests/model/test_pyro.py b/tests/model/test_pyro.py index 8980c75936..38cd8b772a 100644 --- a/tests/model/test_pyro.py +++ b/tests/model/test_pyro.py @@ -53,9 +53,7 @@ def __init__(self, in_features, out_features, per_cell_weight=False): .to_event(2) ) self.linear.bias = PyroSample( - lambda prior: dist.Normal(self.zero, self.ten) - .expand([self.out_features]) - .to_event(1) + lambda prior: dist.Normal(self.zero, self.ten).expand([self.out_features]).to_event(1) ) def create_plates(self, x, y, ind_x): @@ -167,9 +165,7 @@ def setup_anndata( CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -216,9 +212,7 @@ def test_pyro_bayesian_regression_low_level( ] -def test_pyro_bayesian_regression( - accelerator: str, devices: list | str | int, save_path: str -): +def test_pyro_bayesian_regression(accelerator: str, devices: list | str | int, save_path: str): adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128) @@ -362,9 +356,10 @@ def test_pyro_bayesian_train_sample_mixin(): ) # 100 features - assert list( - mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape - ) == [1, 100] + assert list(mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape) == [ + 1, + 100, + ] # test posterior sampling samples = mod.sample_posterior(num_samples=10, batch_size=None, return_samples=True) @@ -383,14 +378,13 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): ) # 100 features - assert list( - mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape - ) == [1, 100] + assert list(mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape) == [ + 1, + 100, + ] # test posterior sampling - samples = mod.sample_posterior( - num_samples=10, batch_size=adata.n_obs, return_samples=True - ) + samples = mod.sample_posterior(num_samples=10, batch_size=adata.n_obs, return_samples=True) assert len(samples["posterior_samples"]["sigma"]) == 10 @@ -407,9 +401,10 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): ) # 100 - assert list( - mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape - ) == [1, 100] + assert list(mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape) == [ + 1, + 100, + ] # test posterior sampling samples = mod.sample_posterior(num_samples=10, batch_size=None, return_samples=True) @@ -434,14 +429,13 @@ def test_pyro_bayesian_train_sample_mixin_with_local_full_data(): ) # 100 - assert list( - mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape - ) == [1, 100] + assert list(mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape) == [ + 1, + 100, + ] # test posterior sampling - samples = mod.sample_posterior( - num_samples=10, batch_size=adata.n_obs, return_samples=True - ) + samples = mod.sample_posterior(num_samples=10, batch_size=adata.n_obs, return_samples=True) assert len(samples["posterior_samples"]["sigma"]) == 10 assert samples["posterior_samples"]["per_cell_weights"].shape == ( @@ -496,9 +490,7 @@ def model(self, x, log_library): # decode the latent code z px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library) # build count distribution - nb_logits = (px_rate + self.epsilon).log() - ( - self.px_r.exp() + self.epsilon - ).log() + nb_logits = (px_rate + self.epsilon).log() - (self.px_r.exp() + self.epsilon).log() x_dist = dist.ZeroInflatedNegativeBinomial( gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits ) @@ -546,9 +538,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/tests/model/test_scanvi.py b/tests/model/test_scanvi.py index fb314e6cb4..048efbd2d7 100644 --- a/tests/model/test_scanvi.py +++ b/tests/model/test_scanvi.py @@ -80,14 +80,10 @@ def legacy_save( # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): SCANVI.load(legacy_save_path, adata=adata, prefix=prefix) - SCANVI.convert_legacy_save( - legacy_save_path, legacy_save_path, overwrite=True, prefix=prefix - ) + SCANVI.convert_legacy_save(legacy_save_path, legacy_save_path, overwrite=True, prefix=prefix) m = SCANVI.load(legacy_save_path, adata=adata, prefix=prefix) m.train(1) @@ -179,9 +175,7 @@ def test_scanvi(): # test from_scvi_model with size_factor a = synthetic_iid() a.obs["size_factor"] = np.random.randint(1, 5, size=(a.shape[0],)) - SCVI.setup_anndata( - a, batch_key="batch", labels_key="labels", size_factor_key="size_factor" - ) + SCVI.setup_anndata(a, batch_key="batch", labels_key="labels", size_factor_key="size_factor") m = SCVI(a, use_observed_lib_size=False) a2 = synthetic_iid() a2.obs["size_factor"] = np.random.randint(1, 5, size=(a2.shape[0],)) @@ -381,18 +375,10 @@ def test_scanvi_online_update(save_path): # test classifier frozen class_query_weight = ( - model2.module.classifier.classifier[0] - .fc_layers[0][0] - .weight.detach() - .cpu() - .numpy() + model2.module.classifier.classifier[0].fc_layers[0][0].weight.detach().cpu().numpy() ) class_ref_weight = ( - model.module.classifier.classifier[0] - .fc_layers[0][0] - .weight.detach() - .cpu() - .numpy() + model.module.classifier.classifier[0].fc_layers[0][0].weight.detach().cpu().numpy() ) # weight decay makes difference np.testing.assert_allclose(class_query_weight, class_ref_weight, atol=1e-07) @@ -403,18 +389,10 @@ def test_scanvi_online_update(save_path): model2._labeled_indices = [] model2.train(max_epochs=1) class_query_weight = ( - model2.module.classifier.classifier[0] - .fc_layers[0][0] - .weight.detach() - .cpu() - .numpy() + model2.module.classifier.classifier[0].fc_layers[0][0].weight.detach().cpu().numpy() ) class_ref_weight = ( - model.module.classifier.classifier[0] - .fc_layers[0][0] - .weight.detach() - .cpu() - .numpy() + model.module.classifier.classifier[0].fc_layers[0][0].weight.detach().cpu().numpy() ) with pytest.raises(AssertionError): np.testing.assert_allclose(class_query_weight, class_ref_weight, atol=1e-07) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 7b8b583b4d..763d874b2d 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -148,9 +148,7 @@ def test_save_load_model(cls, adata, save_path, prefix=None): # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): cls.load(legacy_save_path, adata=adata, prefix=prefix) cls.convert_legacy_save( @@ -194,9 +192,7 @@ def test_scvi(n_latent: int = 5): model.train(1, check_val_every_n_epoch=1, train_size=0.5) # Test without observed lib size. - model = SCVI( - adata, n_latent=n_latent, var_activation=Softplus(), use_observed_lib_size=False - ) + model = SCVI(adata, n_latent=n_latent, var_activation=Softplus(), use_observed_lib_size=False) model.train(1, check_val_every_n_epoch=1, train_size=0.5) model.train(1, check_val_every_n_epoch=1, train_size=0.5) @@ -242,9 +238,7 @@ def test_scvi(n_latent: int = 5): assert denoised.shape == (3, adata2.n_vars) sample = model.posterior_predictive_sample(adata2) assert sample.shape == adata2.shape - sample = model.posterior_predictive_sample( - adata2, indices=[1, 2, 3], gene_list=["1", "2"] - ) + sample = model.posterior_predictive_sample(adata2, indices=[1, 2, 3], gene_list=["1", "2"]) assert sample.shape == (3, 2) sample = model.posterior_predictive_sample( adata2, indices=[1, 2, 3], gene_list=["1", "2"], n_samples=3 @@ -316,9 +310,7 @@ def test_scvi(n_latent: int = 5): # test differential expression model.differential_expression(groupby="labels", group1="label_1") - model.differential_expression( - groupby="labels", group1="label_1", weights="importance" - ) + model.differential_expression(groupby="labels", group1="label_1", weights="importance") model.differential_expression( groupby="labels", group1="label_1", group2="label_2", mode="change" ) @@ -505,9 +497,7 @@ def test_new_setup_compat(): adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],)) adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],)) # Handle edge case where registry_key != obs_key. - adata.obs.rename( - columns={"batch": "testbatch", "labels": "testlabels"}, inplace=True - ) + adata.obs.rename(columns={"batch": "testbatch", "labels": "testlabels"}, inplace=True) adata2 = adata.copy() SCVI.setup_anndata( @@ -544,9 +534,7 @@ def test_new_setup_compat(): @pytest.mark.internet def test_backwards_compatible_loading(save_path): def download_080_models(save_path): - file_path = ( - "https://github.com/yoseflab/scVI-data/raw/master/testing_models.tar.gz" - ) + file_path = "https://github.com/yoseflab/scVI-data/raw/master/testing_models.tar.gz" save_fn = "testing_models.tar.gz" _download(file_path, save_path, save_fn) saved_file_path = os.path.join(save_path, save_fn) @@ -556,9 +544,7 @@ def download_080_models(save_path): download_080_models(save_path) pretrained_scvi_path = os.path.join(save_path, "testing_models/080_scvi") - pretrained_scvi_updated_path = os.path.join( - save_path, "testing_models/080_scvi_updated" - ) + pretrained_scvi_updated_path = os.path.join(save_path, "testing_models/080_scvi_updated") a = synthetic_iid() # Fail legacy load. with pytest.raises(ValueError): @@ -707,9 +693,7 @@ def test_scarches_data_prep(save_path): SCVI.prepare_query_anndata(adata4, dir_path) # should be padded 0s assert np.sum(adata4[:, adata4.var_names[:10]].X) == 0 - np.testing.assert_equal( - adata4.var_names[:10].to_numpy(), adata1.var_names[:10].to_numpy() - ) + np.testing.assert_equal(adata4.var_names[:10].to_numpy(), adata1.var_names[:10].to_numpy()) SCVI.load_query_data(adata4, dir_path) adata5 = SCVI.prepare_query_anndata(adata4, dir_path, inplace=False) @@ -736,9 +720,7 @@ def test_scarches_data_prep_layer(save_path): SCVI.prepare_query_anndata(adata4, dir_path) # should be padded 0s assert np.sum(adata4[:, adata4.var_names[:10]].layers["counts"]) == 0 - np.testing.assert_equal( - adata4.var_names[:10].to_numpy(), adata1.var_names[:10].to_numpy() - ) + np.testing.assert_equal(adata4.var_names[:10].to_numpy(), adata1.var_names[:10].to_numpy()) SCVI.load_query_data(adata4, dir_path) diff --git a/tests/model/test_totalvi.py b/tests/model/test_totalvi.py index 526705943b..afd5cd5f45 100644 --- a/tests/model/test_totalvi.py +++ b/tests/model/test_totalvi.py @@ -100,9 +100,7 @@ def test_save_load_model(cls, adata, save_path, prefix=None): # Test legacy loading legacy_save_path = os.path.join(save_path, "legacy/") - legacy_save( - model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix - ) + legacy_save(model, legacy_save_path, overwrite=True, save_anndata=True, prefix=prefix) with pytest.raises(ValueError): cls.load(legacy_save_path, adata=adata, prefix=prefix) cls.convert_legacy_save( @@ -167,15 +165,11 @@ def test_totalvi(save_path): assert post_pred.shape == (n_obs, n_vars + n_proteins, 2) post_pred = model.posterior_predictive_sample(n_samples=1) assert post_pred.shape == (n_obs, n_vars + n_proteins) - feature_correlation_matrix1 = model.get_feature_correlation_matrix( - correlation_type="spearman" - ) + feature_correlation_matrix1 = model.get_feature_correlation_matrix(correlation_type="spearman") feature_correlation_matrix1 = model.get_feature_correlation_matrix( correlation_type="spearman", transform_batch=["batch_0", "batch_1"] ) - feature_correlation_matrix2 = model.get_feature_correlation_matrix( - correlation_type="pearson" - ) + feature_correlation_matrix2 = model.get_feature_correlation_matrix(correlation_type="pearson") assert feature_correlation_matrix1.shape == ( n_vars + n_proteins, n_vars + n_proteins, @@ -283,9 +277,7 @@ def test_totalvi_model_library_size(save_path): n_latent = 10 model = TOTALVI(adata, n_latent=n_latent, use_observed_lib_size=False) - assert hasattr(model.module, "library_log_means") and hasattr( - model.module, "library_log_vars" - ) + assert hasattr(model.module, "library_log_means") and hasattr(model.module, "library_log_vars") model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() @@ -397,15 +389,11 @@ def test_totalvi_mudata(): assert post_pred.shape == (n_obs, n_genes + n_proteins, 2) post_pred = model.posterior_predictive_sample(n_samples=1) assert post_pred.shape == (n_obs, n_genes + n_proteins) - feature_correlation_matrix1 = model.get_feature_correlation_matrix( - correlation_type="spearman" - ) + feature_correlation_matrix1 = model.get_feature_correlation_matrix(correlation_type="spearman") feature_correlation_matrix1 = model.get_feature_correlation_matrix( correlation_type="spearman", transform_batch=["batch_0", "batch_1"] ) - feature_correlation_matrix2 = model.get_feature_correlation_matrix( - correlation_type="pearson" - ) + feature_correlation_matrix2 = model.get_feature_correlation_matrix(correlation_type="pearson") assert feature_correlation_matrix1.shape == ( n_genes + n_proteins, n_genes + n_proteins, @@ -542,9 +530,7 @@ def test_totalvi_model_library_size_mudata(): n_latent = 10 model = TOTALVI(mdata, n_latent=n_latent, use_observed_lib_size=False) - assert hasattr(model.module, "library_log_means") and hasattr( - model.module, "library_log_vars" - ) + assert hasattr(model.module, "library_log_means") and hasattr(model.module, "library_log_vars") model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() @@ -618,9 +604,7 @@ def test_totalvi_saving_and_loading_mudata(save_path): # Load with different batches. tmp_adata = synthetic_iid() - tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories( - ["batch_2", "batch_3"] - ) + tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(["batch_2", "batch_3"]) tmp_protein_adata = synthetic_iid(n_genes=50) tmp_mdata = MuData({"rna": tmp_adata, "protein": tmp_protein_adata}) with pytest.raises(ValueError): diff --git a/tests/train/test_trainingplans.py b/tests/train/test_trainingplans.py index 3b56760795..4ebf717462 100644 --- a/tests/train/test_trainingplans.py +++ b/tests/train/test_trainingplans.py @@ -23,13 +23,9 @@ def test_compute_kl_weight_linear_annealing( current, n_warm_up, min_kl_weight, max_kl_weight, expected ): - kl_weight = _compute_kl_weight( - current, 1, n_warm_up, None, max_kl_weight, min_kl_weight - ) + kl_weight = _compute_kl_weight(current, 1, n_warm_up, None, max_kl_weight, min_kl_weight) assert kl_weight == pytest.approx(expected) - kl_weight = _compute_kl_weight( - 1, current, None, n_warm_up, max_kl_weight, min_kl_weight - ) + kl_weight = _compute_kl_weight(1, current, None, n_warm_up, max_kl_weight, min_kl_weight) assert kl_weight == pytest.approx(expected) @@ -51,12 +47,8 @@ def test_compute_kl_weight_min_greater_max(): (100, 200, 100, 1000, 1.0), ], ) -def test_compute_kl_precedence( - epoch, step, n_epochs_kl_warmup, n_steps_kl_warmup, expected -): - kl_weight = _compute_kl_weight( - epoch, step, n_epochs_kl_warmup, n_steps_kl_warmup, 1.0, 0.0 - ) +def test_compute_kl_precedence(epoch, step, n_epochs_kl_warmup, n_steps_kl_warmup, expected): + kl_weight = _compute_kl_weight(epoch, step, n_epochs_kl_warmup, n_steps_kl_warmup, 1.0, 0.0) assert kl_weight == expected From ce907ba143fe407be38e20c78d52678e32352335 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Mon, 12 Feb 2024 09:59:30 -0800 Subject: [PATCH 13/21] Update tutorials head (#2502) --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index e773cc9f9f..b73465f45c 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit e773cc9f9f9d116c0b5a8a8c813c4d890743ca5c +Subproject commit b73465f45c2235744fe526981c79f5fd09bce4a8 From 5646c6743ab229f9996d161a3ff6198bb3688b16 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Mon, 12 Feb 2024 10:14:35 -0800 Subject: [PATCH 14/21] Update ruff pre-commit (#2504) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 74de5975f4..9891eef4f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: # https://github.com/jupyterlab/jupyterlab/issues/12675 language_version: "17.9.1" - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.11 + rev: v0.2.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 31053155733b95a60c99a5f61c6896b223882f81 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 13 Feb 2024 10:49:53 -0800 Subject: [PATCH 15/21] Update tutorials head for release (#2508) --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index b73465f45c..3d913fce03 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit b73465f45c2235744fe526981c79f5fd09bce4a8 +Subproject commit 3d913fce03a15ac42f46844840cd831e9b29d8ab From b17a0fc8ffc8eb948eb8c443f9a88889961f6a63 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:43:54 -0800 Subject: [PATCH 16/21] Bump to 1.1.0post1 (#2510) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b7c6074c98..3ef25ff752 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = ["hatchling"] [project] name = "scvi-tools" -version = "1.1.0" +version = "1.1.0post1" description = "Deep probabilistic analysis of single-cell omics data." readme = "README.md" requires-python = ">=3.9" From 14284721dfe6338679fa95972caa98b943a6e317 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:57:05 -0800 Subject: [PATCH 17/21] Bump to post2 and add date to release (#2512) --- docs/release_notes/index.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 063426243f..34d0f5a6db 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -7,7 +7,7 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -### 1.1.0 (2024-02-DD) +### 1.1.0 (2024-02-13) #### Added diff --git a/pyproject.toml b/pyproject.toml index 3ef25ff752..70fb32e2d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = ["hatchling"] [project] name = "scvi-tools" -version = "1.1.0post1" +version = "1.1.0post2" description = "Deep probabilistic analysis of single-cell omics data." readme = "README.md" requires-python = ">=3.9" From df7ea5e5743464a6c4bd550a7caa50a23ab8323b Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:43:16 -0800 Subject: [PATCH 18/21] Start 1.1.1 release (#2514) --- docs/release_notes/index.md | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 34d0f5a6db..9158423ee5 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -7,6 +7,10 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +### 1.1.1 (unreleased) + +#### Fixed + ### 1.1.0 (2024-02-13) #### Added diff --git a/pyproject.toml b/pyproject.toml index 70fb32e2d6..846a33524e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = ["hatchling"] [project] name = "scvi-tools" -version = "1.1.0post2" +version = "1.1.1" description = "Deep probabilistic analysis of single-cell omics data." readme = "README.md" requires-python = ">=3.9" From 6fb91869f541e126338a01d16c13b2820f40c7f0 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 16 Feb 2024 13:22:25 -0800 Subject: [PATCH 19/21] Add CUDA 12 to CI (#2518) --- .github/workflows/test_linux_cuda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 153254bdf1..722541eeab 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -31,7 +31,7 @@ jobs: fail-fast: false matrix: python: ["3.11"] - cuda: ["11"] + cuda: ["11", "12"] container: image: scverse/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}-base From 6736ff8f7f7ae64e1d5ffbc2d91753a3f555260b Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:26:12 -0800 Subject: [PATCH 20/21] Allow for non-default user params in `POISSONVI` (#2522) - match defaults for `dropout_rate` and `n_layers` with `PEAKVAE` - pass in `n_hidden` and `n_latent` so that defaults don't always take precedence - test that non-default user parameters apply correctly - test that default parameters match `PEAKVI` --- docs/release_notes/index.md | 2 + scvi/external/poissonvi/_model.py | 15 +++--- tests/external/poissonvi/test_poissonvi.py | 62 ++++++++++++++++++++-- 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 9158423ee5..aee31140bb 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -11,6 +11,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ #### Fixed +- Correctly apply non-default user parameters in {class}`scvi.external.POISSONVI` {pr}`2522`. + ### 1.1.0 (2024-02-13) #### Added diff --git a/scvi/external/poissonvi/_model.py b/scvi/external/poissonvi/_model.py index eef36a6a85..0acf9276fd 100644 --- a/scvi/external/poissonvi/_model.py +++ b/scvi/external/poissonvi/_model.py @@ -79,14 +79,13 @@ def __init__( adata: AnnData, n_hidden: int | None = None, n_latent: int | None = None, - n_layers: int | None = None, - dropout_rate: float | None = None, + n_layers: int = 2, + dropout_rate: float = 0.1, latent_distribution: Literal["normal", "ln"] = "normal", **model_kwargs, ): - super().__init__( - adata, - ) + # need to pass these in to get the correct defaults for peakvi + super().__init__(adata, n_hidden=n_hidden, n_latent=n_latent) n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry @@ -104,11 +103,11 @@ def __init__( n_cats_per_cov=self.module.n_cats_per_cov, n_hidden=self.module.n_hidden, n_latent=self.module.n_latent, - n_layers=self.module.n_layers_encoder, - dropout_rate=self.module.dropout_rate, + n_layers=n_layers, + dropout_rate=dropout_rate, dispersion="gene", # not needed here gene_likelihood="poisson", # fixed value for now, but we could think of also allowing nb - latent_distribution=self.module.latent_distribution, + latent_distribution=latent_distribution, use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, diff --git a/tests/external/poissonvi/test_poissonvi.py b/tests/external/poissonvi/test_poissonvi.py index d75c7493e6..53de237668 100644 --- a/tests/external/poissonvi/test_poissonvi.py +++ b/tests/external/poissonvi/test_poissonvi.py @@ -1,13 +1,69 @@ +from torch.nn import Linear + from scvi.data import synthetic_iid from scvi.external import POISSONVI def test_poissonvi(): adata = synthetic_iid(batch_size=100) - POISSONVI.setup_anndata( - adata, - ) + POISSONVI.setup_anndata(adata) model = POISSONVI(adata) model.train(max_epochs=1) model.get_latent_representation() model.get_accessibility_estimates() + + +def test_poissonvi_default_params(): + from scvi.model import PEAKVI + + adata = synthetic_iid(batch_size=100) + POISSONVI.setup_anndata(adata) + PEAKVI.setup_anndata(adata) + poissonvi = POISSONVI(adata) + peakvi = PEAKVI(adata) + + assert poissonvi.module.n_latent == peakvi.module.n_latent + assert poissonvi.module.latent_distribution == peakvi.module.latent_distribution + poisson_encoder = poissonvi.module.z_encoder.encoder + poisson_mean_encoder = poissonvi.module.z_encoder.mean_encoder + poisson_decoder = poissonvi.module.decoder.px_decoder + assert len(poisson_encoder.fc_layers) == peakvi.module.n_layers_encoder + assert len(poisson_decoder.fc_layers) == peakvi.module.n_layers_encoder + assert poisson_encoder.fc_layers[-1][0].in_features == peakvi.module.n_hidden + assert poisson_decoder.fc_layers[-1][0].in_features == peakvi.module.n_hidden + assert poisson_mean_encoder.out_features == peakvi.module.n_latent + assert poisson_decoder.fc_layers[0][0].in_features == peakvi.module.n_latent + + +def test_poissonvi_non_default_params( + n_hidden: int = 50, + n_latent: int = 5, + n_layers: int = 2, + dropout_rate: float = 0.4, + latent_distribution="ln", +): + adata = synthetic_iid(batch_size=100) + POISSONVI.setup_anndata(adata) + model = POISSONVI( + adata, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + latent_distribution=latent_distribution, + ) + + assert model.module.n_latent == n_latent + assert model.module.latent_distribution == latent_distribution + + encoder = model.module.z_encoder.encoder + assert len(encoder.fc_layers) == n_layers + linear = encoder.fc_layers[-1][0] + assert isinstance(linear, Linear) + assert linear.in_features == n_hidden + mean_encoder = model.module.z_encoder.mean_encoder + assert isinstance(mean_encoder, Linear) + assert mean_encoder.out_features == n_latent + + model.train(max_epochs=1) + assert model.get_latent_representation().shape[1] == n_latent From fea055e6976997ebd5b989cdd48dd4500e46fba4 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:21:05 -0800 Subject: [PATCH 21/21] Update release checklist (#2526) --- .github/ISSUE_TEMPLATE/release_checklist.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/release_checklist.md b/.github/ISSUE_TEMPLATE/release_checklist.md index 156756e0b7..d2554ed9d7 100644 --- a/.github/ISSUE_TEMPLATE/release_checklist.md +++ b/.github/ISSUE_TEMPLATE/release_checklist.md @@ -10,9 +10,10 @@ assignees: "" - [ ] If patch release, backport version bump PR into the appropriate branch. Else, create a new branch off `main` with the appropriate rules - [ ] Trigger a Docker image build in [`scvi-tools-docker`](https://github.com/YosefLab/scvi-tools-docker) targeting the release branch - [ ] After image builds and pushes to the registry, run the [tutorials](https://github.com/scverse/scvi-tutorials) using the new image -- [ ] Publish a new release on the tutorials repo off `main` after all tutorials changes have been merged. +- [ ] Publish a new release on the tutorials repo off `main` after all tutorials changes have been merged - [ ] Create a new branch off `main` in the main repo and run `git submodule update --remote`, and then merge the PR, with an appropriate backport as needed - [ ] Create a new GitHub release targeting the release branch with the same body as the previous release - [ ] Check that the version updates correctly on [PyPI](https://pypi.org/project/scvi-tools/) +- [ ] Build new Docker images with the `stable` and semantic versioning tags - [ ] Check that the [feedstock repo](https://github.com/conda-forge/scvi-tools-feedstock) updates correctly - [ ] (Optional) Post threads on Discourse and Twitter