Skip to content

Commit

Permalink
I think this fixes the bug where vq loss is small compared to recon l…
Browse files Browse the repository at this point in the history
…oss (#125)
  • Loading branch information
ctr26 authored Jan 21, 2024
1 parent 18fe279 commit b5c7462
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/pythae/models/vq_vae/vq_vae_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def forward(self, z: torch.Tensor, uses_ddp: bool = False):
commitment_loss = F.mse_loss(
quantized.detach().reshape(-1, self.embedding_dim),
z.reshape(-1, self.embedding_dim),
reduction="mean",
reduction="sum",
)

embedding_loss = F.mse_loss(
quantized.reshape(-1, self.embedding_dim),
z.detach().reshape(-1, self.embedding_dim),
reduction="mean",
).mean(dim=-1)
reduction="sum",
)

quantized = z + (quantized - z).detach()

Expand Down Expand Up @@ -147,7 +147,7 @@ def forward(self, z: torch.Tensor, uses_ddp: bool = False):
commitment_loss = F.mse_loss(
quantized.detach().reshape(-1, self.embedding_dim),
z.reshape(-1, self.embedding_dim),
reduction="mean",
reduction="sum",
)

quantized = z + (quantized - z).detach()
Expand Down

0 comments on commit b5c7462

Please sign in to comment.