Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L1 loss added to the models #85

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

soumickmj
Copy link
Contributor

MSE loss sometimes make the models produce smooth images. L1 loss is an easy drop-in fix for the same. I have added to the models where we already had MSE and BCE loss functions, and skipped (for now) the ones without the recon loss flag. In the futurue, I will also add a feature to pass custom loss functions.

Copy link
Owner

@clementchadebec clementchadebec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @soumickmj,
Thank you very much for this contribution and sorry for the late reply. Please see some comments that, in my opinion, should be addressed before merging.

  • First, we need to allow to pass l1 as argument of the VAEConfig (see comment on vae_config.py)
  • Second, since the l1 loss does not really refer to a distribution (e.g. BCE = Bernoulli, MSE = multivariate standard), we should not allow the computation of the NLL when such a loss is chosen. Instead, I propose to add an exception is get_nll method is called with l1 as loss as follows
raise NotImplementedError("Computation of the likelihood is not implemented when `L1 loss` is chosen")

I any case, do not hesitate if you have any questions or do not agree with the proposed modifications.

Best,

Clément

@@ -286,6 +294,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation of the likelihood cannot be handled with the l1 loss since it does not correspond to a tractable distribution per say.

@@ -572,6 +586,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation of the likelihood cannot be handled with the l1 loss since it does not correspond to a tractable distribution per say.

@@ -11,7 +11,7 @@ class VAEConfig(BaseAEConfig):
Parameters:
input_dim (tuple): The input_data dimension.
latent_dim (int): The latent space dimension. Default: None.
reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
"""

reconstruction_loss: Literal["bce", "mse"] = "mse"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be replaced by

 reconstruction_loss: Literal["bce", "mse", "l1"] = "mse"

@@ -424,6 +424,20 @@ def _log_p_x_given_z(self, recon_x, x):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this model, this should correspond to a known distribution. The mse actually models a multivariate normal. In the case of L1 loss, there is not really an underlying distribution so we should not allow the L1 for this model.

@@ -9,6 +9,7 @@ class RHVAEConfig(VAEConfig):

Parameters:
latent_dim (int): The latent dimension used for the latent space. Default: 10
reconstruction_loss (str): The reconstruction loss to use ['bce', 'l1', 'mse']. Default: 'mse'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l1 cannot be allowed in this model (see next comment)

@@ -186,6 +194,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not compute the likelihood when l1 loss is chosen since there is no associated distribution. We may add a warning with the following message

raise NotImplementedError("Computation of the likelihood is not implemented when `L1 loss` is chosen")

@@ -212,6 +220,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before for the computation of the NLL.

@@ -226,6 +234,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before for the computation of the NLL.

@@ -243,6 +251,18 @@ def get_nll(self, data, n_samples=1, batch_size=100):
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "l1":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before for the computation of the NLL.

@@ -38,6 +38,9 @@ def model_configs_no_input_dim(request):
RHVAEConfig(
input_dim=(1, 28, 28), latent_dim=1, n_lf=1, reconstruction_loss="bce"
),
RHVAEConfig(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed since this model will not handled L1 loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants