Skip to content

Commit

Permalink
fix: added loss calculation in DSPNAE
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Stablum committed Nov 3, 2021
1 parent 544b526 commit 8038f4b
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions models/dspn_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,21 @@ def forward(self, loaded_set):
self.reconstructed = self.decoder(self.code)
return self.reconstructed

def training_step(self, batch, batch_idx):
print("training_set",batch.shape,batch_idx)
return 1
def _step(self, batch, batch_idx, which_tset):
# copied from dspn.train.main.run()
(progress, masks, evals, gradn), (y_enc, y_label) = self( input, batch )

set_loss = dspn.utils.chamfer_loss(
torch.stack(progress), batch.unsqueeze(0)
)

return set_loss.mean()

def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, 'train')

def validation_step(self, batch, batch_idx):
return 1
return self._step(batch, batch_idx, 'val')


if __name__ == "__main__":
Expand Down

0 comments on commit 8038f4b

Please sign in to comment.