diff --git a/trainer.py b/trainer.py index a726dad64..f370603dc 100644 --- a/trainer.py +++ b/trainer.py @@ -466,7 +466,7 @@ def compute_losses(self, inputs, outputs): if not self.opt.disable_automasking: # add random numbers to break ties identity_reprojection_loss += torch.randn( - identity_reprojection_loss.shape).cuda() * 0.00001 + identity_reprojection_loss.shape, device=self.device) * 0.00001 combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1) else: