Skip to content

Commit

Permalink
Merge branch 'main' into renovate/wntrblm-nox-2024.x
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop authored Apr 26, 2024
2 parents 6245a1b + 767fbbd commit d75cf49
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 66 deletions.
59 changes: 22 additions & 37 deletions src/eva/vision/data/datasets/segmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,28 @@
import abc
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
from torchvision import tv_tensors
from typing_extensions import override

from eva.vision.data.datasets import vision


class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
"""Image segmentation abstract dataset."""

def __init__(
self,
image_transforms: Callable | None = None,
target_transforms: Callable | None = None,
image_target_transforms: Callable | None = None,
transforms: Callable | None = None,
) -> None:
"""Initializes the image segmentation base class.
Args:
image_transforms: A function/transform that takes in an image
and returns a transformed version.
target_transforms: A function/transform that takes in the target
and transforms it.
image_target_transforms: A function/transforms that takes in an
transforms: A function/transforms that takes in an
image and a label and returns the transformed versions of both.
This transform happens after the `image_transforms` and
`target_transforms`.
"""
super().__init__()

self._image_transforms = image_transforms
self._target_transforms = target_transforms
self._image_target_transforms = image_target_transforms
self._transforms = transforms

@property
def classes(self) -> List[str] | None:
Expand All @@ -56,25 +46,26 @@ def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, An
"""

@abc.abstractmethod
def load_image(self, index: int) -> np.ndarray:
def load_image(self, index: int) -> tv_tensors.Image:
"""Loads and returns the `index`'th image sample.
Args:
index: The index of the data sample to load.
Returns:
The image as a numpy array.
An image torchvision tensor (channels, height, width).
"""

@abc.abstractmethod
def load_mask(self, index: int) -> np.ndarray:
"""Returns the `index`'th target mask sample.
def load_masks(self, index: int) -> tv_tensors.Mask:
"""Returns the `index`'th target masks sample.
Args:
index: The index of the data sample target mask to load.
index: The index of the data sample target masks to load.
Returns:
The sample mask as a stack of binary mask arrays (label, height, width).
The sample masks as a stack of binary torchvision mask
tensors (label, height, width).
"""

@abc.abstractmethod
Expand All @@ -83,30 +74,24 @@ def __len__(self) -> int:
raise NotImplementedError

@override
def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
image = self.load_image(index)
mask = self.load_mask(index)
return self._apply_transforms(image, mask)
masks = self.load_masks(index)
return self._apply_transforms(image, masks)

def _apply_transforms(
self, image: np.ndarray, target: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
self, image: tv_tensors.Image, masks: tv_tensors.Mask
) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
"""Applies the transforms to the provided data and returns them.
Args:
image: The desired image.
target: The target of the image.
masks: The target masks of the image.
Returns:
A tuple with the image and the target transformed.
A tuple with the image and the masks transformed.
"""
if self._image_transforms is not None:
image = self._image_transforms(image)
if self._transforms is not None:
image, masks = self._transforms(image, masks)

if self._target_transforms is not None:
target = self._target_transforms(target)

if self._image_target_transforms is not None:
image, target = self._image_target_transforms(image, target)

return image, target
return image, masks
36 changes: 15 additions & 21 deletions src/eva/vision/data/datasets/segmentation/total_segmentator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, Dict, List, Literal, Tuple

import numpy as np
from torchvision import tv_tensors
from torchvision.datasets import utils
from typing_extensions import override

Expand Down Expand Up @@ -50,9 +51,7 @@ def __init__(
split: Literal["train", "val"] | None,
version: Literal["small", "full"] = "small",
download: bool = False,
image_transforms: Callable | None = None,
target_transforms: Callable | None = None,
image_target_transforms: Callable | None = None,
transforms: Callable | None = None,
) -> None:
"""Initialize dataset.
Expand All @@ -65,20 +64,10 @@ def __init__(
Note that the download will be executed only by additionally
calling the :meth:`prepare_data` method and if the data does not
exist yet on disk.
image_transforms: A function/transform that takes in an image
and returns a transformed version.
target_transforms: A function/transform that takes in the target
and transforms it.
image_target_transforms: A function/transforms that takes in an
image and a label and returns the transformed versions of both.
This transform happens after the `image_transforms` and
`target_transforms`.
transforms: A function/transforms that takes in an image and a target
mask and returns the transformed versions of both.
"""
super().__init__(
image_transforms=image_transforms,
target_transforms=target_transforms,
image_target_transforms=image_target_transforms,
)
super().__init__(transforms=transforms)

self._root = root
self._split = split
Expand Down Expand Up @@ -134,19 +123,21 @@ def __len__(self) -> int:
return len(self._indices) * self._n_slices_per_image

@override
def load_image(self, index: int) -> np.ndarray:
def load_image(self, index: int) -> tv_tensors.Image:
image_path = self._get_image_path(index)
slice_index = self._get_sample_slice_index(index)
image_array = io.read_nifti_slice(image_path, slice_index)
return image_array.repeat(3, axis=2)
image_rgb_array = image_array.repeat(3, axis=2)
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))

@override
def load_mask(self, index: int) -> np.ndarray:
def load_masks(self, index: int) -> tv_tensors.Mask:
masks_dir = self._get_masks_dir(index)
slice_index = self._get_sample_slice_index(index)
mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths]
return np.concatenate(masks, axis=-1)
list_of_mask_arrays = [io.read_nifti_slice(path, slice_index) for path in mask_paths]
masks = np.concatenate(list_of_mask_arrays, axis=2)
return tv_tensors.Mask(masks.transpose(2, 0, 1))

def _get_masks_dir(self, index: int) -> str:
"""Returns the directory of the corresponding masks."""
Expand Down Expand Up @@ -204,6 +195,9 @@ def _download_dataset(self) -> None:
raise ValueError("Invalid data version. Use 'small' or 'full'.")

for resource in resources:
if os.path.isdir(self._root):
continue

utils.download_and_extract_archive(
resource.url,
download_root=self._root,
Expand Down
5 changes: 2 additions & 3 deletions src/eva/vision/utils/io/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ def read_nifti_slice(path: str, slice_index: int) -> npt.NDArray[Any]:
Args:
path: The path to the NIfTI file.
slice_index: The image slice index to return. If `None`, it will
return the full 3D image.
slice_index: The image slice index to return.
Returns:
The image as a numpy array.
The image as a numpy array (height, width, channels).
Raises:
FileExistsError: If the path does not exist or it is unreachable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
from typing import Literal

import numpy as np
import pytest
from torchvision import tv_tensors

from eva.vision.data import datasets

Expand Down Expand Up @@ -35,10 +35,10 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i
assert len(sample) == 2
# assert the format of the `image` and `mask`
image, mask = sample
assert isinstance(image, np.ndarray)
assert image.shape == (16, 16, 3)
assert isinstance(mask, np.ndarray)
assert mask.shape == (16, 16, 3)
assert isinstance(image, tv_tensors.Image)
assert image.shape == (3, 16, 16)
assert isinstance(mask, tv_tensors.Mask)
assert mask.shape == (3, 16, 16)


@pytest.fixture(scope="function")
Expand Down

0 comments on commit d75cf49

Please sign in to comment.