-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding model with cycle consistency and VampPrior #2421
base: main
Are you sure you want to change the base?
Changes from 23 commits
a582251
bb359ac
14e41f1
5863b1f
3f49266
6605682
3319fc8
3a67d1c
5edad83
a4b080e
a24ef28
9b05bca
8c25dba
c5f5c37
5b4838c
c885e20
661bbc6
9a49d24
3622eee
e4c1ef9
0f7bd06
9e0cba9
54f5734
df959ed
f65c403
5fcb6f6
3f1cffe
3c93e7f
3e8ffd0
4ce4614
f81abda
bdd2c7a
4507602
cc19d91
c15b8a2
aacdafe
de6db13
235767c
5f04d65
69759ec
2e447c8
9e8cc35
b0829c0
bf2e850
bd5a882
7f50353
01db60a
6cdd3d6
62eb924
b9a9047
b95cb03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._base_components import Layers, VarEncoder | ||
from ._model import SysVI | ||
from ._module import SysVAE | ||
|
||
__all__ = ["SysVI", "VarEncoder", "Layers", "SysVAE"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,345 @@ | ||
from __future__ import annotations | ||
|
||
from collections import OrderedDict | ||
from typing import Literal | ||
|
||
import torch | ||
from torch.distributions import Normal | ||
from torch.nn import ( | ||
BatchNorm1d, | ||
Dropout, | ||
LayerNorm, | ||
Linear, | ||
Module, | ||
Parameter, | ||
ReLU, | ||
Sequential, | ||
) | ||
|
||
|
||
class Embedding(Module): | ||
"""Module for obtaining embedding of categorical covariates | ||
|
||
Parameters | ||
---------- | ||
size | ||
N categories | ||
cov_embed_dims | ||
Dimensions of embedding | ||
normalize | ||
Apply layer normalization | ||
""" | ||
|
||
def __init__(self, size: int, cov_embed_dims: int = 10, normalize: bool = True): | ||
super().__init__() | ||
|
||
self.normalize = normalize | ||
|
||
self.embedding = torch.nn.Embedding(size, cov_embed_dims) | ||
|
||
if self.normalize: | ||
# TODO this could probably be implemented more efficiently as embed gives same result for every sample in | ||
# a give class. However, if we have many balanced classes there wont be many repetitions within minibatch | ||
self.layer_norm = torch.nn.LayerNorm(cov_embed_dims, elementwise_affine=False) | ||
|
||
def forward(self, x): | ||
x = self.embedding(x) | ||
if self.normalize: | ||
x = self.layer_norm(x) | ||
|
||
return x | ||
|
||
|
||
class EncoderDecoder(Module): | ||
"""Module that can be used as probabilistic encoder or decoder | ||
|
||
Based on inputs and optional covariates predicts output mean and var | ||
|
||
Parameters | ||
---------- | ||
n_input | ||
The dimensionality of the main input | ||
n_output | ||
The dimensionality of the output | ||
n_cov | ||
Dimensionality of covariates. | ||
If there are no cov this should be set to None - | ||
in this case cov will not be used. | ||
n_hidden | ||
The number of fully-connected hidden layers | ||
n_layers | ||
Number of hidden layers | ||
var_eps | ||
See :class:`~scvi.external.sysvi.VarEncoder` | ||
var_mode | ||
See :class:`~scvi.external.sysvi.VarEncoder` | ||
sample | ||
Return samples from predicted distribution | ||
kwargs | ||
Passed to :class:`~scvi.external.sysvi.Layers` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_input: int, | ||
n_output: int, | ||
n_cov: int, | ||
n_hidden: int = 256, | ||
n_layers: int = 3, | ||
var_eps: float = 1e-4, | ||
var_mode: Literal["sample_feature", "feature", "linear"] = "feature", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide in the documentation more insights than cross-reference? |
||
sample: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.sample = sample | ||
|
||
self.var_eps = var_eps | ||
|
||
self.decoder_y = Layers( | ||
n_in=n_input, | ||
n_cov=n_cov, | ||
n_out=n_hidden, | ||
n_hidden=n_hidden, | ||
n_layers=n_layers, | ||
**kwargs, | ||
) | ||
|
||
self.mean_encoder = Linear(n_hidden, n_output) | ||
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps) | ||
|
||
def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): | ||
y = self.decoder_y(x=x, cov=cov) | ||
# TODO better handling of inappropriate edge-case values than nan_to_num or at least warn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have there been edge case values other than NaNs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dislike nan_to_num. It's numerically unstable but gives the user no good insight of this issue. Can you describe the type of None errors that you are getting. |
||
y_m = torch.nan_to_num(self.mean_encoder(y)) | ||
y_v = self.var_encoder(y, x_m=y_m) | ||
|
||
outputs = {"y_m": y_m, "y_v": y_v} | ||
|
||
if self.sample: | ||
y = Normal(y_m, y_v.sqrt()).rsample() | ||
outputs["y"] = y | ||
|
||
return outputs | ||
|
||
|
||
class Layers(Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this class is similar to the existing implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inheriting this will make it also easier for us to add new functionality to sysVI. Are there issues with the FClayers class? |
||
"""A helper class to build fully-connected layers for a neural network. | ||
|
||
Adapted from scVI FCLayers to use covariates more flexibly | ||
|
||
Parameters | ||
---------- | ||
n_in | ||
The dimensionality of the main input | ||
n_out | ||
The dimensionality of the output | ||
n_cov | ||
Dimensionality of covariates. | ||
If there are no cov this should be set to None - | ||
in this case cov will not be used. | ||
n_layers | ||
The number of fully-connected hidden layers | ||
n_hidden | ||
The number of nodes per hidden layer | ||
dropout_rate | ||
Dropout rate to apply to each of the hidden layers | ||
use_batch_norm | ||
Whether to have `BatchNorm` layers or not | ||
use_layer_norm | ||
Whether to have `LayerNorm` layers or not | ||
use_activation | ||
Whether to have layer activation or not | ||
bias | ||
Whether to learn bias in linear layers or not | ||
inject_covariates | ||
Whether to inject covariates in each layer, or just the first. | ||
activation_fn | ||
Which activation function to use | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_in: int, | ||
n_out: int, | ||
n_cov: int | None = None, | ||
n_layers: int = 1, | ||
n_hidden: int = 128, | ||
dropout_rate: float = 0.1, | ||
use_batch_norm: bool = True, | ||
use_layer_norm: bool = False, | ||
use_activation: bool = True, | ||
bias: bool = True, | ||
inject_covariates: bool = True, | ||
activation_fn: Module = ReLU, | ||
): | ||
super().__init__() | ||
|
||
self.inject_covariates = inject_covariates | ||
self.n_cov = n_cov if n_cov is not None else 0 | ||
|
||
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] | ||
|
||
self.fc_layers = Sequential( | ||
OrderedDict( | ||
[ | ||
( | ||
f"Layer {i}", | ||
Sequential( | ||
Linear( | ||
n_in + self.n_cov * self.inject_into_layer(i), | ||
n_out, | ||
bias=bias, | ||
), | ||
# non-default params come from defaults in original Tensorflow implementation | ||
BatchNorm1d(n_out, momentum=0.01, eps=0.001) | ||
if use_batch_norm | ||
else None, | ||
LayerNorm(n_out, elementwise_affine=False) if use_layer_norm else None, | ||
activation_fn() if use_activation else None, | ||
Dropout(p=dropout_rate) if dropout_rate > 0 else None, | ||
), | ||
) | ||
for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) | ||
] | ||
) | ||
) | ||
|
||
def inject_into_layer(self, layer_num) -> bool: | ||
"""Helper to determine if covariates should be injected.""" | ||
user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates) | ||
return user_cond | ||
|
||
def set_online_update_hooks(self, hook_first_layer=True): | ||
self.hooks = [] | ||
|
||
def _hook_fn_weight(grad): | ||
new_grad = torch.zeros_like(grad) | ||
if self.n_cov > 0: | ||
new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] | ||
return new_grad | ||
|
||
def _hook_fn_zero_out(grad): | ||
return grad * 0 | ||
|
||
for i, layers in enumerate(self.fc_layers): | ||
for layer in layers: | ||
if i == 0 and not hook_first_layer: | ||
continue | ||
if isinstance(layer, Linear): | ||
if self.inject_into_layer(i): | ||
w = layer.weight.register_hook(_hook_fn_weight) | ||
else: | ||
w = layer.weight.register_hook(_hook_fn_zero_out) | ||
self.hooks.append(w) | ||
b = layer.bias.register_hook(_hook_fn_zero_out) | ||
self.hooks.append(b) | ||
|
||
def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): | ||
""" | ||
Forward computation on ``x``. | ||
|
||
Parameters | ||
---------- | ||
x | ||
tensor of values with shape ``(n_in,)`` | ||
cov | ||
tensor of covariate values with shape ``(n_cov,)`` or None | ||
|
||
Returns | ||
------- | ||
py:class:`torch.Tensor` | ||
tensor of shape ``(n_out,)`` | ||
|
||
""" | ||
for i, layers in enumerate(self.fc_layers): | ||
for layer in layers: | ||
if layer is not None: | ||
if isinstance(layer, BatchNorm1d): | ||
if x.dim() == 3: | ||
x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) | ||
else: | ||
x = layer(x) | ||
else: | ||
# Injection of covariates | ||
if ( | ||
self.n_cov > 0 | ||
and isinstance(layer, Linear) | ||
and self.inject_into_layer(i) | ||
): | ||
x = torch.cat((x, cov), dim=-1) | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class VarEncoder(Module): | ||
"""Encode variance (strictly positive). | ||
|
||
Parameters | ||
---------- | ||
n_input | ||
Number of input dimensions, used if mode is sample_feature | ||
n_output | ||
Number of variances to predict | ||
mode | ||
How to compute var | ||
'sample_feature' - learn per sample and feature | ||
'feature' - learn per feature, constant across samples | ||
'linear' - linear with respect to input mean, var = a1 * mean + a0; | ||
not suggested to be used due to bad implementation for positive constraining | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you format this similar to how it is done here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reformat documentation. This provides a better formatting in sphynx. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you define feature better (it is latent dimensions right?) |
||
eps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs documentation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Up. |
||
""" | ||
|
||
def __init__( | ||
self, | ||
n_input: int, | ||
n_output: int, | ||
mode: Literal["sample_feature", "feature", "linear"], | ||
eps: float = 1e-4, | ||
): | ||
super().__init__() | ||
|
||
self.eps = eps | ||
self.mode = mode | ||
if self.mode == "sample_feature": | ||
self.encoder = Linear(n_input, n_output) | ||
elif self.mode == "feature": | ||
self.var_param = Parameter(torch.zeros(1, n_output)) | ||
elif self.mode == "linear": | ||
self.var_param_a1 = Parameter(torch.tensor([1.0])) | ||
self.var_param_a0 = Parameter(torch.tensor([self.eps])) | ||
else: | ||
raise ValueError("Mode not recognised.") | ||
self.activation = torch.exp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably a good idea to make this an adjustable parameter as we've experienced numerical stability issues with the exponential. |
||
|
||
def forward(self, x: torch.Tensor, x_m: torch.Tensor): | ||
"""Forward pass through model | ||
|
||
Parameters | ||
---------- | ||
x | ||
Used to encode var if mode is sample_feature; dim = n_samples x n_input | ||
x_m | ||
Used to predict var instead of x if mode is linear; dim = n_samples x 1 | ||
|
||
Returns | ||
------- | ||
Predicted var | ||
""" | ||
# Force to be non nan - TODO come up with better way to do so | ||
if self.mode == "sample_feature": | ||
v = self.encoder(x) | ||
v = ( | ||
torch.nan_to_num(self.activation(v)) + self.eps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please try to avoid nan_to_num There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what would you suggest? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using softplus activation will be much more stable than exp. Do you need exp here for specific reasons? Clamping v would otherwise be safe (something like 20). |
||
) # Ensure that var is strictly positive | ||
elif self.mode == "feature": | ||
v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size | ||
v = ( | ||
torch.nan_to_num(self.activation(v)) + self.eps | ||
) # Ensure that var is strictly positive | ||
elif self.mode == "linear": | ||
v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0 | ||
# TODO come up with a better way to constrain this to positive while having lin relationship | ||
# Could activation be used for log-lin relationship? | ||
v = torch.clamp(torch.nan_to_num(v), min=self.eps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this every producing Nan errors. Printing a warning if v is Nan would at least be helpful. |
||
return v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you switch this to use the Embedding class within scvi-tools. Should perform similarly and otherwise we can add options there.