Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 14, 2024
1 parent b9a9047 commit b95cb03
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def forward(self, x: torch.Tensor):
# Force to be non nan - TODO come up with better way to do so
if self.mode == "sample_feature":
v = self.encoder(x)
v = (self.activation(v) + self.eps) # Ensure that var is strictly positive
v = self.activation(v) + self.eps # Ensure that var is strictly positive
elif self.mode == "feature":
v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size
v = (self.activation(v) + self.eps) # Ensure that var is strictly positive
v = self.activation(v) + self.eps # Ensure that var is strictly positive
return v
6 changes: 2 additions & 4 deletions src/scvi/external/sysvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data._constants import _SCVI_UUID_KEY
from scvi.data._utils import _check_if_view
from scvi.data.fields import (
LayerField,
ObsmField,
Expand Down Expand Up @@ -137,6 +135,7 @@ def get_latent_representation(
return_dist
If ``True``, returns the mean and variance of the latent distribution. Otherwise,
returns the mean of the latent distribution.
Returns
-------
Latent Embedding
Expand Down Expand Up @@ -192,8 +191,7 @@ def _validate_anndata(

# Check that all required fields are present and match the Model's adata
assert (
self.adata.uns["layer_information"]["layer"]
== adata.uns["layer_information"]["layer"]
self.adata.uns["layer_information"]["layer"] == adata.uns["layer_information"]["layer"]
)
assert (
self.adata.uns["layer_information"]["var_names"]
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/sysvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def loss(

# Reconstruction loss
reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")(
generative_outputs["x_m"], x_true, generative_outputs["x_v"]).sum(dim=1)
generative_outputs["x_m"], x_true, generative_outputs["x_v"]
).sum(dim=1)

reconst_loss = reconst_loss_x

Expand Down
2 changes: 0 additions & 2 deletions tests/external/sysvi/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,3 @@ def test_model():
give_mean=False,
),
)


0 comments on commit b95cb03

Please sign in to comment.