Skip to content

Commit

Permalink
Fix test for DiceLoss, change trainer_type fixture back to device fix…
Browse files Browse the repository at this point in the history
…ture
  • Loading branch information
NickleDave committed May 5, 2024
1 parent 151abe2 commit 316e7a3
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions tests/test_nn/test_loss/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,52 @@ def tensor_to_gradcheck_var(tensor, dtype=torch.float64, requires_grad=True):

# adapted from kornia, https://github.com/kornia/kornia/blob/master/test/test_losses.py
class TestDiceLoss:
def test_smoke(self, trainer_table, dtype):
def test_smoke(self, device, dtype):
num_classes = 3
logits = torch.rand(2, num_classes, 20, device=trainer_table, dtype=dtype)
logits = torch.rand(2, num_classes, 20, device=device, dtype=dtype)
labels = torch.rand(2, 20) * num_classes
labels = labels.to(trainer_table).long()
labels = labels.to(device).long()

criterion = vak.nn.loss.DiceLoss()
assert criterion(logits, labels) is not None

def test_all_zeros(self, trainer_table, dtype):
def test_all_zeros(self, device, dtype):
num_classes = 3
logits = torch.zeros(2, num_classes, 20, device=trainer_table, dtype=dtype)
logits = torch.zeros(2, num_classes, 20, device=device, dtype=dtype)
logits[:, 0] = 10.0
logits[:, 1] = 1.0
logits[:, 2] = 1.0
labels = torch.zeros(2, 20, device=trainer_table, dtype=torch.int64)
labels = torch.zeros(2, 20, device=device, dtype=torch.int64)

criterion = vak.nn.loss.DiceLoss()
loss = criterion(logits, labels)
assert_close(loss, torch.zeros_like(loss), rtol=1e-3, atol=1e-3)

def test_gradcheck(self, trainer_table, dtype):
def test_gradcheck(self, device, dtype):
num_classes = 3
logits = torch.rand(2, num_classes, 20, device=trainer_table, dtype=dtype)
logits = torch.rand(2, num_classes, 20, device=device, dtype=dtype)
labels = torch.rand(2, 20) * num_classes
labels = labels.to(trainer_table).long()
labels = labels.to(device).long()

logits = tensor_to_gradcheck_var(logits) # to var
assert gradcheck(vak.nn.dice_loss, (logits, labels), raise_exception=True)

def test_jit(self, trainer_table, dtype):
def test_jit(self, device, dtype):
num_classes = 3
logits = torch.rand(2, num_classes, 20, device=trainer_table, dtype=dtype)
logits = torch.rand(2, num_classes, 20, device=device, dtype=dtype)
labels = torch.rand(2, 20) * num_classes
labels = labels.to(trainer_table).long()
labels = labels.to(device).long()

op = vak.nn.dice_loss
op_script = torch.jit.script(op)

assert_close(op(logits, labels), op_script(logits, labels))

def test_module(self, trainer_table, dtype):
def test_module(self, device, dtype):
num_classes = 3
logits = torch.rand(2, num_classes, 20, device=trainer_table, dtype=dtype)
logits = torch.rand(2, num_classes, 20, device=device, dtype=dtype)
labels = torch.rand(2, 20) * num_classes
labels = labels.to(trainer_table).long()
labels = labels.to(device).long()

op = vak.nn.dice_loss
op_module = vak.nn.loss.DiceLoss()
Expand Down

0 comments on commit 316e7a3

Please sign in to comment.