Skip to content

Commit

Permalink
Update scvi/external/sysvi/_base_components.py
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 authored Mar 15, 2024
1 parent c5f5c37 commit 5b4838c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
self.mean_encoder = Linear(n_hidden, n_output)
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps)

def forward(self, x, cov: torch.Tensor | None = None):
def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None):
y = self.decoder_y(x=x, cov=cov)
# TODO better handling of inappropriate edge-case values than nan_to_num or at least warn
y_m = torch.nan_to_num(self.mean_encoder(y))
Expand Down

0 comments on commit 5b4838c

Please sign in to comment.