Skip to content

Commit

Permalink
fix: Support tensors and arrays for class_weight (#1413)
Browse files Browse the repository at this point in the history
Avoids ambiguous truth value ValueError when the class_weight input
parameter is either a PyTorch tensor or a NumPy array.

Includes new tests for SemanticSegmentationTask's class_weight
parameter.
  • Loading branch information
ntw-au authored Jul 7, 2023
1 parent 8807b9f commit b0ae5be
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
21 changes: 21 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, cast

import numpy as np
import pytest
import segmentation_models_pytorch as smp
import timm
Expand Down Expand Up @@ -262,3 +263,23 @@ def test_freeze_decoder(
for param in model.model.segmentation_head.parameters()
]
)

@pytest.mark.parametrize(
"class_weights", [torch.tensor([1, 2, 3]), np.array([1, 2, 3]), [1, 2, 3]]
)
def test_classweights_valid(
self, class_weights: Any, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["class_weights"] = class_weights
sst = SemanticSegmentationTask(**model_kwargs)
assert isinstance(sst.loss.weight, torch.Tensor)
assert torch.equal(sst.loss.weight, torch.tensor([1.0, 2.0, 3.0]))
assert sst.loss.weight.dtype == torch.float32

@pytest.mark.parametrize("class_weights", [[], None])
def test_classweights_empty(
self, class_weights: Any, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["class_weights"] = class_weights
sst = SemanticSegmentationTask(**model_kwargs)
assert sst.loss.weight is None
9 changes: 6 additions & 3 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ def config_task(self) -> None:
if self.hyperparams["loss"] == "ce":
ignore_value = -1000 if self.ignore_index is None else self.ignore_index

class_weights = (
torch.FloatTensor(self.class_weights) if self.class_weights else None
)
class_weights = None
if isinstance(self.class_weights, torch.Tensor):
class_weights = self.class_weights.to(dtype=torch.float32)
elif hasattr(self.class_weights, "__array__") or self.class_weights:
class_weights = torch.tensor(self.class_weights, dtype=torch.float32)

self.loss = nn.CrossEntropyLoss(
ignore_index=ignore_value, weight=class_weights
)
Expand Down

0 comments on commit b0ae5be

Please sign in to comment.