Skip to content

Commit

Permalink
fixed unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Nov 14, 2024
1 parent 6804525 commit 1fda0b1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,14 @@ def _load_masks_as_semantic_label(
slice_index: Whether to return only a specific slice.
"""
masks_dir = self._get_masks_dir(sample_index)
classes = self._class_mappings.keys() if self._class_mappings else self.classes
classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]

if self._class_mappings:
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(self.classes[1:])
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
self.classes[1:]
)
for original_class, mapped_class in self._class_mappings.items():
mapped_index = self.class_to_idx[mapped_class] - 1
original_index = list(self._class_mappings.keys()).index(original_class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i
assert "slice_index" in metadata

# check the number of classes with v.s. without class mappings
n_classes_expected = 2 if total_segmentator_dataset._class_mappings is not None else 3
n_classes_expected = 3 if total_segmentator_dataset._class_mappings is not None else 4
assert len(total_segmentator_dataset.classes) == n_classes_expected


Expand Down

0 comments on commit 1fda0b1

Please sign in to comment.