From b0ae5be35d2c3b1599daed5f2e40934bb467c27d Mon Sep 17 00:00:00 2001 From: ntw-au <59516633+ntw-au@users.noreply.github.com> Date: Sat, 8 Jul 2023 08:30:16 +1000 Subject: [PATCH] fix: Support tensors and arrays for class_weight (#1413) Avoids ambiguous truth value ValueError when the class_weight input parameter is either a PyTorch tensor or a NumPy array. Includes new tests for SemanticSegmentationTask's class_weight parameter. --- tests/trainers/test_segmentation.py | 21 +++++++++++++++++++++ torchgeo/trainers/segmentation.py | 9 ++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index c1c1663f943..c29d37e24eb 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, cast +import numpy as np import pytest import segmentation_models_pytorch as smp import timm @@ -262,3 +263,23 @@ def test_freeze_decoder( for param in model.model.segmentation_head.parameters() ] ) + + @pytest.mark.parametrize( + "class_weights", [torch.tensor([1, 2, 3]), np.array([1, 2, 3]), [1, 2, 3]] + ) + def test_classweights_valid( + self, class_weights: Any, model_kwargs: dict[Any, Any] + ) -> None: + model_kwargs["class_weights"] = class_weights + sst = SemanticSegmentationTask(**model_kwargs) + assert isinstance(sst.loss.weight, torch.Tensor) + assert torch.equal(sst.loss.weight, torch.tensor([1.0, 2.0, 3.0])) + assert sst.loss.weight.dtype == torch.float32 + + @pytest.mark.parametrize("class_weights", [[], None]) + def test_classweights_empty( + self, class_weights: Any, model_kwargs: dict[Any, Any] + ) -> None: + model_kwargs["class_weights"] = class_weights + sst = SemanticSegmentationTask(**model_kwargs) + assert sst.loss.weight is None diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 2107371cf53..e0497de1b9c 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -65,9 +65,12 @@ def config_task(self) -> None: if self.hyperparams["loss"] == "ce": ignore_value = -1000 if self.ignore_index is None else self.ignore_index - class_weights = ( - torch.FloatTensor(self.class_weights) if self.class_weights else None - ) + class_weights = None + if isinstance(self.class_weights, torch.Tensor): + class_weights = self.class_weights.to(dtype=torch.float32) + elif hasattr(self.class_weights, "__array__") or self.class_weights: + class_weights = torch.tensor(self.class_weights, dtype=torch.float32) + self.loss = nn.CrossEntropyLoss( ignore_index=ignore_value, weight=class_weights )