From 2b5a1ce1d63db83b12d97680423d6bfef73bea26 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Tue, 30 Apr 2024 11:34:50 +0200 Subject: [PATCH 01/21] Add simple decoder networks for segmentation tasks (#404) --- .github/workflows/ci.yaml | 4 - main.py | 0 .../vision/data/datasets/segmentation/base.py | 18 +-- .../segmentation/total_segmentator.py | 20 +++- .../models/networks/decoders/__init__.py | 6 + .../models/networks/decoders/decoder.py | 7 ++ .../decoders/segmentation/__init__.py | 6 + .../networks/decoders/segmentation/conv.py | 97 ++++++++++++++++ .../networks/decoders/segmentation/linear.py | 108 ++++++++++++++++++ src/eva/vision/utils/convert.py | 24 ++++ src/eva/vision/utils/io/nifti.py | 11 +- .../segmentation/test_total_segmentator.py | 2 +- .../models/networks/decoders/__init__.py | 1 + .../decoders/segmentation/__init__.py | 1 + .../networks/decoders/segmentation/conv.py | 51 +++++++++ .../networks/decoders/segmentation/linear.py | 40 +++++++ tests/eva/vision/utils/test_convert.py | 50 ++++++++ 17 files changed, 424 insertions(+), 22 deletions(-) create mode 100644 main.py create mode 100644 src/eva/vision/models/networks/decoders/__init__.py create mode 100644 src/eva/vision/models/networks/decoders/decoder.py create mode 100644 src/eva/vision/models/networks/decoders/segmentation/__init__.py create mode 100644 src/eva/vision/models/networks/decoders/segmentation/conv.py create mode 100644 src/eva/vision/models/networks/decoders/segmentation/linear.py create mode 100644 src/eva/vision/utils/convert.py create mode 100644 tests/eva/vision/models/networks/decoders/__init__.py create mode 100644 tests/eva/vision/models/networks/decoders/segmentation/__init__.py create mode 100644 tests/eva/vision/models/networks/decoders/segmentation/conv.py create mode 100644 tests/eva/vision/models/networks/decoders/segmentation/linear.py create mode 100644 tests/eva/vision/utils/test_convert.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cd5d8846..c004a64c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,10 +5,6 @@ on: workflow_dispatch: pull_request: branches: - - main - push: - branches: - - main jobs: security: diff --git a/main.py b/main.py new file mode 100644 index 00000000..e69de29b diff --git a/src/eva/vision/data/datasets/segmentation/base.py b/src/eva/vision/data/datasets/segmentation/base.py index d6fd5264..6686945c 100644 --- a/src/eva/vision/data/datasets/segmentation/base.py +++ b/src/eva/vision/data/datasets/segmentation/base.py @@ -57,15 +57,15 @@ def load_image(self, index: int) -> tv_tensors.Image: """ @abc.abstractmethod - def load_masks(self, index: int) -> tv_tensors.Mask: + def load_mask(self, index: int) -> tv_tensors.Mask: """Returns the `index`'th target masks sample. Args: index: The index of the data sample target masks to load. Returns: - The sample masks as a stack of binary torchvision mask - tensors (label, height, width). + The semantic mask as a (H x W) shaped tensor with integer + values which represent the pixel class id. """ @abc.abstractmethod @@ -76,22 +76,22 @@ def __len__(self) -> int: @override def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask]: image = self.load_image(index) - masks = self.load_masks(index) - return self._apply_transforms(image, masks) + mask = self.load_mask(index) + return self._apply_transforms(image, mask) def _apply_transforms( - self, image: tv_tensors.Image, masks: tv_tensors.Mask + self, image: tv_tensors.Image, mask: tv_tensors.Mask ) -> Tuple[tv_tensors.Image, tv_tensors.Mask]: """Applies the transforms to the provided data and returns them. Args: image: The desired image. - masks: The target masks of the image. + mask: The target segmentation mask. Returns: A tuple with the image and the masks transformed. """ if self._transforms is not None: - image, masks = self._transforms(image, masks) + image, mask = self._transforms(image, mask) - return image, masks + return image, mask diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index e8de271b..4892e6b6 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -12,7 +12,7 @@ from eva.vision.data.datasets import _utils, _validators, structs from eva.vision.data.datasets.segmentation import base -from eva.vision.utils import io +from eva.vision.utils import convert, io class TotalSegmentator2D(base.ImageSegmentation): @@ -51,6 +51,7 @@ def __init__( split: Literal["train", "val"] | None, version: Literal["small", "full"] = "small", download: bool = False, + as_uint8: bool = True, transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -64,6 +65,7 @@ def __init__( Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not exist yet on disk. + as_uint8: Whether to convert and return the images as a 8-bit. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. """ @@ -73,6 +75,7 @@ def __init__( self._split = split self._version = version self._download = download + self._as_uint8 = as_uint8 self._samples_dirs: List[str] = [] self._indices: List[int] = [] @@ -127,17 +130,24 @@ def load_image(self, index: int) -> tv_tensors.Image: image_path = self._get_image_path(index) slice_index = self._get_sample_slice_index(index) image_array = io.read_nifti_slice(image_path, slice_index) + if self._as_uint8: + image_array = convert.to_8bit(image_array) image_rgb_array = image_array.repeat(3, axis=2) return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1)) @override - def load_masks(self, index: int) -> tv_tensors.Mask: + def load_mask(self, index: int) -> tv_tensors.Mask: masks_dir = self._get_masks_dir(index) slice_index = self._get_sample_slice_index(index) mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - list_of_mask_arrays = [io.read_nifti_slice(path, slice_index) for path in mask_paths] - masks = np.concatenate(list_of_mask_arrays, axis=2) - return tv_tensors.Mask(masks.transpose(2, 0, 1)) + one_hot_encoded = np.concatenate( + [io.read_nifti_slice(path, slice_index) for path in mask_paths], + axis=2, + ) + background_mask = one_hot_encoded.sum(axis=2, keepdims=True) == 0 + one_hot_encoded_with_bg = np.concatenate([background_mask, one_hot_encoded], axis=2) + segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2) + return tv_tensors.Mask(segmentation_label) def _get_masks_dir(self, index: int) -> str: """Returns the directory of the corresponding masks.""" diff --git a/src/eva/vision/models/networks/decoders/__init__.py b/src/eva/vision/models/networks/decoders/__init__.py new file mode 100644 index 00000000..1f659ad1 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/__init__.py @@ -0,0 +1,6 @@ +"""Decoder heads API.""" + +from eva.vision.models.networks.decoders import segmentation +from eva.vision.models.networks.decoders.decoder import Decoder + +__all__ = ["segmentation", "Decoder"] diff --git a/src/eva/vision/models/networks/decoders/decoder.py b/src/eva/vision/models/networks/decoders/decoder.py new file mode 100644 index 00000000..3299a290 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/decoder.py @@ -0,0 +1,7 @@ +"""Semantic segmentation decoder base class.""" + +from torch import nn + + +class Decoder(nn.Module): + """Semantic segmentation decoder base class.""" diff --git a/src/eva/vision/models/networks/decoders/segmentation/__init__.py b/src/eva/vision/models/networks/decoders/segmentation/__init__.py new file mode 100644 index 00000000..8a5f014f --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/__init__.py @@ -0,0 +1,6 @@ +"""Segmentation decoder heads API.""" + +from eva.vision.models.networks.decoders.segmentation.conv import ConvDecoder +from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder + +__all__ = ["ConvDecoder", "LinearDecoder"] diff --git a/src/eva/vision/models/networks/decoders/segmentation/conv.py b/src/eva/vision/models/networks/decoders/segmentation/conv.py new file mode 100644 index 00000000..a21749bc --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/conv.py @@ -0,0 +1,97 @@ +"""Convolutional based semantic segmentation decoder.""" + +from typing import List, Tuple + +import torch +from torch import nn +from torch.nn import functional + +from eva.vision.models.networks.decoders import decoder + + +class ConvDecoder(decoder.Decoder): + """Convolutional segmentation decoder.""" + + def __init__(self, layers: nn.Module) -> None: + """Initializes the convolutional based decoder head. + + Here the input nn layers will be directly applied to the + features of shape (batch_size, hidden_size, n_patches_height, + n_patches_width), where n_patches is image_size / patch_size. + Note the n_patches is also known as grid_size. + + Args: + layers: The convolutional layers to be used as the decoder head. + """ + super().__init__() + + self._layers = layers + + def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: + """Forward function for multi-level feature maps to a single one. + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + + Returns: + A tensor of shape (batch_size, hidden_size, n_patches_height, + n_patches_width) which is feature map of the decoder head. + """ + if not isinstance(features, list) and features[0].ndim != 4: + raise ValueError( + "Input features should be a list of four (4) dimensional inputs of " + "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." + ) + + return features[-1] + + def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: + """Forward of the decoder head. + + Args: + patch_embeddings: The patch embeddings tensor of shape + (batch_size, hidden_size, n_patches_height, n_patches_width). + + Returns: + The logits as a tensor (batch_size, n_classes, upscale_height, upscale_width). + """ + return self._layers(patch_embeddings) + + def _cls_seg( + self, + logits: torch.Tensor, + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Classify each pixel of the image. + + Args: + logits: The decoder outputs of shape (batch_size, n_classes, + height, width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + return functional.interpolate(logits, image_size, mode="bilinear") + + def forward( + self, + features: List[torch.Tensor], + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Maps the patch embeddings to a segmentation mask of the image size. + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + patch_embeddings = self._forward_features(features) + logits = self._forward_head(patch_embeddings) + return self._cls_seg(logits, image_size) diff --git a/src/eva/vision/models/networks/decoders/segmentation/linear.py b/src/eva/vision/models/networks/decoders/segmentation/linear.py new file mode 100644 index 00000000..6b123cf5 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/linear.py @@ -0,0 +1,108 @@ +"""Linear based decoder.""" + +from typing import List, Tuple + +import torch +from torch import nn +from torch.nn import functional + +from eva.vision.models.networks.decoders import decoder + + +class LinearDecoder(decoder.Decoder): + """Linear decoder.""" + + def __init__(self, layers: nn.Module) -> None: + """Initializes the linear based decoder head. + + Here the input nn layers will be applied to the reshaped + features (batch_size, patch_embeddings, hidden_size) from + the input (batch_size, hidden_size, height, width) and then + unwrapped again to (batch_size, n_classes, height, width). + + Args: + layers: The linear layers to be used as the decoder head. + """ + super().__init__() + + self._layers = layers + + def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: + """Forward function for multi-level feature maps to a single one. + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + + Returns: + A tensor of shape (batch_size, hidden_size, n_patches_height, + n_patches_width) which is feature map of the decoder head. + """ + if not isinstance(features, list) and features[0].ndim != 4: + raise ValueError( + "Input features should be a list of four (4) dimensional inputs of " + "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." + ) + + return features[-1] + + def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: + """Forward of the decoder head. + + Here the following transformations will take place: + - (batch_size, hidden_size, n_patches_height, n_patches_width) + - (batch_size, hidden_size, n_patches_height * n_patches_width) + - (batch_size, n_patches_height * n_patches_width, hidden_size) + - (batch_size, n_patches_height * n_patches_width, n_classes) + - (batch_size, n_classes, n_patches_height, n_patches_width) + + Args: + patch_embeddings: The patch embeddings tensor of shape + (batch_size, hidden_size, n_patches_height, n_patches_width). + + Returns: + The logits as a tensor (batch_size, n_classes, n_patches_height, + n_patches_width). + """ + batch_size, hidden_size, height, width = patch_embeddings.shape + embeddings_reshaped = patch_embeddings.reshape(batch_size, hidden_size, height * width) + logits = self._layers(embeddings_reshaped.permute(0, 2, 1)) + return logits.permute(0, 2, 1).reshape(batch_size, -1, height, width) + + def _cls_seg( + self, + logits: torch.Tensor, + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Classify each pixel of the image. + + Args: + logits: The decoder outputs of shape (batch_size, n_classes, + height, width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + return functional.interpolate(logits, image_size, mode="bilinear") + + def forward( + self, + features: List[torch.Tensor], + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Maps the patch embeddings to a segmentation mask of the image size. + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + patch_embeddings = self._forward_features(features) + logits = self._forward_head(patch_embeddings) + return self._cls_seg(logits, image_size) diff --git a/src/eva/vision/utils/convert.py b/src/eva/vision/utils/convert.py new file mode 100644 index 00000000..cfe22f55 --- /dev/null +++ b/src/eva/vision/utils/convert.py @@ -0,0 +1,24 @@ +"""Image conversion related functionalities.""" + +from typing import Any + +import numpy as np +import numpy.typing as npt + + +def to_8bit(image_array: npt.NDArray[Any]) -> npt.NDArray[np.uint8]: + """Casts an image of higher bit image (i.e. 16bit) to 8bit. + + Args: + image_array: The image array to convert. + + Returns: + The image as normalized as a 8-bit format. + """ + if np.issubdtype(image_array.dtype, np.integer): + image_array = image_array.astype(np.float64) + + image_scaled_array = image_array - image_array.min() + image_scaled_array /= image_scaled_array.max() + image_scaled_array *= 255 + return image_scaled_array.astype(np.uint8) diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index f2265bce..162f0a64 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -8,12 +8,16 @@ from eva.vision.utils.io import _utils -def read_nifti_slice(path: str, slice_index: int) -> npt.NDArray[Any]: +def read_nifti_slice( + path: str, slice_index: int, *, use_storage_dtype: bool = True +) -> npt.NDArray[Any]: """Reads and loads a NIfTI image from a file path as `uint8`. Args: path: The path to the NIfTI file. slice_index: The image slice index to return. + use_storage_dtype: Whether to cast the raw image + array to the inferred type. Returns: The image as a numpy array (height, width, channels). @@ -24,10 +28,11 @@ def read_nifti_slice(path: str, slice_index: int) -> npt.NDArray[Any]: """ _utils.check_file(path) image_data = nib.load(path) # type: ignore - dtype = image_data.get_data_dtype() # type: ignore image_slice = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore image_array = image_slice.get_fdata() - return image_array.astype(dtype) + if use_storage_dtype: + image_array = image_array.astype(image_data.get_data_dtype()) # type: ignore + return image_array def fetch_total_nifti_slices(path: str) -> int: diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 2e8f3abe..3e7f09e6 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -38,7 +38,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i assert isinstance(image, tv_tensors.Image) assert image.shape == (3, 16, 16) assert isinstance(mask, tv_tensors.Mask) - assert mask.shape == (3, 16, 16) + assert mask.shape == (16, 16) @pytest.fixture(scope="function") diff --git a/tests/eva/vision/models/networks/decoders/__init__.py b/tests/eva/vision/models/networks/decoders/__init__.py new file mode 100644 index 00000000..b27386c5 --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/__init__.py @@ -0,0 +1 @@ +"""Vision decoders tests.""" diff --git a/tests/eva/vision/models/networks/decoders/segmentation/__init__.py b/tests/eva/vision/models/networks/decoders/segmentation/__init__.py new file mode 100644 index 00000000..984d0cb7 --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/segmentation/__init__.py @@ -0,0 +1 @@ +"""Vision segmentation decoders tests.""" diff --git a/tests/eva/vision/models/networks/decoders/segmentation/conv.py b/tests/eva/vision/models/networks/decoders/segmentation/conv.py new file mode 100644 index 00000000..45cace1a --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/segmentation/conv.py @@ -0,0 +1,51 @@ +"""Tests for convolutional decoder.""" + +from typing import List, Tuple + +import pytest +import torch +from torch import nn + +from eva.vision.models.networks.decoders import segmentation + + +@pytest.mark.parametrize( + "layers, features, image_size, expected_shape", + [ + ( + nn.Conv2d(384, 5, kernel_size=(1, 1)), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(384, 64, kernel_size=(3, 3), padding=(1, 1)), + nn.Upsample(scale_factor=2), + nn.Conv2d(64, 5, kernel_size=(3, 3), padding=(1, 1)), + ), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ], +) +def test_conv_decoder( + conv_decoder: segmentation.ConvDecoder, + features: List[torch.Tensor], + image_size: Tuple[int, int], + expected_shape: torch.Size, +) -> None: + """Tests the ConvDecoder network.""" + logits = conv_decoder(features, image_size) + assert isinstance(logits, torch.Tensor) + assert logits.shape == expected_shape + + +@pytest.fixture(scope="function") +def conv_decoder( + layers: nn.Module, +) -> segmentation.ConvDecoder: + """ConvDecoder fixture.""" + return segmentation.ConvDecoder(layers=layers) diff --git a/tests/eva/vision/models/networks/decoders/segmentation/linear.py b/tests/eva/vision/models/networks/decoders/segmentation/linear.py new file mode 100644 index 00000000..bcc727da --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/segmentation/linear.py @@ -0,0 +1,40 @@ +"""Tests for linear decoder.""" + +from typing import List, Tuple + +import pytest +import torch +from torch import nn + +from eva.vision.models.networks.decoders import segmentation + + +@pytest.mark.parametrize( + "layers, features, image_size, expected_shape", + [ + ( + nn.Linear(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ], +) +def test_linear_decoder( + linear_decoder: segmentation.LinearDecoder, + features: List[torch.Tensor], + image_size: Tuple[int, int], + expected_shape: torch.Size, +) -> None: + """Tests the ConvDecoder network.""" + logits = linear_decoder(features, image_size) + assert isinstance(logits, torch.Tensor) + assert logits.shape == expected_shape + + +@pytest.fixture(scope="function") +def linear_decoder( + layers: nn.Module, +) -> segmentation.LinearDecoder: + """LinearDecoder fixture.""" + return segmentation.LinearDecoder(layers=layers) diff --git a/tests/eva/vision/utils/test_convert.py b/tests/eva/vision/utils/test_convert.py new file mode 100644 index 00000000..a9eb05eb --- /dev/null +++ b/tests/eva/vision/utils/test_convert.py @@ -0,0 +1,50 @@ +"""Tests image conversion related functionalities.""" + +from typing import Any + +import numpy as np +import numpy.typing as npt +from pytest import mark + +from eva.vision.utils import convert + +IMAGE_ARRAY_INT16 = np.array( + [ + [ + [-794, -339, -607, -950], + [-608, -81, 172, -834], + [-577, -10, 366, -837], + [-790, -325, -564, -969], + ] + ], + dtype=np.int16, +) +"""Test input data.""" + +EXPECTED_ARRAY_INT16 = np.array( + [ + [ + [33, 120, 69, 3], + [68, 169, 217, 25], + [74, 183, 255, 25], + [34, 123, 77, 0], + ] + ], + dtype=np.uint8, +) +"""Test expected/desired features.""" + + +@mark.parametrize( + "image_array, expected", + [ + [IMAGE_ARRAY_INT16, EXPECTED_ARRAY_INT16], + ], +) +def test_to_8bit( + image_array: npt.NDArray[Any], + expected: npt.NDArray[np.uint8], +) -> None: + """Tests the `to_8bit` image conversion.""" + actual = convert.to_8bit(image_array) + np.testing.assert_allclose(actual, expected) From a7e65e58dd0e98d6932bc6c9f873cab22b412488 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Tue, 30 Apr 2024 16:15:47 +0200 Subject: [PATCH 02/21] Add `timm` encoder networks (#403) --- pdm.lock | 567 ++++++++++-------- pyproject.toml | 4 +- src/eva/core/models/networks/__init__.py | 2 +- .../core/models/networks/wrappers/__init__.py | 7 +- src/eva/vision/models/networks/__init__.py | 3 +- .../models/networks/encoders/__init__.py | 6 + .../models/networks/encoders/encoder.py | 23 + .../models/networks/encoders/from_timm.py | 61 ++ src/eva/vision/models/networks/from_timm.py | 36 ++ .../models/networks/encoders/__init__.py | 1 + .../networks/encoders/test_from_timm.py | 66 ++ .../vision/models/networks/test_from_timm.py | 48 ++ 12 files changed, 558 insertions(+), 266 deletions(-) create mode 100644 src/eva/vision/models/networks/encoders/__init__.py create mode 100644 src/eva/vision/models/networks/encoders/encoder.py create mode 100644 src/eva/vision/models/networks/encoders/from_timm.py create mode 100644 src/eva/vision/models/networks/from_timm.py create mode 100644 tests/eva/vision/models/networks/encoders/__init__.py create mode 100644 tests/eva/vision/models/networks/encoders/test_from_timm.py create mode 100644 tests/eva/vision/models/networks/test_from_timm.py diff --git a/pdm.lock b/pdm.lock index dee63384..9ae7e117 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "all", "typecheck", "lint", "vision", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:ffa7b5a5665a4fd7d142fd7c8cf32c533a041f9823b79d24f24a228ec144a1f7" +content_hash = "sha256:e9ce86cdec586c1748ab0b2dfcd22ac15112f38fb0c262d908992a1ca7c349b3" [[package]] name = "absl-py" @@ -168,7 +168,7 @@ files = [ [[package]] name = "black" -version = "24.3.0" +version = "24.4.2" requires_python = ">=3.8" summary = "The uncompromising code formatter." groups = ["dev", "lint"] @@ -182,20 +182,20 @@ dependencies = [ "typing-extensions>=4.0.1; python_version < \"3.11\"", ] files = [ - {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, - {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, - {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, - {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, - {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, - {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, - {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, - {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, - {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, - {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, - {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, - {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, - {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, - {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, + {file = "black-24.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dd1b5a14e417189db4c7b64a6540f31730713d173f0b63e55fabd52d61d8fdce"}, + {file = "black-24.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e537d281831ad0e71007dcdcbe50a71470b978c453fa41ce77186bbe0ed6021"}, + {file = "black-24.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaea3008c281f1038edb473c1aa8ed8143a5535ff18f978a318f10302b254063"}, + {file = "black-24.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:7768a0dbf16a39aa5e9a3ded568bb545c8c2727396d063bbaf847df05b08cd96"}, + {file = "black-24.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:257d724c2c9b1660f353b36c802ccece186a30accc7742c176d29c146df6e474"}, + {file = "black-24.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bdde6f877a18f24844e381d45e9947a49e97933573ac9d4345399be37621e26c"}, + {file = "black-24.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e151054aa00bad1f4e1f04919542885f89f5f7d086b8a59e5000e6c616896ffb"}, + {file = "black-24.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:7e122b1c4fb252fd85df3ca93578732b4749d9be076593076ef4d07a0233c3e1"}, + {file = "black-24.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:accf49e151c8ed2c0cdc528691838afd217c50412534e876a19270fea1e28e2d"}, + {file = "black-24.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88c57dc656038f1ab9f92b3eb5335ee9b021412feaa46330d5eba4e51fe49b04"}, + {file = "black-24.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be8bef99eb46d5021bf053114442914baeb3649a89dc5f3a555c88737e5e98fc"}, + {file = "black-24.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:415e686e87dbbe6f4cd5ef0fbf764af7b89f9057b97c908742b6008cc554b9c0"}, + {file = "black-24.4.2-py3-none-any.whl", hash = "sha256:d36ed1124bb81b32f8614555b34cc4259c3fbc7eec17870e8ff8ded335b58d8c"}, + {file = "black-24.4.2.tar.gz", hash = "sha256:c872b53057f000085da66a19c55d68f6f8ddcac2642392ad3a355878406fbd4d"}, ] [[package]] @@ -597,7 +597,7 @@ files = [ [[package]] name = "h5py" -version = "3.10.0" +version = "3.11.0" requires_python = ">=3.8" summary = "Read and write HDF5 files from Python" groups = ["all", "vision"] @@ -605,21 +605,19 @@ dependencies = [ "numpy>=1.17.3", ] files = [ - {file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"}, - {file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"}, - {file = "h5py-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:781a24263c1270a62cd67be59f293e62b76acfcc207afa6384961762bb88ea03"}, - {file = "h5py-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42e6c30698b520f0295d70157c4e202a9e402406f50dc08f5a7bc416b24e52d"}, - {file = "h5py-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:93dd840bd675787fc0b016f7a05fc6efe37312a08849d9dd4053fd0377b1357f"}, - {file = "h5py-3.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2381e98af081b6df7f6db300cd88f88e740649d77736e4b53db522d8874bf2dc"}, - {file = "h5py-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:667fe23ab33d5a8a6b77970b229e14ae3bb84e4ea3382cc08567a02e1499eedd"}, - {file = "h5py-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90286b79abd085e4e65e07c1bd7ee65a0f15818ea107f44b175d2dfe1a4674b7"}, - {file = "h5py-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c013d2e79c00f28ffd0cc24e68665ea03ae9069e167087b2adb5727d2736a52"}, - {file = "h5py-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:92273ce69ae4983dadb898fd4d3bea5eb90820df953b401282ee69ad648df684"}, - {file = "h5py-3.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c97d03f87f215e7759a354460fb4b0d0f27001450b18b23e556e7856a0b21c3"}, - {file = "h5py-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86df4c2de68257b8539a18646ceccdcf2c1ce6b1768ada16c8dcfb489eafae20"}, - {file = "h5py-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba9ab36be991119a3ff32d0c7cbe5faf9b8d2375b5278b2aea64effbeba66039"}, - {file = "h5py-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c8e4fda19eb769e9a678592e67eaec3a2f069f7570c82d2da909c077aa94339"}, - {file = "h5py-3.10.0.tar.gz", hash = "sha256:d93adc48ceeb33347eb24a634fb787efc7ae4644e6ea4ba733d099605045c049"}, + {file = "h5py-3.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1625fd24ad6cfc9c1ccd44a66dac2396e7ee74940776792772819fc69f3a3731"}, + {file = "h5py-3.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c072655ad1d5fe9ef462445d3e77a8166cbfa5e599045f8aa3c19b75315f10e5"}, + {file = "h5py-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77b19a40788e3e362b54af4dcf9e6fde59ca016db2c61360aa30b47c7b7cef00"}, + {file = "h5py-3.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:ef4e2f338fc763f50a8113890f455e1a70acd42a4d083370ceb80c463d803972"}, + {file = "h5py-3.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bbd732a08187a9e2a6ecf9e8af713f1d68256ee0f7c8b652a32795670fb481ba"}, + {file = "h5py-3.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75bd7b3d93fbeee40860fd70cdc88df4464e06b70a5ad9ce1446f5f32eb84007"}, + {file = "h5py-3.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c416f8eb0daae39dabe71415cb531f95dce2d81e1f61a74537a50c63b28ab3"}, + {file = "h5py-3.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e"}, + {file = "h5py-3.11.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a76cae64080210389a571c7d13c94a1a6cf8cb75153044fd1f822a962c97aeab"}, + {file = "h5py-3.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc"}, + {file = "h5py-3.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa6ae84a14103e8dc19266ef4c3e5d7c00b68f21d07f2966f0ca7bdb6c2761fb"}, + {file = "h5py-3.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:21dbdc5343f53b2e25404673c4f00a3335aef25521bd5fa8c707ec3833934892"}, + {file = "h5py-3.11.0.tar.gz", hash = "sha256:7b7e8f78072a2edec87c9836f25f34203fd492a4475709a18b417a33cfb21fa9"}, ] [[package]] @@ -703,6 +701,20 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "intel-openmp" +version = "2021.4.0" +summary = "IntelĀ® OpenMP* Runtime Library" +groups = ["all", "default", "vision"] +marker = "platform_system == \"Windows\"" +files = [ + {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, +] + [[package]] name = "isort" version = "5.13.2" @@ -730,7 +742,7 @@ files = [ [[package]] name = "jsonargparse" -version = "4.27.6" +version = "4.28.0" requires_python = ">=3.7" summary = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables." groups = ["default"] @@ -738,29 +750,29 @@ dependencies = [ "PyYAML>=3.13", ] files = [ - {file = "jsonargparse-4.27.6-py3-none-any.whl", hash = "sha256:f429b4a1b1fe92ef2e3e531615f53e81720a424f3f3181eca7a28c994515fc15"}, - {file = "jsonargparse-4.27.6.tar.gz", hash = "sha256:ebd2e0a4faef85a075bb6ef79c6b2f03f57a5f8e3db26c911b55518a1bca68ad"}, + {file = "jsonargparse-4.28.0-py3-none-any.whl", hash = "sha256:9dcda241349547e8035c630d51de73b8b4ba67bdc2b014d7f76734d404e82518"}, + {file = "jsonargparse-4.28.0.tar.gz", hash = "sha256:ac835a290ef18cc2a5309e6bfa8ada9c5d63f46ff18701583fc8f3e95314679c"}, ] [[package]] name = "jsonargparse" -version = "4.27.6" +version = "4.28.0" extras = ["omegaconf"] requires_python = ">=3.7" summary = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables." groups = ["default"] dependencies = [ - "jsonargparse==4.27.6", + "jsonargparse==4.28.0", "omegaconf>=2.1.1", ] files = [ - {file = "jsonargparse-4.27.6-py3-none-any.whl", hash = "sha256:f429b4a1b1fe92ef2e3e531615f53e81720a424f3f3181eca7a28c994515fc15"}, - {file = "jsonargparse-4.27.6.tar.gz", hash = "sha256:ebd2e0a4faef85a075bb6ef79c6b2f03f57a5f8e3db26c911b55518a1bca68ad"}, + {file = "jsonargparse-4.28.0-py3-none-any.whl", hash = "sha256:9dcda241349547e8035c630d51de73b8b4ba67bdc2b014d7f76734d404e82518"}, + {file = "jsonargparse-4.28.0.tar.gz", hash = "sha256:ac835a290ef18cc2a5309e6bfa8ada9c5d63f46ff18701583fc8f3e95314679c"}, ] [[package]] name = "lightning" -version = "2.3.0.dev20240407" +version = "2.3.0.dev20240421" requires_python = ">=3.8" summary = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." groups = ["default"] @@ -777,8 +789,8 @@ dependencies = [ "typing-extensions<6.0,>=4.4.0", ] files = [ - {file = "lightning-2.3.0.dev20240407-py3-none-any.whl", hash = "sha256:27fa1f37a5ab12b917590f833baeea3e02c3a979f49f9a14bb35e2fd0ae29cfd"}, - {file = "lightning-2.3.0.dev20240407.tar.gz", hash = "sha256:6aab115c1c22a75d359f79db3457d7488682900aa03426a9534ef1a535195310"}, + {file = "lightning-2.3.0.dev20240421-py3-none-any.whl", hash = "sha256:18a31ce5fec10c11e73ebddb970aab4f47acf86e4ff187d1f85ae15f3be540ca"}, + {file = "lightning-2.3.0.dev20240421.tar.gz", hash = "sha256:a803678ca35e24a9ec8b01f3c5ce7207be80adbf13775bc14cf8c4af7e7039d2"}, ] [[package]] @@ -825,7 +837,7 @@ files = [ [[package]] name = "markdown-exec" -version = "1.8.0" +version = "1.8.1" requires_python = ">=3.8" summary = "Utilities to execute code blocks in Markdown files." groups = ["dev", "docs"] @@ -833,8 +845,8 @@ dependencies = [ "pymdown-extensions>=9", ] files = [ - {file = "markdown_exec-1.8.0-py3-none-any.whl", hash = "sha256:e80cb766eff8d0bcd1cdd133dba58223b42edbd1b7b9672481c2189572401bff"}, - {file = "markdown_exec-1.8.0.tar.gz", hash = "sha256:0a932312f0ca89b82150e1638e84febb90eadd410dfd2417f05759c06deed727"}, + {file = "markdown_exec-1.8.1-py3-none-any.whl", hash = "sha256:63c769ebf202b1c1f97822c72e4467d39e151b741aeb94758b3de20066ed3b5f"}, + {file = "markdown_exec-1.8.1.tar.gz", hash = "sha256:1fe4e344f3dc000dd7e764ab1ee21d14e4e15c91afc8c6d35f18d694693eb696"}, ] [[package]] @@ -934,8 +946,8 @@ files = [ [[package]] name = "mkdocs" -version = "1.5.3" -requires_python = ">=3.7" +version = "1.6.0" +requires_python = ">=3.8" summary = "Project documentation with Markdown." groups = ["dev", "docs"] dependencies = [ @@ -943,19 +955,19 @@ dependencies = [ "colorama>=0.4; platform_system == \"Windows\"", "ghp-import>=1.0", "jinja2>=2.11.1", - "markdown>=3.2.1", + "markdown>=3.3.6", "markupsafe>=2.0.1", "mergedeep>=1.3.4", + "mkdocs-get-deps>=0.2.0", "packaging>=20.5", "pathspec>=0.11.1", - "platformdirs>=2.2.0", "pyyaml-env-tag>=0.1", "pyyaml>=5.1", "watchdog>=2.0", ] files = [ - {file = "mkdocs-1.5.3-py3-none-any.whl", hash = "sha256:3b3a78e736b31158d64dbb2f8ba29bd46a379d0c6e324c2246c3bc3d2189cfc1"}, - {file = "mkdocs-1.5.3.tar.gz", hash = "sha256:eb7c99214dcb945313ba30426c2451b735992c73c2e10838f76d09e39ff4d0e2"}, + {file = "mkdocs-1.6.0-py3-none-any.whl", hash = "sha256:1eb5cb7676b7d89323e62b56235010216319217d4af5ddc543a91beb8d125ea7"}, + {file = "mkdocs-1.6.0.tar.gz", hash = "sha256:a73f735824ef83a4f3bcb7a231dcab23f5a838f88b7efc54a0eef5fbdbc3c512"}, ] [[package]] @@ -974,9 +986,25 @@ files = [ {file = "mkdocs_autorefs-1.0.1.tar.gz", hash = "sha256:f684edf847eced40b570b57846b15f0bf57fb93ac2c510450775dcf16accb971"}, ] +[[package]] +name = "mkdocs-get-deps" +version = "0.2.0" +requires_python = ">=3.8" +summary = "MkDocs extension that lists all dependencies according to a mkdocs.yml file" +groups = ["dev", "docs"] +dependencies = [ + "mergedeep>=1.3.4", + "platformdirs>=2.2.0", + "pyyaml>=5.1", +] +files = [ + {file = "mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134"}, + {file = "mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c"}, +] + [[package]] name = "mkdocs-material" -version = "9.5.14" +version = "9.5.19" requires_python = ">=3.8" summary = "Documentation that simply works" groups = ["dev", "docs"] @@ -986,7 +1014,7 @@ dependencies = [ "jinja2~=3.0", "markdown~=3.2", "mkdocs-material-extensions~=1.3", - "mkdocs~=1.5.3", + "mkdocs~=1.6", "paginate~=0.5", "pygments~=2.16", "pymdown-extensions~=10.2", @@ -994,8 +1022,8 @@ dependencies = [ "requests~=2.26", ] files = [ - {file = "mkdocs_material-9.5.14-py3-none-any.whl", hash = "sha256:a45244ac221fda46ecf8337f00ec0e5cb5348ab9ffb203ca2a0c313b0d4dbc27"}, - {file = "mkdocs_material-9.5.14.tar.gz", hash = "sha256:2a1f8e67cda2587ab93ecea9ba42d0ca61d1d7b5fad8cf690eeaeb39dcd4b9af"}, + {file = "mkdocs_material-9.5.19-py3-none-any.whl", hash = "sha256:ea96e150b6c95f5e4ffe47d78bb712c7bacdd91d2a0bec47f46b6fa0705a86ec"}, + {file = "mkdocs_material-9.5.19.tar.gz", hash = "sha256:7473e06e17e23af608a30ef583fdde8f36389dd3ef56b1d503eed54c89c9618c"}, ] [[package]] @@ -1035,7 +1063,7 @@ files = [ [[package]] name = "mkdocstrings" -version = "0.24.1" +version = "0.24.3" requires_python = ">=3.8" summary = "Automatic documentation from sources, for MkDocs." groups = ["dev", "docs"] @@ -1050,8 +1078,8 @@ dependencies = [ "pymdown-extensions>=6.3", ] files = [ - {file = "mkdocstrings-0.24.1-py3-none-any.whl", hash = "sha256:b4206f9a2ca8a648e222d5a0ca1d36ba7dee53c88732818de183b536f9042b5d"}, - {file = "mkdocstrings-0.24.1.tar.gz", hash = "sha256:cc83f9a1c8724fc1be3c2fa071dd73d91ce902ef6a79710249ec8d0ee1064401"}, + {file = "mkdocstrings-0.24.3-py3-none-any.whl", hash = "sha256:5c9cf2a32958cd161d5428699b79c8b0988856b0d4a8c5baf8395fc1bf4087c3"}, + {file = "mkdocstrings-0.24.3.tar.gz", hash = "sha256:f327b234eb8d2551a306735436e157d0a22d45f79963c60a8b585d5f7a94c1d2"}, ] [[package]] @@ -1072,18 +1100,36 @@ files = [ [[package]] name = "mkdocstrings" -version = "0.24.1" +version = "0.24.3" extras = ["python"] requires_python = ">=3.8" summary = "Automatic documentation from sources, for MkDocs." groups = ["dev", "docs"] dependencies = [ "mkdocstrings-python>=0.5.2", - "mkdocstrings==0.24.1", + "mkdocstrings==0.24.3", ] files = [ - {file = "mkdocstrings-0.24.1-py3-none-any.whl", hash = "sha256:b4206f9a2ca8a648e222d5a0ca1d36ba7dee53c88732818de183b536f9042b5d"}, - {file = "mkdocstrings-0.24.1.tar.gz", hash = "sha256:cc83f9a1c8724fc1be3c2fa071dd73d91ce902ef6a79710249ec8d0ee1064401"}, + {file = "mkdocstrings-0.24.3-py3-none-any.whl", hash = "sha256:5c9cf2a32958cd161d5428699b79c8b0988856b0d4a8c5baf8395fc1bf4087c3"}, + {file = "mkdocstrings-0.24.3.tar.gz", hash = "sha256:f327b234eb8d2551a306735436e157d0a22d45f79963c60a8b585d5f7a94c1d2"}, +] + +[[package]] +name = "mkl" +version = "2021.4.0" +summary = "IntelĀ® oneAPI Math Kernel Library" +groups = ["all", "default", "vision"] +marker = "platform_system == \"Windows\"" +dependencies = [ + "intel-openmp==2021.*", + "tbb==2021.*", +] +files = [ + {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, + {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, + {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, ] [[package]] @@ -1205,7 +1251,7 @@ files = [ [[package]] name = "nox" -version = "2024.3.2" +version = "2024.4.15" requires_python = ">=3.7" summary = "Flexible test automation." groups = ["dev", "typecheck"] @@ -1213,11 +1259,12 @@ dependencies = [ "argcomplete<4.0,>=1.9.4", "colorlog<7.0.0,>=2.6.1", "packaging>=20.9", + "tomli>=1; python_version < \"3.11\"", "virtualenv>=20.14.1", ] files = [ - {file = "nox-2024.3.2-py3-none-any.whl", hash = "sha256:e53514173ac0b98dd47585096a55572fe504fecede58ced708979184d05440be"}, - {file = "nox-2024.3.2.tar.gz", hash = "sha256:f521ae08a15adbf5e11f16cb34e8d0e6ea521e0b92868f684e91677deb974553"}, + {file = "nox-2024.4.15-py3-none-any.whl", hash = "sha256:6492236efa15a460ecb98e7b67562a28b70da006ab0be164e8821177577c0565"}, + {file = "nox-2024.4.15.tar.gz", hash = "sha256:ecf6700199cdfa9e5ea0a41ff5e6ef4641d09508eda6edb89d9987864115817f"}, ] [[package]] @@ -1377,13 +1424,14 @@ files = [ [[package]] name = "nvidia-nccl-cu12" -version = "2.19.3" +version = "2.20.5" requires_python = ">=3" summary = "NVIDIA Collective Communication Library (NCCL) Runtime" groups = ["all", "default", "vision"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ - {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, ] [[package]] @@ -1460,33 +1508,33 @@ files = [ [[package]] name = "onnxruntime" -version = "1.17.1" +version = "1.17.3" summary = "ONNX Runtime is a runtime accelerator for Machine Learning models" groups = ["default"] dependencies = [ "coloredlogs", "flatbuffers", - "numpy>=1.24.2", + "numpy>=1.21.6", "packaging", "protobuf", "sympy", ] files = [ - {file = "onnxruntime-1.17.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:d43ac17ac4fa3c9096ad3c0e5255bb41fd134560212dc124e7f52c3159af5d21"}, - {file = "onnxruntime-1.17.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:55b5e92a4c76a23981c998078b9bf6145e4fb0b016321a8274b1607bd3c6bd35"}, - {file = "onnxruntime-1.17.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ebbcd2bc3a066cf54e6f18c75708eb4d309ef42be54606d22e5bdd78afc5b0d7"}, - {file = "onnxruntime-1.17.1-cp310-cp310-win32.whl", hash = "sha256:5e3716b5eec9092e29a8d17aab55e737480487deabfca7eac3cd3ed952b6ada9"}, - {file = "onnxruntime-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:fbb98cced6782ae1bb799cc74ddcbbeeae8819f3ad1d942a74d88e72b6511337"}, - {file = "onnxruntime-1.17.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:36fd6f87a1ecad87e9c652e42407a50fb305374f9a31d71293eb231caae18784"}, - {file = "onnxruntime-1.17.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99a8bddeb538edabc524d468edb60ad4722cff8a49d66f4e280c39eace70500b"}, - {file = "onnxruntime-1.17.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fd7fddb4311deb5a7d3390cd8e9b3912d4d963efbe4dfe075edbaf18d01c024e"}, - {file = "onnxruntime-1.17.1-cp311-cp311-win32.whl", hash = "sha256:606a7cbfb6680202b0e4f1890881041ffc3ac6e41760a25763bd9fe146f0b335"}, - {file = "onnxruntime-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:53e4e06c0a541696ebdf96085fd9390304b7b04b748a19e02cf3b35c869a1e76"}, - {file = "onnxruntime-1.17.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:40f08e378e0f85929712a2b2c9b9a9cc400a90c8a8ca741d1d92c00abec60843"}, - {file = "onnxruntime-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac79da6d3e1bb4590f1dad4bb3c2979d7228555f92bb39820889af8b8e6bd472"}, - {file = "onnxruntime-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ae9ba47dc099004e3781f2d0814ad710a13c868c739ab086fc697524061695ea"}, - {file = "onnxruntime-1.17.1-cp312-cp312-win32.whl", hash = "sha256:2dff1a24354220ac30e4a4ce2fb1df38cb1ea59f7dac2c116238d63fe7f4c5ff"}, - {file = "onnxruntime-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:6226a5201ab8cafb15e12e72ff2a4fc8f50654e8fa5737c6f0bd57c5ff66827e"}, + {file = "onnxruntime-1.17.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:d86dde9c0bb435d709e51bd25991c9fe5b9a5b168df45ce119769edc4d198b15"}, + {file = "onnxruntime-1.17.3-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9d87b68bf931ac527b2d3c094ead66bb4381bac4298b65f46c54fe4d1e255865"}, + {file = "onnxruntime-1.17.3-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:26e950cf0333cf114a155f9142e71da344d2b08dfe202763a403ae81cc02ebd1"}, + {file = "onnxruntime-1.17.3-cp310-cp310-win32.whl", hash = "sha256:0962a4d0f5acebf62e1f0bf69b6e0adf16649115d8de854c1460e79972324d68"}, + {file = "onnxruntime-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:468ccb8a0faa25c681a41787b1594bf4448b0252d3efc8b62fd8b2411754340f"}, + {file = "onnxruntime-1.17.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e8cd90c1c17d13d47b89ab076471e07fb85467c01dcd87a8b8b5cdfbcb40aa51"}, + {file = "onnxruntime-1.17.3-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a058b39801baefe454eeb8acf3ada298c55a06a4896fafc224c02d79e9037f60"}, + {file = "onnxruntime-1.17.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f823d5eb4807007f3da7b27ca972263df6a1836e6f327384eb266274c53d05d"}, + {file = "onnxruntime-1.17.3-cp311-cp311-win32.whl", hash = "sha256:b66b23f9109e78ff2791628627a26f65cd335dcc5fbd67ff60162733a2f7aded"}, + {file = "onnxruntime-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:570760ca53a74cdd751ee49f13de70d1384dcf73d9888b8deac0917023ccda6d"}, + {file = "onnxruntime-1.17.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:77c318178d9c16e9beadd9a4070d8aaa9f57382c3f509b01709f0f010e583b99"}, + {file = "onnxruntime-1.17.3-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:23da8469049b9759082e22c41a444f44a520a9c874b084711b6343672879f50b"}, + {file = "onnxruntime-1.17.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2949730215af3f9289008b2e31e9bbef952012a77035b911c4977edea06f3f9e"}, + {file = "onnxruntime-1.17.3-cp312-cp312-win32.whl", hash = "sha256:6c7555a49008f403fb3b19204671efb94187c5085976ae526cb625f6ede317bc"}, + {file = "onnxruntime-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:58672cf20293a1b8a277a5c6c55383359fcdf6119b2f14df6ce3b140f5001c39"}, ] [[package]] @@ -1537,41 +1585,41 @@ files = [ [[package]] name = "pandas" -version = "2.2.1" +version = "2.2.2" requires_python = ">=3.9" summary = "Powerful data structures for data analysis, time series, and statistics" groups = ["default"] dependencies = [ - "numpy<2,>=1.22.4; python_version < \"3.11\"", - "numpy<2,>=1.23.2; python_version == \"3.11\"", - "numpy<2,>=1.26.0; python_version >= \"3.12\"", + "numpy>=1.22.4; python_version < \"3.11\"", + "numpy>=1.23.2; python_version == \"3.11\"", + "numpy>=1.26.0; python_version >= \"3.12\"", "python-dateutil>=2.8.2", "pytz>=2020.1", "tzdata>=2022.7", ] files = [ - {file = "pandas-2.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8df8612be9cd1c7797c93e1c5df861b2ddda0b48b08f2c3eaa0702cf88fb5f88"}, - {file = "pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f573ab277252ed9aaf38240f3b54cfc90fff8e5cab70411ee1d03f5d51f3944"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f02a3a6c83df4026e55b63c1f06476c9aa3ed6af3d89b4f04ea656ccdaaaa359"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c2ce852e1cf2509a69e98358e8458775f89599566ac3775e70419b98615f4b06"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53680dc9b2519cbf609c62db3ed7c0b499077c7fefda564e330286e619ff0dd9"}, - {file = "pandas-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:94e714a1cca63e4f5939cdce5f29ba8d415d85166be3441165edd427dc9f6bc0"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f821213d48f4ab353d20ebc24e4faf94ba40d76680642fb7ce2ea31a3ad94f9b"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c70e00c2d894cb230e5c15e4b1e1e6b2b478e09cf27cc593a11ef955b9ecc81a"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97fbb5387c69209f134893abc788a6486dbf2f9e511070ca05eed4b930b1b02"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101d0eb9c5361aa0146f500773395a03839a5e6ecde4d4b6ced88b7e5a1a6403"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7d2ed41c319c9fb4fd454fe25372028dfa417aacb9790f68171b2e3f06eae8cd"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5d3c00557d657c8773ef9ee702c61dd13b9d7426794c9dfeb1dc4a0bf0ebc7"}, - {file = "pandas-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:06cf591dbaefb6da9de8472535b185cba556d0ce2e6ed28e21d919704fef1a9e"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:88ecb5c01bb9ca927ebc4098136038519aa5d66b44671861ffab754cae75102c"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f6ec3baec203c13e3f8b139fb0f9f86cd8c0b94603ae3ae8ce9a422e9f5bee"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a935a90a76c44fe170d01e90a3594beef9e9a6220021acfb26053d01426f7dc2"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c391f594aae2fd9f679d419e9a4d5ba4bce5bb13f6a989195656e7dc4b95c8f0"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9d1265545f579edf3f8f0cb6f89f234f5e44ba725a34d86535b1a1d38decbccc"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:11940e9e3056576ac3244baef2fedade891977bcc1cb7e5cc8f8cc7d603edc89"}, - {file = "pandas-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acf681325ee1c7f950d058b05a820441075b0dd9a2adf5c4835b9bc056bf4fb"}, - {file = "pandas-2.2.1.tar.gz", hash = "sha256:0ab90f87093c13f3e8fa45b48ba9f39181046e8f3317d3aadb2fffbb1b978572"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, ] [[package]] @@ -1740,7 +1788,7 @@ files = [ [[package]] name = "pyright" -version = "1.1.355" +version = "1.1.360" requires_python = ">=3.7" summary = "Command line wrapper for pyright" groups = ["dev", "typecheck"] @@ -1748,8 +1796,8 @@ dependencies = [ "nodeenv>=1.6.0", ] files = [ - {file = "pyright-1.1.355-py3-none-any.whl", hash = "sha256:bf30b6728fd68ae7d09c98292b67152858dd89738569836896df786e52b5fe48"}, - {file = "pyright-1.1.355.tar.gz", hash = "sha256:dca4104cd53d6484e6b1b50b7a239ad2d16d2ffd20030bcf3111b56f44c263bf"}, + {file = "pyright-1.1.360-py3-none-any.whl", hash = "sha256:7637f75451ac968b7cf1f8c51cfefb6d60ac7d086eb845364bc8ac03a026efd7"}, + {file = "pyright-1.1.360.tar.gz", hash = "sha256:784ddcda9745e9f5610483d7b963e9aa8d4f50d7755a9dffb28ccbeb27adce32"}, ] [[package]] @@ -1773,8 +1821,8 @@ files = [ [[package]] name = "pytest-cov" -version = "4.1.0" -requires_python = ">=3.7" +version = "5.0.0" +requires_python = ">=3.8" summary = "Pytest plugin for measuring coverage." groups = ["dev", "test"] dependencies = [ @@ -1782,8 +1830,8 @@ dependencies = [ "pytest>=4.6", ] files = [ - {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, - {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, ] [[package]] @@ -1969,28 +2017,28 @@ files = [ [[package]] name = "ruff" -version = "0.3.3" +version = "0.4.2" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["dev", "lint"] files = [ - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:973a0e388b7bc2e9148c7f9be8b8c6ae7471b9be37e1cc732f8f44a6f6d7720d"}, - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfa60d23269d6e2031129b053fdb4e5a7b0637fc6c9c0586737b962b2f834493"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eca7ff7a47043cf6ce5c7f45f603b09121a7cc047447744b029d1b719278eb5"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7d3f6762217c1da954de24b4a1a70515630d29f71e268ec5000afe81377642d"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b24c19e8598916d9c6f5a5437671f55ee93c212a2c4c569605dc3842b6820386"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5a6cbf216b69c7090f0fe4669501a27326c34e119068c1494f35aaf4cc683778"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352e95ead6964974b234e16ba8a66dad102ec7bf8ac064a23f95371d8b198aab"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d6ab88c81c4040a817aa432484e838aaddf8bfd7ca70e4e615482757acb64f8"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79bca3a03a759cc773fca69e0bdeac8abd1c13c31b798d5bb3c9da4a03144a9f"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2700a804d5336bcffe063fd789ca2c7b02b552d2e323a336700abb8ae9e6a3f8"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd66469f1a18fdb9d32e22b79f486223052ddf057dc56dea0caaf1a47bdfaf4e"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45817af234605525cdf6317005923bf532514e1ea3d9270acf61ca2440691376"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0da458989ce0159555ef224d5b7c24d3d2e4bf4c300b85467b08c3261c6bc6a8"}, - {file = "ruff-0.3.3-py3-none-win32.whl", hash = "sha256:f2831ec6a580a97f1ea82ea1eda0401c3cdf512cf2045fa3c85e8ef109e87de0"}, - {file = "ruff-0.3.3-py3-none-win_amd64.whl", hash = "sha256:be90bcae57c24d9f9d023b12d627e958eb55f595428bafcb7fec0791ad25ddfc"}, - {file = "ruff-0.3.3-py3-none-win_arm64.whl", hash = "sha256:0171aab5fecdc54383993389710a3d1227f2da124d76a2784a7098e818f92d61"}, - {file = "ruff-0.3.3.tar.gz", hash = "sha256:38671be06f57a2f8aba957d9f701ea889aa5736be806f18c0cd03d6ff0cbca8d"}, + {file = "ruff-0.4.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:8d14dc8953f8af7e003a485ef560bbefa5f8cc1ad994eebb5b12136049bbccc5"}, + {file = "ruff-0.4.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:24016ed18db3dc9786af103ff49c03bdf408ea253f3cb9e3638f39ac9cf2d483"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2e06459042ac841ed510196c350ba35a9b24a643e23db60d79b2db92af0c2b"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3afabaf7ba8e9c485a14ad8f4122feff6b2b93cc53cd4dad2fd24ae35112d5c5"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:799eb468ea6bc54b95527143a4ceaf970d5aa3613050c6cff54c85fda3fde480"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ec4ba9436a51527fb6931a8839af4c36a5481f8c19e8f5e42c2f7ad3a49f5069"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a2243f8f434e487c2a010c7252150b1fdf019035130f41b77626f5655c9ca22"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8772130a063f3eebdf7095da00c0b9898bd1774c43b336272c3e98667d4fb8fa"}, + {file = "ruff-0.4.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ab165ef5d72392b4ebb85a8b0fbd321f69832a632e07a74794c0e598e7a8376"}, + {file = "ruff-0.4.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1f32cadf44c2020e75e0c56c3408ed1d32c024766bd41aedef92aa3ca28eef68"}, + {file = "ruff-0.4.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:22e306bf15e09af45ca812bc42fa59b628646fa7c26072555f278994890bc7ac"}, + {file = "ruff-0.4.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:82986bb77ad83a1719c90b9528a9dd663c9206f7c0ab69282af8223566a0c34e"}, + {file = "ruff-0.4.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:652e4ba553e421a6dc2a6d4868bc3b3881311702633eb3672f9f244ded8908cd"}, + {file = "ruff-0.4.2-py3-none-win32.whl", hash = "sha256:7891ee376770ac094da3ad40c116258a381b86c7352552788377c6eb16d784fe"}, + {file = "ruff-0.4.2-py3-none-win_amd64.whl", hash = "sha256:5ec481661fb2fd88a5d6cf1f83403d388ec90f9daaa36e40e2c003de66751798"}, + {file = "ruff-0.4.2-py3-none-win_arm64.whl", hash = "sha256:cbd1e87c71bca14792948c4ccb51ee61c3296e164019d2d484f3eaa2d360dfaf"}, + {file = "ruff-0.4.2.tar.gz", hash = "sha256:33bcc160aee2520664bc0859cfeaebc84bb7323becff3f303b8f1f2d81cb4edc"}, ] [[package]] @@ -2078,13 +2126,13 @@ files = [ [[package]] name = "setuptools" -version = "69.2.0" +version = "69.5.1" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default", "dev", "docs", "typecheck"] files = [ - {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"}, - {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"}, + {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"}, + {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"}, ] [[package]] @@ -2126,6 +2174,19 @@ files = [ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] +[[package]] +name = "tbb" +version = "2021.12.0" +summary = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" +groups = ["all", "default", "vision"] +marker = "platform_system == \"Windows\"" +files = [ + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, + {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, + {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, +] + [[package]] name = "tensorboard" version = "2.16.2" @@ -2161,8 +2222,11 @@ files = [ [[package]] name = "timm" -version = "0.9.16" +version = "1.0.0.dev0" requires_python = ">=3.8" +git = "https://github.com/huggingface/pytorch-image-models.git" +ref = "main" +revision = "e741370e2b95e0c2fa3e00808cd9014ee620ca62" summary = "PyTorch Image Models" groups = ["all", "vision"] dependencies = [ @@ -2172,14 +2236,10 @@ dependencies = [ "torch", "torchvision", ] -files = [ - {file = "timm-0.9.16-py3-none-any.whl", hash = "sha256:bf5704014476ab011589d3c14172ee4c901fd18f9110a928019cac5be2945914"}, - {file = "timm-0.9.16.tar.gz", hash = "sha256:891e54f375d55adf31a71ab0c117761f0e472f9f3971858ecdd1e7376b7071e6"}, -] [[package]] name = "tokenizers" -version = "0.15.2" +version = "0.19.1" requires_python = ">=3.7" summary = "" groups = ["default"] @@ -2187,80 +2247,70 @@ dependencies = [ "huggingface-hub<1.0,>=0.16.4", ] files = [ - {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, - {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"}, - {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"}, - {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"}, - {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"}, - {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"}, - {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"}, - {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"}, - {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"}, - {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"}, - {file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"}, - {file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"}, - {file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"}, - {file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"}, - {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"}, - {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"}, - {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"}, - {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"}, - {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"}, - {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"}, - {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"}, - {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"}, - {file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"}, - {file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"}, - {file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"}, - {file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"}, - {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"}, - {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"}, - {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"}, - {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"}, - {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"}, - {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"}, - {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"}, - {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"}, - {file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"}, - {file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"}, - {file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"}, - {file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"}, - {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"}, - {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"}, - {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"}, - {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"}, - {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"}, - {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"}, - {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"}, - {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"}, - {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"}, - {file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"}, - {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"}, - {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"}, - {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"}, - {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"}, - {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"}, - {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"}, - {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"}, - {file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"}, + {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"}, + {file = "tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e"}, + {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98"}, + {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3"}, + {file = "tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837"}, + {file = "tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403"}, + {file = "tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059"}, + {file = "tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa"}, + {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6"}, + {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b"}, + {file = "tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256"}, + {file = "tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66"}, + {file = "tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153"}, + {file = "tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3"}, + {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea"}, + {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c"}, + {file = "tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57"}, + {file = "tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f"}, + {file = "tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3"}, ] [[package]] @@ -2288,7 +2338,7 @@ files = [ [[package]] name = "torch" -version = "2.2.1" +version = "2.3.0" requires_python = ">=3.8.0" summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" groups = ["all", "default", "vision"] @@ -2296,6 +2346,7 @@ dependencies = [ "filelock", "fsspec", "jinja2", + "mkl<=2021.4.0,>=2021.1.1; platform_system == \"Windows\"", "networkx", "nvidia-cublas-cu12==12.1.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-cuda-cupti-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"", @@ -2306,28 +2357,25 @@ dependencies = [ "nvidia-curand-cu12==10.3.2.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-cusolver-cu12==11.4.5.107; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-cusparse-cu12==12.1.0.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"", - "nvidia-nccl-cu12==2.19.3; platform_system == \"Linux\" and platform_machine == \"x86_64\"", + "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "sympy", - "triton==2.2.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"", + "triton==2.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"", "typing-extensions>=4.8.0", ] files = [ - {file = "torch-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8d3bad336dd2c93c6bcb3268e8e9876185bda50ebde325ef211fb565c7d15273"}, - {file = "torch-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5297f13370fdaca05959134b26a06a7f232ae254bf2e11a50eddec62525c9006"}, - {file = "torch-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:5f5dee8433798888ca1415055f5e3faf28a3bad660e4c29e1014acd3275ab11a"}, - {file = "torch-2.2.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b6d78338acabf1fb2e88bf4559d837d30230cf9c3e4337261f4d83200df1fcbe"}, - {file = "torch-2.2.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:6ab3ea2e29d1aac962e905142bbe50943758f55292f1b4fdfb6f4792aae3323e"}, - {file = "torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:d86664ec85902967d902e78272e97d1aff1d331f7619d398d3ffab1c9b8e9157"}, - {file = "torch-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d6227060f268894f92c61af0a44c0d8212e19cb98d05c20141c73312d923bc0a"}, - {file = "torch-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:77e990af75fb1675490deb374d36e726f84732cd5677d16f19124934b2409ce9"}, - {file = "torch-2.2.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:46085e328d9b738c261f470231e987930f4cc9472d9ffb7087c7a1343826ac51"}, - {file = "torch-2.2.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:2d9e7e5ecbb002257cf98fae13003abbd620196c35f85c9e34c2adfb961321ec"}, - {file = "torch-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ada53aebede1c89570e56861b08d12ba4518a1f8b82d467c32665ec4d1f4b3c8"}, - {file = "torch-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:be21d4c41ecebed9e99430dac87de1439a8c7882faf23bba7fea3fea7b906ac1"}, - {file = "torch-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:79848f46196750367dcdf1d2132b722180b9d889571e14d579ae82d2f50596c5"}, - {file = "torch-2.2.1-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:7ee804847be6be0032fbd2d1e6742fea2814c92bebccb177f0d3b8e92b2d2b18"}, - {file = "torch-2.2.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:84b2fb322ab091039fdfe74e17442ff046b258eb5e513a28093152c5b07325a7"}, + {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, + {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, + {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, + {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, + {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, + {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, + {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, + {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, + {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, + {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, + {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, + {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, ] [[package]] @@ -2349,31 +2397,28 @@ files = [ [[package]] name = "torchvision" -version = "0.17.1" +version = "0.18.0" requires_python = ">=3.8" summary = "image and video datasets and models for torch deep learning" groups = ["all", "vision"] dependencies = [ "numpy", "pillow!=8.3.*,>=5.3.0", - "torch==2.2.1", + "torch==2.3.0", ] files = [ - {file = "torchvision-0.17.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:06418880212b66e45e855dd39f536e7fd48b4e6b034a11dd9fe9e2384afb51ec"}, - {file = "torchvision-0.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:33d65d0c7fdcb3f7bc1dd8ed30ea3cd7e0587b4ad1b104b5677c8191a8bad9f1"}, - {file = "torchvision-0.17.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:aaefef2be6a02f206085ce4bb6c0078b03ebf48cb6ff82bd762ff6248475e08e"}, - {file = "torchvision-0.17.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ebe5fdb466aff8a8e8e755de84a843418b6f8d500624752c05eaa638d7700f3d"}, - {file = "torchvision-0.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:9d4d45a996f4313e9c5db4da71d31508d44f7ccfbf29d3442bdcc2ad13e0b6f3"}, - {file = "torchvision-0.17.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:ea2ccdbf5974e0bf27fd6644a33b19cb0700297cf397bb0469e762c11c6c4105"}, - {file = "torchvision-0.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9106e32c9f1e70afa8172cf1b064cf9c2998d8dff0769ec69d537b20209ee43d"}, - {file = "torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:5966936c669a08870f6547cd0a90d08b157aeda03293f79e2adbb934687175ed"}, - {file = "torchvision-0.17.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e74f5a26ef8190eab0c38b3f63914fea94e58e3b2f0e5466611c9f63bd91a80b"}, - {file = "torchvision-0.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:a2109c1a1dcf71e8940d43e91f78c4dd5bf0fcefb3a0a42244102752009f5862"}, - {file = "torchvision-0.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5d241d2a5fb4e608677fccf6f80b34a124446d324ee40c7814ce54bce888275b"}, - {file = "torchvision-0.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0fe98d9d92c23d2262ff82f973242951b9357fb640f8888ac50848bd00f5b45"}, - {file = "torchvision-0.17.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:32dc5de86d2ade399e11087095674ca08a1649fb322cfe69336d28add467edcb"}, - {file = "torchvision-0.17.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:54902877410ffb5458ee52b6d0de4b25cf01496bee736d6825301a5f0398536e"}, - {file = "torchvision-0.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:cc22c1ed0f1aba3f98fd72b6f60021f57aec1d2f6af518522e8a0a83848de3a8"}, + {file = "torchvision-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dd61628a3d189c6852a12dc5ed4cd2eece66d2d67f35a866cb16f1dcb06c8c62"}, + {file = "torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:493c45f9937dad37aa1b64b14da17c7a589c72b91adc4837d431009cfe29bd53"}, + {file = "torchvision-0.18.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5337f6acfa1fe959d5cb340d01a00614d6b31ce7a4824ccb95435a85c5273b95"}, + {file = "torchvision-0.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:bd8e6f3b5beb49965f15c461302488edfa3d8c2d01d3bb79b150d6fb62711e3a"}, + {file = "torchvision-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6896a52168befe1105fb3c9335287390ed227e71d1e4ec4d68b62e8a3099fc09"}, + {file = "torchvision-0.18.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3d7955398d4ceaad77c487c2c44f6f7813112402c9bab8cd906d346005891048"}, + {file = "torchvision-0.18.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e5a24d620cea14a4bb89f24aa2b506230c0a16a3ada57fc53ad80cfd256a2128"}, + {file = "torchvision-0.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:6ad70ddfa879bda5ed886b2518fe562640e0059787cbd65cb2bffa7674541410"}, + {file = "torchvision-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eb9d83c0e1dbb54ecb0fb04c87f786333e3a6fb8b9c400aca7c31081f9aa5707"}, + {file = "torchvision-0.18.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b657d052d146f24cb3b2a78219bfc82ae70a9706671c50f632528907d10cccec"}, + {file = "torchvision-0.18.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a964afbc7ddf50a46b941477f6c35729b416deedd139756befd488245e2e226d"}, + {file = "torchvision-0.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:7c770f0f748e0b17f57c0297508d7254f686cdf03fc2e2949f422b20574f4c0f"}, ] [[package]] @@ -2392,7 +2437,7 @@ files = [ [[package]] name = "transformers" -version = "4.39.0" +version = "4.40.1" requires_python = ">=3.8.0" summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" groups = ["default"] @@ -2405,17 +2450,17 @@ dependencies = [ "regex!=2019.12.17", "requests", "safetensors>=0.4.1", - "tokenizers<0.19,>=0.14", + "tokenizers<0.20,>=0.19", "tqdm>=4.27", ] files = [ - {file = "transformers-4.39.0-py3-none-any.whl", hash = "sha256:7801785b1f016d667467e8c372c1c3653c18fe32ba97952059e3bea79ba22b08"}, - {file = "transformers-4.39.0.tar.gz", hash = "sha256:517a13cd633b10bea01c92ab0b3059762872c7c29da3d223db9d28e926fe330d"}, + {file = "transformers-4.40.1-py3-none-any.whl", hash = "sha256:9d5ee0c8142a60501faf9e49a0b42f8e9cb8611823bce4f195a9325a6816337e"}, + {file = "transformers-4.40.1.tar.gz", hash = "sha256:55e1697e6f18b58273e7117bb469cdffc11be28995462d8d5e422fef38d2de36"}, ] [[package]] name = "triton" -version = "2.2.0" +version = "2.3.0" summary = "A language and compiler for custom Deep Learning operations" groups = ["all", "default", "vision"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"" @@ -2423,9 +2468,9 @@ dependencies = [ "filelock", ] files = [ - {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"}, - {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"}, - {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"}, + {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, + {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, + {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 59241a78..625f98a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,14 +59,14 @@ vision = [ "h5py>=3.10.0", "nibabel>=5.2.0", "opencv-python-headless>=4.9.0.80", - "timm>=0.9.12", + "timm @ git+https://github.com/huggingface/pytorch-image-models.git@main", "torchvision>=0.17.0", ] all = [ "h5py>=3.10.0", "nibabel>=5.2.0", "opencv-python-headless>=4.9.0.80", - "timm>=0.9.12", + "timm @ git+https://github.com/huggingface/pytorch-image-models.git@main", "torchvision>=0.17.0", ] diff --git a/src/eva/core/models/networks/__init__.py b/src/eva/core/models/networks/__init__.py index c1ce5bb3..34e3d7a9 100644 --- a/src/eva/core/models/networks/__init__.py +++ b/src/eva/core/models/networks/__init__.py @@ -3,4 +3,4 @@ from eva.core.models.networks.mlp import MLP from eva.core.models.networks.wrappers import HuggingFaceModel, ModelFromFunction, ONNXModel -__all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel", "MLP"] +__all__ = ["MLP", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"] diff --git a/src/eva/core/models/networks/wrappers/__init__.py b/src/eva/core/models/networks/wrappers/__init__.py index 42513046..766dd8f4 100644 --- a/src/eva/core/models/networks/wrappers/__init__.py +++ b/src/eva/core/models/networks/wrappers/__init__.py @@ -5,4 +5,9 @@ from eva.core.models.networks.wrappers.huggingface import HuggingFaceModel from eva.core.models.networks.wrappers.onnx import ONNXModel -__all__ = ["BaseModel", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"] +__all__ = [ + "BaseModel", + "ModelFromFunction", + "HuggingFaceModel", + "ONNXModel", +] diff --git a/src/eva/vision/models/networks/__init__.py b/src/eva/vision/models/networks/__init__.py index 554e65c3..4d669351 100644 --- a/src/eva/vision/models/networks/__init__.py +++ b/src/eva/vision/models/networks/__init__.py @@ -2,5 +2,6 @@ from eva.vision.models.networks import postprocesses from eva.vision.models.networks.abmil import ABMIL +from eva.vision.models.networks.from_timm import TimmModel -__all__ = ["postprocesses", "ABMIL"] +__all__ = ["postprocesses", "ABMIL", "TimmModel"] diff --git a/src/eva/vision/models/networks/encoders/__init__.py b/src/eva/vision/models/networks/encoders/__init__.py new file mode 100644 index 00000000..88ae1345 --- /dev/null +++ b/src/eva/vision/models/networks/encoders/__init__.py @@ -0,0 +1,6 @@ +"""Encoder networks API.""" + +from eva.vision.models.networks.encoders.encoder import Encoder +from eva.vision.models.networks.encoders.from_timm import TimmEncoder + +__all__ = ["Encoder", "TimmEncoder"] diff --git a/src/eva/vision/models/networks/encoders/encoder.py b/src/eva/vision/models/networks/encoders/encoder.py new file mode 100644 index 00000000..b2c53aa3 --- /dev/null +++ b/src/eva/vision/models/networks/encoders/encoder.py @@ -0,0 +1,23 @@ +"""Encoder base class.""" + +import abc +from typing import List + +import torch +from torch import nn + + +class Encoder(nn.Module, abc.ABC): + """Encoder base class.""" + + @abc.abstractmethod + def forward(self, tensor: torch.Tensor) -> List[torch.Tensor]: + """Returns the multi-level feature maps of the model. + + Args: + tensor: The image tensor (batch_size, num_channels, height, width). + + Returns: + The list of multi-level image features of shape (batch_size, + hidden_size, num_patches_height, num_patches_width). + """ diff --git a/src/eva/vision/models/networks/encoders/from_timm.py b/src/eva/vision/models/networks/encoders/from_timm.py new file mode 100644 index 00000000..8ac1f66a --- /dev/null +++ b/src/eva/vision/models/networks/encoders/from_timm.py @@ -0,0 +1,61 @@ +"""Encoder wrapper for timm models.""" + +from typing import Any, Dict, List, Tuple + +import timm +import torch +from torch import nn +from typing_extensions import override + +from eva.vision.models.networks.encoders import encoder + + +class TimmEncoder(encoder.Encoder): + """Encoder wrapper for `timm` models. + + Note that only models with `forward_intermediates` + method are currently only supported. + """ + + def __init__( + self, + model_name: str, + pretrained: bool = False, + checkpoint_path: str = "", + out_indices: int | Tuple[int, ...] | None = 1, + model_arguments: Dict[str, Any] | None = None, + ) -> None: + """Initializes the encoder. + + Args: + model_name: Name of model to instantiate. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + out_indices: Returns last n blocks if `int`, all if `None`, select + matching indices if sequence. + model_arguments: Extra model arguments. + """ + super().__init__() + + self._model_name = model_name + self._pretrained = pretrained + self._checkpoint_path = checkpoint_path + self._out_indices = out_indices + self._model_arguments = model_arguments or {} + + self._feature_extractor = self._load_model() + + def _load_model(self) -> nn.Module: + """Builds, loads and returns the timm model as feature extractor.""" + return timm.create_model( + model_name=self._model_name, + pretrained=self._pretrained, + checkpoint_path=self._checkpoint_path, + out_indices=self._out_indices, + features_only=True, + **self._model_arguments, + ) + + @override + def forward(self, tensor: torch.Tensor) -> List[torch.Tensor]: + return self._feature_extractor(tensor) diff --git a/src/eva/vision/models/networks/from_timm.py b/src/eva/vision/models/networks/from_timm.py new file mode 100644 index 00000000..22053591 --- /dev/null +++ b/src/eva/vision/models/networks/from_timm.py @@ -0,0 +1,36 @@ +"""Helper wrapper class for timm models.""" + +from typing import Any, Dict + +import timm + +from eva.core.models.networks import wrappers + + +class TimmModel(wrappers.ModelFromFunction): + """Wrapper class for timm models.""" + + def __init__( + self, + model_name: str, + pretrained: bool = False, + checkpoint_path: str = "", + model_arguments: Dict[str, Any] | None = None, + ) -> None: + """Initializes and constructs the model. + + Args: + model_name: Name of model to instantiate. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + model_arguments: The extra callable function / class arguments. + """ + super().__init__( + path=timm.create_model, + arguments={ + "model_name": model_name, + "pretrained": pretrained, + "checkpoint_path": checkpoint_path, + } + | (model_arguments or {}), + ) diff --git a/tests/eva/vision/models/networks/encoders/__init__.py b/tests/eva/vision/models/networks/encoders/__init__.py new file mode 100644 index 00000000..5760a3d4 --- /dev/null +++ b/tests/eva/vision/models/networks/encoders/__init__.py @@ -0,0 +1 @@ +"""Test for vision encoders.""" diff --git a/tests/eva/vision/models/networks/encoders/test_from_timm.py b/tests/eva/vision/models/networks/encoders/test_from_timm.py new file mode 100644 index 00000000..0c356810 --- /dev/null +++ b/tests/eva/vision/models/networks/encoders/test_from_timm.py @@ -0,0 +1,66 @@ +"""TimmEncoder tests.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch + +from eva.vision.models.networks import encoders + + +@pytest.mark.parametrize( + "model_name, out_indices, model_arguments, input_tensor, expected_len, expected_shape", + [ + ( + "vit_small_patch16_224", + 1, + None, + torch.Tensor(2, 3, 224, 224), + 1, + torch.Size([2, 384, 14, 14]), + ), + ( + "vit_small_patch16_224", + 3, + None, + torch.Tensor(2, 3, 224, 224), + 3, + torch.Size([2, 384, 14, 14]), + ), + ( + "vit_small_patch16_224", + 3, + {"dynamic_img_size": True}, + torch.Tensor(2, 3, 512, 512), + 3, + torch.Size([2, 384, 32, 32]), + ), + ], +) +def test_timm_encoder( + timm_encoder: encoders.TimmEncoder, + input_tensor: torch.Tensor, + expected_len: int, + expected_shape: torch.Size, +) -> None: + """Tests the TimmEncoder network.""" + outputs = timm_encoder(input_tensor) + assert isinstance(outputs, list) + assert len(outputs) == expected_len + # individual + assert isinstance(outputs[0], torch.Tensor) + assert outputs[0].shape == expected_shape + + +@pytest.fixture(scope="function") +def timm_encoder( + model_name: str, + out_indices: int | Tuple[int, ...] | None, + model_arguments: Dict[str, Any] | None, +) -> encoders.TimmEncoder: + """TimmEncoder fixture.""" + return encoders.TimmEncoder( + model_name=model_name, + out_indices=out_indices, + model_arguments=model_arguments, + ) diff --git a/tests/eva/vision/models/networks/test_from_timm.py b/tests/eva/vision/models/networks/test_from_timm.py new file mode 100644 index 00000000..384a4f27 --- /dev/null +++ b/tests/eva/vision/models/networks/test_from_timm.py @@ -0,0 +1,48 @@ +"""TimmModel tests.""" + +from typing import Any, Dict + +import pytest +import torch + +from eva.vision.models import networks + + +@pytest.mark.parametrize( + "model_name, model_arguments, input_tensor, expected_shape", + [ + ( + "vit_small_patch16_224", + {"num_classes": 0}, + torch.Tensor(2, 3, 224, 224), + torch.Size([2, 384]), + ), + ( + "vit_small_patch16_224", + {"num_classes": 0, "dynamic_img_size": True}, + torch.Tensor(2, 3, 512, 512), + torch.Size([2, 384]), + ), + ], +) +def test_timm_model( + timm_model: networks.TimmModel, + input_tensor: torch.Tensor, + expected_shape: torch.Size, +) -> None: + """Tests the timm_model network.""" + outputs = timm_model(input_tensor) + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == expected_shape + + +@pytest.fixture(scope="function") +def timm_model( + model_name: str, + model_arguments: Dict[str, Any] | None, +) -> networks.TimmModel: + """TimmModel fixture.""" + return networks.TimmModel( + model_name=model_name, + model_arguments=model_arguments, + ) From ad392e88ee413b0455d90d8a10b562bbcabd2bd2 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 6 May 2024 13:05:52 +0200 Subject: [PATCH 03/21] Add SemanticSegmentation module (#410) --- src/eva/core/models/modules/typings.py | 13 ++ src/eva/vision/models/modules/__init__.py | 5 + .../models/modules/semantic_segmentation.py | 152 ++++++++++++++++++ tests/eva/core/models/modules/test_head.py | 5 +- .../eva/core/models/modules/test_inference.py | 5 +- tests/eva/vision/models/modules/__init__.py | 1 + tests/eva/vision/models/modules/conftest.py | 83 ++++++++++ .../modules/test_semantic_segmentation.py | 59 +++++++ 8 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 src/eva/vision/models/modules/__init__.py create mode 100644 src/eva/vision/models/modules/semantic_segmentation.py create mode 100644 tests/eva/vision/models/modules/__init__.py create mode 100644 tests/eva/vision/models/modules/conftest.py create mode 100644 tests/eva/vision/models/modules/test_semantic_segmentation.py diff --git a/src/eva/core/models/modules/typings.py b/src/eva/core/models/modules/typings.py index fa476bd1..85a9bcd4 100644 --- a/src/eva/core/models/modules/typings.py +++ b/src/eva/core/models/modules/typings.py @@ -21,3 +21,16 @@ class INPUT_BATCH(NamedTuple): metadata: Dict[str, Any] | None = None """The associated metadata.""" + + +class INPUT_TENSOR_BATCH(NamedTuple): + """Tensor based input batch data scheme.""" + + data: torch.Tensor + """The data batch.""" + + targets: torch.Tensor + """The target batch.""" + + metadata: Dict[str, Any] | None = None + """The associated metadata.""" diff --git a/src/eva/vision/models/modules/__init__.py b/src/eva/vision/models/modules/__init__.py new file mode 100644 index 00000000..ac3342f7 --- /dev/null +++ b/src/eva/vision/models/modules/__init__.py @@ -0,0 +1,5 @@ +"""Vision modules API.""" + +from eva.vision.models.modules.semantic_segmentation import SemanticSegmentationModule + +__all__ = ["SemanticSegmentationModule"] diff --git a/src/eva/vision/models/modules/semantic_segmentation.py b/src/eva/vision/models/modules/semantic_segmentation.py new file mode 100644 index 00000000..f058bc52 --- /dev/null +++ b/src/eva/vision/models/modules/semantic_segmentation.py @@ -0,0 +1,152 @@ +""""Neural Network Semantic Segmentation Module.""" + +from typing import Any, Callable, Iterable, Tuple + +import torch +from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim +from torch.optim import lr_scheduler +from typing_extensions import override + +from eva.core.metrics import structs as metrics_lib +from eva.core.models.modules import module +from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH +from eva.core.models.modules.utils import batch_postprocess, grad +from eva.vision.models.networks import decoders, encoders + + +class SemanticSegmentationModule(module.ModelModule): + """Neural network semantic segmentation module for training on patch embeddings.""" + + def __init__( + self, + decoder: decoders.Decoder, + criterion: Callable[..., torch.Tensor], + encoder: encoders.Encoder | None = None, + lr_multiplier_encoder: float = 0.0, + optimizer: OptimizerCallable = optim.AdamW, + lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR, + metrics: metrics_lib.MetricsSchema | None = None, + postprocess: batch_postprocess.BatchPostProcess | None = None, + ) -> None: + """Initializes the neural net head module. + + Args: + decoder: The decoder model. + criterion: The loss function to use. + encoder: The encoder model. If `None`, it will be expected + that the input batch returns the features directly. + lr_multiplier_encoder: The learning rate multiplier for the + encoder parameters. If `0`, it will freeze the encoder. + optimizer: The optimizer to use. + lr_scheduler: The learning rate scheduler to use. + metrics: The metric groups to track. + postprocess: A list of helper functions to apply after the + loss and before the metrics calculation to the model + predictions and targets. + """ + super().__init__(metrics=metrics, postprocess=postprocess) + + self.decoder = decoder + self.criterion = criterion + self.encoder = encoder + self.lr_multiplier_encoder = lr_multiplier_encoder + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + @override + def configure_model(self) -> None: + self._freeze_encoder() + + @override + def configure_optimizers(self) -> Any: + optimizer = self.optimizer( + [ + {"params": self.decoder.parameters()}, + { + "params": self._encoder_trainable_parameters(), + "lr": self._base_lr * self.lr_multiplier_encoder, + }, + ] + ) + lr_scheduler = self.lr_scheduler(optimizer) + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + + @override + def forward( + self, + tensor: torch.Tensor, + image_size: Tuple[int, int] | None = None, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """Maps the input tensor (image tensor or embeddings) to masks. + + If `tensor` is image tensor, then the `self.encoder` + should be implemented, otherwise it will be interpreted + as embeddings, where the `image_size` should be given. + """ + if self.encoder is None and image_size is None: + raise ValueError( + "Please provide the expected `image_size` that the " + "decoder should map the embeddings (`tensor`) to." + ) + + patch_embeddings = self.encoder(tensor) if self.encoder else tensor + return self.decoder(patch_embeddings, image_size or tensor.shape[-2:]) + + @override + def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self._batch_step(batch) + + @override + def validation_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self._batch_step(batch) + + @override + def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self._batch_step(batch) + + @override + def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor: + tensor = INPUT_BATCH(*batch).data + return tensor if self.backbone is None else self.backbone(tensor) + + @property + def _base_lr(self) -> float: + """Returns the base learning rate.""" + base_optimizer = self.optimizer(self.parameters()) + return base_optimizer.param_groups[-1]["lr"] + + def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]: + """Returns the trainable parameters of the encoder.""" + return ( + self.encoder.parameters() + if self.encoder is not None and self.lr_multiplier_encoder > 0 + else iter(()) + ) + + def _freeze_encoder(self) -> None: + """If initialized, Freezes the encoder network.""" + if self.encoder is not None and self.lr_multiplier_encoder == 0: + grad.deactivate_requires_grad(self.encoder) + + def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT: + """Performs a model forward step and calculates the loss. + + Args: + batch: The desired batch to process. + + Returns: + The batch step output. + """ + data, targets, metadata = INPUT_TENSOR_BATCH(*batch) + predictions = self(data, targets.shape[-2:]) + loss = self.criterion(predictions, targets) + return { + "loss": loss, + "targets": targets, + "predictions": predictions, + "metadata": metadata, + } diff --git a/tests/eva/core/models/modules/test_head.py b/tests/eva/core/models/modules/test_head.py index b3b0bc35..00e4b7a2 100644 --- a/tests/eva/core/models/modules/test_head.py +++ b/tests/eva/core/models/modules/test_head.py @@ -35,7 +35,10 @@ def test_head_module_fit( @pytest.fixture(scope="function") -def model(input_shape: Tuple[int, ...] = (3, 8, 8), n_classes: int = 4) -> modules.HeadModule: +def model( + input_shape: Tuple[int, ...] = (3, 8, 8), + n_classes: int = 4, +) -> modules.HeadModule: """Returns a HeadModule model fixture.""" return modules.HeadModule( head=nn.Linear(math.prod(input_shape), n_classes), diff --git a/tests/eva/core/models/modules/test_inference.py b/tests/eva/core/models/modules/test_inference.py index b17f6202..247abf46 100644 --- a/tests/eva/core/models/modules/test_inference.py +++ b/tests/eva/core/models/modules/test_inference.py @@ -46,5 +46,8 @@ def model( ) -> modules.InferenceModule: """Returns a HeadModule model fixture.""" return modules.InferenceModule( - backbone=nn.Sequential(nn.Flatten(), nn.Linear(math.prod(input_shape), n_classes)) + backbone=nn.Sequential( + nn.Flatten(), + nn.Linear(math.prod(input_shape), n_classes), + ) ) diff --git a/tests/eva/vision/models/modules/__init__.py b/tests/eva/vision/models/modules/__init__.py new file mode 100644 index 00000000..e22daaa8 --- /dev/null +++ b/tests/eva/vision/models/modules/__init__.py @@ -0,0 +1 @@ +"""Tests for the vision model modules.""" diff --git a/tests/eva/vision/models/modules/conftest.py b/tests/eva/vision/models/modules/conftest.py new file mode 100644 index 00000000..ef18e4e6 --- /dev/null +++ b/tests/eva/vision/models/modules/conftest.py @@ -0,0 +1,83 @@ +"""Shared configuration and fixtures for models/modules unit tests.""" + +from typing import Tuple + +import pytest +import torch +from torch.utils import data as torch_data + +from eva.core.data import dataloaders, datamodules, datasets +from eva.core.trainers import trainer as eva_trainer + + +@pytest.fixture(scope="function") +def datamodule( + request: pytest.FixtureRequest, + dataset_fixture: str, + dataloader: dataloaders.DataLoader, +) -> datamodules.DataModule: + """Returns a dummy datamodule fixture.""" + dataset = request.getfixturevalue(dataset_fixture) + return datamodules.DataModule( + datasets=datamodules.DatasetsSchema( + train=dataset, + val=dataset, + predict=dataset, + ), + dataloaders=datamodules.DataloadersSchema( + train=dataloader, + val=dataloader, + predict=dataloader, + ), + ) + + +@pytest.fixture(scope="function") +def trainer(max_epochs: int = 1) -> eva_trainer.Trainer: + """Returns a model trainer fixture.""" + return eva_trainer.Trainer( + max_epochs=max_epochs, + accelerator="cpu", + default_root_dir="logs/test", + ) + + +@pytest.fixture(scope="function") +def segmentation_dataset( + n_samples: int = 4, + input_shape: Tuple[int, ...] = (3, 16, 16), + target_shape: Tuple[int, ...] = (16, 16), + n_classes: int = 4, +) -> datasets.TorchDataset: + """Dummy segmentation dataset fixture.""" + return torch_data.TensorDataset( + torch.randn((n_samples,) + input_shape), + torch.randint(n_classes, (n_samples,) + target_shape, dtype=torch.long), + ) + + +@pytest.fixture(scope="function") +def segmentation_dataset_with_metadata( + n_samples: int = 4, + input_shape: Tuple[int, ...] = (3, 16, 16), + target_shape: Tuple[int, ...] = (16, 16), + n_classes: int = 4, +) -> datasets.TorchDataset: + """Dummy segmentation dataset fixture with metadata.""" + return torch_data.TensorDataset( + torch.randn((n_samples,) + input_shape), + torch.randint(n_classes, (n_samples,) + target_shape, dtype=torch.long), + torch.randint(2, (n_samples,) + target_shape, dtype=torch.long), + ) + + +@pytest.fixture(scope="function") +def dataloader(batch_size: int = 2) -> dataloaders.DataLoader: + """Test dataloader fixture.""" + return dataloaders.DataLoader( + batch_size=batch_size, + num_workers=0, + pin_memory=False, + persistent_workers=False, + prefetch_factor=None, + ) diff --git a/tests/eva/vision/models/modules/test_semantic_segmentation.py b/tests/eva/vision/models/modules/test_semantic_segmentation.py new file mode 100644 index 00000000..73776762 --- /dev/null +++ b/tests/eva/vision/models/modules/test_semantic_segmentation.py @@ -0,0 +1,59 @@ +"""Tests the HeadModule module.""" + +import pytest +import torch +from torch import nn + +from eva.core import metrics, trainers +from eva.core.data import datamodules +from eva.vision.models import modules +from eva.vision.models.networks import encoders +from eva.vision.models.networks.decoders import segmentation + + +@pytest.mark.parametrize( + "dataset_fixture", + [ + "segmentation_dataset", + "segmentation_dataset_with_metadata", + ], +) +def test_semantic_segmentation_module_fit( + model: modules.SemanticSegmentationModule, + datamodule: datamodules.DataModule, + trainer: trainers.Trainer, +) -> None: + """Tests the SemanticSegmentationModule fit pipeline.""" + initial_decoder_weights = model.decoder._layers.weight.clone() + trainer.fit(model, datamodule=datamodule) + # verify that the metrics were updated + assert trainer.logged_metrics["train/AverageLoss"] > 0 + assert trainer.logged_metrics["val/AverageLoss"] > 0 + # verify that head weights were updated + assert not torch.all(torch.eq(initial_decoder_weights, model.decoder._layers.weight)) + + +@pytest.fixture(scope="function") +def model(n_classes: int = 4) -> modules.SemanticSegmentationModule: + """Returns a SemanticSegmentationModule model fixture.""" + return modules.SemanticSegmentationModule( + decoder=segmentation.ConvDecoder( + layers=nn.Conv2d( + in_channels=192, + out_channels=n_classes, + kernel_size=(1, 1), + ), + ), + criterion=nn.CrossEntropyLoss(), + encoder=encoders.TimmEncoder( + model_name="vit_tiny_patch16_224", + pretrained=False, + out_indices=1, + model_arguments={ + "dynamic_img_size": True, + }, + ), + metrics=metrics.MetricsSchema( + common=metrics.AverageLoss(), + ), + ) From 4ef3afb383093b59a49efa82764f0f62ee7351de Mon Sep 17 00:00:00 2001 From: ioangatop Date: Tue, 7 May 2024 10:07:57 +0200 Subject: [PATCH 04/21] Add `TotalSegmentator2D` segmentation downstream task (#413) --- .../dino_vit/online/total_segmentator_2d.yaml | 75 ++++++++++++ docs/DEVELOPER_GUIDE.md | 5 +- main.py | 0 src/eva/core/models/modules/head.py | 17 +-- src/eva/vision/data/datasets/_utils.py | 6 +- .../segmentation/total_segmentator.py | 109 ++++++++++-------- .../data/transforms/common/resize_and_crop.py | 9 +- .../segmentation/test_total_segmentator.py | 6 +- 8 files changed, 159 insertions(+), 68 deletions(-) create mode 100644 configs/vision/dino_vit/online/total_segmentator_2d.yaml delete mode 100644 main.py diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml new file mode 100644 index 00000000..a3a4defe --- /dev/null +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -0,0 +1,75 @@ +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/total_segmentator_2d/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + encoder: + class_path: eva.vision.models.networks.encoders.TimmEncoder + init_args: + model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224} + pretrained: true + out_indices: 1 + model_arguments: + dynamic_img_size: true + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoder + init_args: + layers: + class_path: torch.nn.Conv2d + init_args: + in_channels: ${oc.env:IN_FEATURES, 384} + out_channels: &NUM_CLASSES 117 + kernel_size: [1, 1] + criterion: torch.nn.CrossEntropyLoss + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0001 + weight_decay: 0.05 + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + metrics: + common: + - class_path: eva.metrics.AverageLoss +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.TotalSegmentator2D + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/total_segmentator + split: train + download: false + # Set `download: true` to download the dataset from https://zenodo.org/records/10047292 + # The TotalSegmentator dataset is distributed under the following license: + # "Creative Commons Attribution 4.0 International" + # (see: https://creativecommons.org/licenses/by/4.0/deed.en) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + val: + class_path: eva.vision.datasets.TotalSegmentator2D + init_args: + <<: *DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 16} + shuffle: true + val: + batch_size: *BATCH_SIZE + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 16} diff --git a/docs/DEVELOPER_GUIDE.md b/docs/DEVELOPER_GUIDE.md index 92f562ad..a7f97bc1 100644 --- a/docs/DEVELOPER_GUIDE.md +++ b/docs/DEVELOPER_GUIDE.md @@ -17,10 +17,7 @@ Add a new dependency to the `core` submodule:
`pdm add ` Add a new dependency to the `vision` submodule:
-`pdm add -G vision ` - -After adding a new dependency, you also need to update the `pdm.lock` file:
-`pdm update` +`pdm add -G vision -G all ` For more information about managing dependencies please look [here](https://pdm-project.org/latest/usage/dependency/#manage-dependencies). diff --git a/main.py b/main.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/eva/core/models/modules/head.py b/src/eva/core/models/modules/head.py index 0976e8f2..95748f70 100644 --- a/src/eva/core/models/modules/head.py +++ b/src/eva/core/models/modules/head.py @@ -54,9 +54,14 @@ def __init__( self.optimizer = optimizer self.lr_scheduler = lr_scheduler + @override + def configure_model(self) -> Any: + if self.backbone is not None: + grad.deactivate_requires_grad(self.backbone) + @override def configure_optimizers(self) -> Any: - parameters = list(self.head.parameters()) + parameters = self.head.parameters() optimizer = self.optimizer(parameters) lr_scheduler = self.lr_scheduler(optimizer) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} @@ -66,11 +71,6 @@ def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens features = tensor if self.backbone is None else self.backbone(tensor) return self.head(features).squeeze(-1) - @override - def on_fit_start(self) -> None: - if self.backbone is not None: - grad.deactivate_requires_grad(self.backbone) - @override def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: return self._batch_step(batch) @@ -88,11 +88,6 @@ def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.T tensor = INPUT_BATCH(*batch).data return tensor if self.backbone is None else self.backbone(tensor) - @override - def on_fit_end(self) -> None: - if self.backbone is not None: - grad.activate_requires_grad(self.backbone) - def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT: """Performs a model forward step and calculates the loss. diff --git a/src/eva/vision/data/datasets/_utils.py b/src/eva/vision/data/datasets/_utils.py index 2d2fe30b..1a17d7e9 100644 --- a/src/eva/vision/data/datasets/_utils.py +++ b/src/eva/vision/data/datasets/_utils.py @@ -1,6 +1,6 @@ """Dataset related function and helper functions.""" -from typing import List, Tuple +from typing import List, Sequence, Tuple def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]: @@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]: return ranges -def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]: +def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]: """Unpacks a list of ranges to individual indices. Args: - ranges: The list of ranges to produce the indices from. + ranges: A sequence of ranges to produce the indices from. Return: A list of the indices. diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 4892e6b6..92bb8992 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -18,14 +18,14 @@ class TotalSegmentator2D(base.ImageSegmentation): """TotalSegmentator 2D segmentation dataset.""" - _train_index_ranges: List[Tuple[int, int]] = [(0, 83)] - """Train range indices.""" + _expected_dataset_lengths: Dict[str, int] = { + "train_small": 29892, + "val_small": 6480, + } + """Dataset version and split to the expected size.""" - _val_index_ranges: List[Tuple[int, int]] = [(83, 103)] - """Validation range indices.""" - - _n_slices_per_image: int = 20 - """The amount of slices to sample per 3D CT scan image.""" + _sample_every_n_slices: int | None = None + """The amount of slices to sub-sample per 3D CT scan image.""" _resources_full: List[structs.DownloadResource] = [ structs.DownloadResource( @@ -49,7 +49,7 @@ def __init__( self, root: str, split: Literal["train", "val"] | None, - version: Literal["small", "full"] = "small", + version: Literal["small", "full"] | None = "small", download: bool = False, as_uint8: bool = True, transforms: Callable | None = None, @@ -60,7 +60,8 @@ def __init__( root: Path to the root directory of the dataset. The dataset will be downloaded and extracted here, if it does not already exist. split: Dataset split to use. If `None`, the entire dataset is used. - version: The version of the dataset to initialize. + version: The version of the dataset to initialize. If `None`, it will + use the files located at root as is and wont perform any checks. download: Whether to download the data for the specified split. Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not @@ -78,7 +79,7 @@ def __init__( self._as_uint8 = as_uint8 self._samples_dirs: List[str] = [] - self._indices: List[int] = [] + self._indices: List[Tuple[int, int]] = [] @functools.cached_property @override @@ -99,7 +100,8 @@ def class_to_idx(self) -> Dict[str, int]: @override def filename(self, index: int) -> str: - sample_dir = self._samples_dirs[self._indices[index]] + sample_idx, _ = self._indices[index] + sample_dir = self._samples_dirs[sample_idx] return os.path.join(sample_dir, "ct.nii.gz") @override @@ -114,21 +116,24 @@ def configure(self) -> None: @override def validate(self) -> None: + if self._version is None: + return + _validators.check_dataset_integrity( self, - length=1660 if self._split == "train" else 400, + length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0), n_classes=117, first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"), ) @override def __len__(self) -> int: - return len(self._indices) * self._n_slices_per_image + return len(self._indices) @override def load_image(self, index: int) -> tv_tensors.Image: - image_path = self._get_image_path(index) - slice_index = self._get_sample_slice_index(index) + sample_index, slice_index = self._indices[index] + image_path = self._get_image_path(sample_index) image_array = io.read_nifti_slice(image_path, slice_index) if self._as_uint8: image_array = convert.to_8bit(image_array) @@ -137,8 +142,8 @@ def load_image(self, index: int) -> tv_tensors.Image: @override def load_mask(self, index: int) -> tv_tensors.Mask: - masks_dir = self._get_masks_dir(index) - slice_index = self._get_sample_slice_index(index) + sample_index, slice_index = self._indices[index] + masks_dir = self._get_masks_dir(sample_index) mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) one_hot_encoded = np.concatenate( [io.read_nifti_slice(path, slice_index) for path in mask_paths], @@ -149,27 +154,20 @@ def load_mask(self, index: int) -> tv_tensors.Mask: segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2) return tv_tensors.Mask(segmentation_label) - def _get_masks_dir(self, index: int) -> str: - """Returns the directory of the corresponding masks.""" - sample_dir = self._get_sample_dir(index) - return os.path.join(self._root, sample_dir, "segmentations") - - def _get_image_path(self, index: int) -> str: + def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" - sample_dir = self._get_sample_dir(index) + sample_dir = self._samples_dirs[sample_index] return os.path.join(self._root, sample_dir, "ct.nii.gz") - def _get_sample_dir(self, index: int) -> str: - """Returns the corresponding sample directory.""" - sample_index = self._indices[index // self._n_slices_per_image] - return self._samples_dirs[sample_index] + def _get_masks_dir(self, sample_index: int) -> str: + """Returns the directory of the corresponding masks.""" + sample_dir = self._samples_dirs[sample_index] + return os.path.join(self._root, sample_dir, "segmentations") - def _get_sample_slice_index(self, index: int) -> int: - """Returns the corresponding slice index.""" - image_path = self._get_image_path(index) - total_slices = io.fetch_total_nifti_slices(image_path) - slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int) - return slice_indices[index % self._n_slices_per_image] + def _get_number_of_slices_per_sample(self, sample_index: int) -> int: + """Returns the total amount of slices of a sample.""" + image_path = self._get_image_path(sample_index) + return io.fetch_total_nifti_slices(image_path) def _fetch_samples_dirs(self) -> List[str]: """Returns the name of all the samples of all the splits of the dataset.""" @@ -180,29 +178,46 @@ def _fetch_samples_dirs(self) -> List[str]: ] return sorted(sample_filenames) - def _create_indices(self) -> List[int]: - """Builds the dataset indices for the specified split.""" - split_index_ranges = { - "train": self._train_index_ranges, - "val": self._val_index_ranges, - None: [(0, 103)], - } - index_ranges = split_index_ranges.get(self._split) - if index_ranges is None: - raise ValueError("Invalid data split. Use 'train', 'val' or `None`.") + def _get_split_indices(self) -> List[int]: + """Returns the samples indices that corresponding the dataset split and version.""" + key = f"{self._split}_{self._version}" + match key: + case "train_small": + index_ranges = [(0, 83)] + case "val_small": + index_ranges = [(83, 102)] + case _: + index_ranges = [(0, len(self._samples_dirs))] return _utils.ranges_to_indices(index_ranges) + def _create_indices(self) -> List[Tuple[int, int]]: + """Builds the dataset indices for the specified split. + + Returns: + A list of tuples, where the first value indicates the + sample index which the second its corresponding slice + index. + """ + indices = [ + (sample_idx, slide_idx) + for sample_idx in self._get_split_indices() + for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx)) + if slide_idx % (self._sample_every_n_slices or 1) == 0 + ] + return indices + def _download_dataset(self) -> None: """Downloads the dataset.""" dataset_resources = { "small": self._resources_small, "full": self._resources_full, - None: (0, 103), } - resources = dataset_resources.get(self._version) + resources = dataset_resources.get(self._version or "") if resources is None: - raise ValueError("Invalid data version. Use 'small' or 'full'.") + raise ValueError( + f"Can't download data version '{self._version}'. Use 'small' or 'full'." + ) for resource in resources: if os.path.isdir(self._root): diff --git a/src/eva/vision/data/transforms/common/resize_and_crop.py b/src/eva/vision/data/transforms/common/resize_and_crop.py index f1956a66..f5320679 100644 --- a/src/eva/vision/data/transforms/common/resize_and_crop.py +++ b/src/eva/vision/data/transforms/common/resize_and_crop.py @@ -4,6 +4,7 @@ import torch import torchvision.transforms.v2 as torch_transforms +from torchvision import tv_tensors class ResizeAndCrop(torch_transforms.Compose): @@ -35,7 +36,13 @@ def _build_transforms(self) -> Sequence[Callable]: torch_transforms.ToImage(), torch_transforms.Resize(size=self._size), torch_transforms.CenterCrop(size=self._size), - torch_transforms.ToDtype(torch.float32, scale=True), + torch_transforms.ToDtype( + { + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.float32, + }, + scale=True, + ), torch_transforms.Normalize( mean=self._mean, std=self._std, diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 3e7f09e6..9607a2a8 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "split, expected_length", - [("train", 1660), ("val", 400), (None, 2060)], + [("train", 9), ("val", 9), (None, 9)], ) def test_length( total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int @@ -25,6 +25,7 @@ def test_length( [ (None, 0), ("train", 0), + ("val", 0), ], ) def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: int) -> None: @@ -43,7 +44,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i @pytest.fixture(scope="function") def total_segmentator_dataset( - split: Literal["train", "val"], assets_path: str + split: Literal["train", "val"] | None, assets_path: str ) -> datasets.TotalSegmentator2D: """TotalSegmentator2D dataset fixture.""" dataset = datasets.TotalSegmentator2D( @@ -55,6 +56,7 @@ def total_segmentator_dataset( "Totalsegmentator_dataset_v201", ), split=split, + version=None, ) dataset.prepare_data() dataset.configure() From 07fbd08c760501024dd49313f16d68dba572b51e Mon Sep 17 00:00:00 2001 From: ioangatop Date: Tue, 7 May 2024 13:44:18 +0200 Subject: [PATCH 05/21] Add dice score in `TotalSegmentator2D` task (#423) --- .../vision/dino_vit/online/total_segmentator_2d.yaml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index a3a4defe..bc7860a0 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -27,7 +27,7 @@ model: class_path: torch.nn.Conv2d init_args: in_channels: ${oc.env:IN_FEATURES, 384} - out_channels: &NUM_CLASSES 117 + out_channels: &NUM_CLASSES 118 kernel_size: [1, 1] criterion: torch.nn.CrossEntropyLoss lr_multiplier_encoder: 0.0 @@ -41,9 +41,17 @@ model: init_args: total_iters: *MAX_STEPS power: 0.9 + postprocess: + targets_transforms: + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: torch.int64 metrics: common: - class_path: eva.metrics.AverageLoss + - class_path: torchmetrics.Dice + init_args: + num_classes: *NUM_CLASSES data: class_path: eva.DataModule init_args: From a6fd246b63b39085813ac7a06c9ebd4573e340c7 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 13 May 2024 10:20:57 +0200 Subject: [PATCH 06/21] Allow to use subclasses in `TotalSegmentator2D` (#435) --- .../segmentation/total_segmentator.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 92bb8992..b427583d 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -52,6 +52,7 @@ def __init__( version: Literal["small", "full"] | None = "small", download: bool = False, as_uint8: bool = True, + classes: List[str] | None = None, transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -67,6 +68,8 @@ def __init__( calling the :meth:`prepare_data` method and if the data does not exist yet on disk. as_uint8: Whether to convert and return the images as a 8-bit. + classes: Whether to configure the dataset with a subset of classes. + If `None`, it will use all of them. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. """ @@ -77,6 +80,7 @@ def __init__( self._version = version self._download = download self._as_uint8 = as_uint8 + self._classes = classes self._samples_dirs: List[str] = [] self._indices: List[Tuple[int, int]] = [] @@ -91,7 +95,13 @@ def get_filename(path: str) -> str: first_sample_labels = os.path.join( self._root, self._samples_dirs[0], "segmentations", "*.nii.gz" ) - return sorted(map(get_filename, glob(first_sample_labels))) + all_classes = sorted(map(get_filename, glob(first_sample_labels))) + if self._classes: + is_subset = all(name in all_classes for name in self._classes) + if not is_subset: + raise ValueError("Provided class names are not subset of the dataset onces.") + + return all_classes if self._classes is None else self._classes @property @override @@ -122,8 +132,12 @@ def validate(self) -> None: _validators.check_dataset_integrity( self, length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0), - n_classes=117, - first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"), + n_classes=len(self._classes) if self._classes else 117, + first_and_last_labels=( + (self._classes[0], self._classes[-1]) + if self._classes + else ("adrenal_gland_left", "vertebrae_T9") + ), ) @override @@ -145,10 +159,8 @@ def load_mask(self, index: int) -> tv_tensors.Mask: sample_index, slice_index = self._indices[index] masks_dir = self._get_masks_dir(sample_index) mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - one_hot_encoded = np.concatenate( - [io.read_nifti_slice(path, slice_index) for path in mask_paths], - axis=2, - ) + binary_masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths] + one_hot_encoded = np.concatenate(binary_masks, axis=2) background_mask = one_hot_encoded.sum(axis=2, keepdims=True) == 0 one_hot_encoded_with_bg = np.concatenate([background_mask, one_hot_encoded], axis=2) segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2) From db5a578a017427b9d43211f94a42e49d1bc4adc2 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 13 May 2024 11:19:23 +0200 Subject: [PATCH 07/21] Create a callback to visualise the segmentation results (#424) --- .../dino_vit/online/total_segmentator_2d.yaml | 4 + pdm.lock | 109 +++++------ src/eva/core/callbacks/writers/embeddings.py | 21 +- src/eva/core/loggers/log/__init__.py | 5 + src/eva/core/loggers/log/image.py | 59 ++++++ src/eva/core/loggers/log/utils.py | 13 ++ src/eva/core/loggers/loggers.py | 6 + src/eva/core/utils/__init__.py | 4 + src/eva/core/utils/memory.py | 28 +++ src/eva/vision/__init__.py | 4 +- src/eva/vision/callbacks/__init__.py | 5 + src/eva/vision/callbacks/loggers/__init__.py | 5 + .../callbacks/loggers/batch/__init__.py | 5 + .../vision/callbacks/loggers/batch/base.py | 130 +++++++++++++ .../callbacks/loggers/batch/segmentation.py | 181 ++++++++++++++++++ src/eva/vision/utils/colormap.py | 77 ++++++++ 16 files changed, 591 insertions(+), 65 deletions(-) create mode 100644 src/eva/core/loggers/log/__init__.py create mode 100644 src/eva/core/loggers/log/image.py create mode 100644 src/eva/core/loggers/log/utils.py create mode 100644 src/eva/core/loggers/loggers.py create mode 100644 src/eva/core/utils/memory.py create mode 100644 src/eva/vision/callbacks/__init__.py create mode 100644 src/eva/vision/callbacks/loggers/__init__.py create mode 100644 src/eva/vision/callbacks/loggers/batch/__init__.py create mode 100644 src/eva/vision/callbacks/loggers/batch/base.py create mode 100644 src/eva/vision/callbacks/loggers/batch/segmentation.py create mode 100644 src/eva/vision/utils/colormap.py diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index bc7860a0..f60eb183 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -4,6 +4,10 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/total_segmentator_2d/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + callbacks: + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: diff --git a/pdm.lock b/pdm.lock index 9ae7e117..0117967b 100644 --- a/pdm.lock +++ b/pdm.lock @@ -772,25 +772,25 @@ files = [ [[package]] name = "lightning" -version = "2.3.0.dev20240421" +version = "2.3.0.dev20240505" requires_python = ">=3.8" summary = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." groups = ["default"] dependencies = [ "PyYAML<8.0,>=5.4", - "fsspec[http]<2025.0,>=2022.5.0", + "fsspec[http]<2026.0,>=2022.5.0", "lightning-utilities<2.0,>=0.8.0", "numpy<3.0,>=1.17.2", "packaging<25.0,>=20.0", "pytorch-lightning", - "torch<4.0,>=1.13.0", + "torch<4.0,>=2.0.0", "torchmetrics<3.0,>=0.7.0", "tqdm<6.0,>=4.57.0", "typing-extensions<6.0,>=4.4.0", ] files = [ - {file = "lightning-2.3.0.dev20240421-py3-none-any.whl", hash = "sha256:18a31ce5fec10c11e73ebddb970aab4f47acf86e4ff187d1f85ae15f3be540ca"}, - {file = "lightning-2.3.0.dev20240421.tar.gz", hash = "sha256:a803678ca35e24a9ec8b01f3c5ce7207be80adbf13775bc14cf8c4af7e7039d2"}, + {file = "lightning-2.3.0.dev20240505-py3-none-any.whl", hash = "sha256:cdf6042c342be0f99267dee27f2700702a53bf7b08cb6df6daae5094674f5ce6"}, + {file = "lightning-2.3.0.dev20240505.tar.gz", hash = "sha256:6f9b541d14e798db830aa8a0bb5940d758e26fa76b4a236d102012f92c2911cc"}, ] [[package]] @@ -927,7 +927,7 @@ files = [ [[package]] name = "mike" -version = "2.0.0" +version = "2.1.1" summary = "Manage multiple versions of your MkDocs-powered documentation" groups = ["dev", "docs"] dependencies = [ @@ -936,12 +936,13 @@ dependencies = [ "jinja2>=2.7", "mkdocs>=1.0", "pyparsing>=3.0", + "pyyaml-env-tag", "pyyaml>=5.1", "verspec", ] files = [ - {file = "mike-2.0.0-py3-none-any.whl", hash = "sha256:87f496a65900f93ba92d72940242b65c86f3f2f82871bc60ebdcffc91fad1d9e"}, - {file = "mike-2.0.0.tar.gz", hash = "sha256:566f1cab1a58cc50b106fb79ea2f1f56e7bfc8b25a051e95e6eaee9fba0922de"}, + {file = "mike-2.1.1-py3-none-any.whl", hash = "sha256:0b1d01a397a423284593eeb1b5f3194e37169488f929b860c9bfe95c0d5efb79"}, + {file = "mike-2.1.1.tar.gz", hash = "sha256:f39ed39f3737da83ad0adc33e9f885092ed27f8c9e7ff0523add0480352a2c22"}, ] [[package]] @@ -1004,7 +1005,7 @@ files = [ [[package]] name = "mkdocs-material" -version = "9.5.19" +version = "9.5.21" requires_python = ">=3.8" summary = "Documentation that simply works" groups = ["dev", "docs"] @@ -1022,8 +1023,8 @@ dependencies = [ "requests~=2.26", ] files = [ - {file = "mkdocs_material-9.5.19-py3-none-any.whl", hash = "sha256:ea96e150b6c95f5e4ffe47d78bb712c7bacdd91d2a0bec47f46b6fa0705a86ec"}, - {file = "mkdocs_material-9.5.19.tar.gz", hash = "sha256:7473e06e17e23af608a30ef583fdde8f36389dd3ef56b1d503eed54c89c9618c"}, + {file = "mkdocs_material-9.5.21-py3-none-any.whl", hash = "sha256:210e1f179682cd4be17d5c641b2f4559574b9dea2f589c3f0e7c17c5bd1959bc"}, + {file = "mkdocs_material-9.5.21.tar.gz", hash = "sha256:049f82770f40559d3c2aa2259c562ea7257dbb4aaa9624323b5ef27b2d95a450"}, ] [[package]] @@ -1063,7 +1064,7 @@ files = [ [[package]] name = "mkdocstrings" -version = "0.24.3" +version = "0.25.1" requires_python = ">=3.8" summary = "Automatic documentation from sources, for MkDocs." groups = ["dev", "docs"] @@ -1078,8 +1079,8 @@ dependencies = [ "pymdown-extensions>=6.3", ] files = [ - {file = "mkdocstrings-0.24.3-py3-none-any.whl", hash = "sha256:5c9cf2a32958cd161d5428699b79c8b0988856b0d4a8c5baf8395fc1bf4087c3"}, - {file = "mkdocstrings-0.24.3.tar.gz", hash = "sha256:f327b234eb8d2551a306735436e157d0a22d45f79963c60a8b585d5f7a94c1d2"}, + {file = "mkdocstrings-0.25.1-py3-none-any.whl", hash = "sha256:da01fcc2670ad61888e8fe5b60afe9fee5781017d67431996832d63e887c2e51"}, + {file = "mkdocstrings-0.25.1.tar.gz", hash = "sha256:c3a2515f31577f311a9ee58d089e4c51fc6046dbd9e9b4c3de4c3194667fe9bf"}, ] [[package]] @@ -1100,18 +1101,18 @@ files = [ [[package]] name = "mkdocstrings" -version = "0.24.3" +version = "0.25.1" extras = ["python"] requires_python = ">=3.8" summary = "Automatic documentation from sources, for MkDocs." groups = ["dev", "docs"] dependencies = [ "mkdocstrings-python>=0.5.2", - "mkdocstrings==0.24.3", + "mkdocstrings==0.25.1", ] files = [ - {file = "mkdocstrings-0.24.3-py3-none-any.whl", hash = "sha256:5c9cf2a32958cd161d5428699b79c8b0988856b0d4a8c5baf8395fc1bf4087c3"}, - {file = "mkdocstrings-0.24.3.tar.gz", hash = "sha256:f327b234eb8d2551a306735436e157d0a22d45f79963c60a8b585d5f7a94c1d2"}, + {file = "mkdocstrings-0.25.1-py3-none-any.whl", hash = "sha256:da01fcc2670ad61888e8fe5b60afe9fee5781017d67431996832d63e887c2e51"}, + {file = "mkdocstrings-0.25.1.tar.gz", hash = "sha256:c3a2515f31577f311a9ee58d089e4c51fc6046dbd9e9b4c3de4c3194667fe9bf"}, ] [[package]] @@ -1713,13 +1714,13 @@ files = [ [[package]] name = "pluggy" -version = "1.4.0" +version = "1.5.0" requires_python = ">=3.8" summary = "plugin and hook calling mechanisms for python" groups = ["dev", "test", "typecheck"] files = [ - {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, - {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] [[package]] @@ -1740,13 +1741,13 @@ files = [ [[package]] name = "pygments" -version = "2.17.2" -requires_python = ">=3.7" +version = "2.18.0" +requires_python = ">=3.8" summary = "Pygments is a syntax highlighting package written in Python." groups = ["default", "dev", "docs", "lint", "test"] files = [ - {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, - {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, ] [[package]] @@ -1788,7 +1789,7 @@ files = [ [[package]] name = "pyright" -version = "1.1.360" +version = "1.1.361" requires_python = ">=3.7" summary = "Command line wrapper for pyright" groups = ["dev", "typecheck"] @@ -1796,13 +1797,13 @@ dependencies = [ "nodeenv>=1.6.0", ] files = [ - {file = "pyright-1.1.360-py3-none-any.whl", hash = "sha256:7637f75451ac968b7cf1f8c51cfefb6d60ac7d086eb845364bc8ac03a026efd7"}, - {file = "pyright-1.1.360.tar.gz", hash = "sha256:784ddcda9745e9f5610483d7b963e9aa8d4f50d7755a9dffb28ccbeb27adce32"}, + {file = "pyright-1.1.361-py3-none-any.whl", hash = "sha256:c50fc94ce92b5c958cfccbbe34142e7411d474da43d6c14a958667e35b9df7ea"}, + {file = "pyright-1.1.361.tar.gz", hash = "sha256:1d67933315666b05d230c85ea8fb97aaa2056e4092a13df87b7765bb9e8f1a8d"}, ] [[package]] name = "pytest" -version = "8.1.1" +version = "8.2.0" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" groups = ["dev", "test", "typecheck"] @@ -1811,12 +1812,12 @@ dependencies = [ "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", "iniconfig", "packaging", - "pluggy<2.0,>=1.4", + "pluggy<2.0,>=1.5", "tomli>=1; python_version < \"3.11\"", ] files = [ - {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, - {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, + {file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"}, + {file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"}, ] [[package]] @@ -2017,28 +2018,28 @@ files = [ [[package]] name = "ruff" -version = "0.4.2" +version = "0.4.3" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["dev", "lint"] files = [ - {file = "ruff-0.4.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:8d14dc8953f8af7e003a485ef560bbefa5f8cc1ad994eebb5b12136049bbccc5"}, - {file = "ruff-0.4.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:24016ed18db3dc9786af103ff49c03bdf408ea253f3cb9e3638f39ac9cf2d483"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2e06459042ac841ed510196c350ba35a9b24a643e23db60d79b2db92af0c2b"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3afabaf7ba8e9c485a14ad8f4122feff6b2b93cc53cd4dad2fd24ae35112d5c5"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:799eb468ea6bc54b95527143a4ceaf970d5aa3613050c6cff54c85fda3fde480"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ec4ba9436a51527fb6931a8839af4c36a5481f8c19e8f5e42c2f7ad3a49f5069"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a2243f8f434e487c2a010c7252150b1fdf019035130f41b77626f5655c9ca22"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8772130a063f3eebdf7095da00c0b9898bd1774c43b336272c3e98667d4fb8fa"}, - {file = "ruff-0.4.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ab165ef5d72392b4ebb85a8b0fbd321f69832a632e07a74794c0e598e7a8376"}, - {file = "ruff-0.4.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1f32cadf44c2020e75e0c56c3408ed1d32c024766bd41aedef92aa3ca28eef68"}, - {file = "ruff-0.4.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:22e306bf15e09af45ca812bc42fa59b628646fa7c26072555f278994890bc7ac"}, - {file = "ruff-0.4.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:82986bb77ad83a1719c90b9528a9dd663c9206f7c0ab69282af8223566a0c34e"}, - {file = "ruff-0.4.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:652e4ba553e421a6dc2a6d4868bc3b3881311702633eb3672f9f244ded8908cd"}, - {file = "ruff-0.4.2-py3-none-win32.whl", hash = "sha256:7891ee376770ac094da3ad40c116258a381b86c7352552788377c6eb16d784fe"}, - {file = "ruff-0.4.2-py3-none-win_amd64.whl", hash = "sha256:5ec481661fb2fd88a5d6cf1f83403d388ec90f9daaa36e40e2c003de66751798"}, - {file = "ruff-0.4.2-py3-none-win_arm64.whl", hash = "sha256:cbd1e87c71bca14792948c4ccb51ee61c3296e164019d2d484f3eaa2d360dfaf"}, - {file = "ruff-0.4.2.tar.gz", hash = "sha256:33bcc160aee2520664bc0859cfeaebc84bb7323becff3f303b8f1f2d81cb4edc"}, + {file = "ruff-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b70800c290f14ae6fcbb41bbe201cf62dfca024d124a1f373e76371a007454ce"}, + {file = "ruff-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:08a0d6a22918ab2552ace96adeaca308833873a4d7d1d587bb1d37bae8728eb3"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eba1f14df3c758dd7de5b55fbae7e1c8af238597961e5fb628f3de446c3c40c5"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:819fb06d535cc76dfddbfe8d3068ff602ddeb40e3eacbc90e0d1272bb8d97113"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0bfc9e955e6dc6359eb6f82ea150c4f4e82b660e5b58d9a20a0e42ec3bb6342b"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:510a67d232d2ebe983fddea324dbf9d69b71c4d2dfeb8a862f4a127536dd4cfb"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc9ff11cd9a092ee7680a56d21f302bdda14327772cd870d806610a3503d001f"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:29efff25bf9ee685c2c8390563a5b5c006a3fee5230d28ea39f4f75f9d0b6f2f"}, + {file = "ruff-0.4.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18b00e0bcccf0fc8d7186ed21e311dffd19761cb632241a6e4fe4477cc80ef6e"}, + {file = "ruff-0.4.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:262f5635e2c74d80b7507fbc2fac28fe0d4fef26373bbc62039526f7722bca1b"}, + {file = "ruff-0.4.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7363691198719c26459e08cc17c6a3dac6f592e9ea3d2fa772f4e561b5fe82a3"}, + {file = "ruff-0.4.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:eeb039f8428fcb6725bb63cbae92ad67b0559e68b5d80f840f11914afd8ddf7f"}, + {file = "ruff-0.4.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:927b11c1e4d0727ce1a729eace61cee88a334623ec424c0b1c8fe3e5f9d3c865"}, + {file = "ruff-0.4.3-py3-none-win32.whl", hash = "sha256:25cacda2155778beb0d064e0ec5a3944dcca9c12715f7c4634fd9d93ac33fd30"}, + {file = "ruff-0.4.3-py3-none-win_amd64.whl", hash = "sha256:7a1c3a450bc6539ef00da6c819fb1b76b6b065dec585f91456e7c0d6a0bbc725"}, + {file = "ruff-0.4.3-py3-none-win_arm64.whl", hash = "sha256:71ca5f8ccf1121b95a59649482470c5601c60a416bf189d553955b0338e34614"}, + {file = "ruff-0.4.3.tar.gz", hash = "sha256:ff0a3ef2e3c4b6d133fbedcf9586abfbe38d076041f2dc18ffb2c7e0485d5a07"}, ] [[package]] @@ -2226,7 +2227,7 @@ version = "1.0.0.dev0" requires_python = ">=3.8" git = "https://github.com/huggingface/pytorch-image-models.git" ref = "main" -revision = "e741370e2b95e0c2fa3e00808cd9014ee620ca62" +revision = "f8979d4f50b7920c78511746f7315df8f1857bc5" summary = "PyTorch Image Models" groups = ["all", "vision"] dependencies = [ @@ -2437,7 +2438,7 @@ files = [ [[package]] name = "transformers" -version = "4.40.1" +version = "4.40.2" requires_python = ">=3.8.0" summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" groups = ["default"] @@ -2454,8 +2455,8 @@ dependencies = [ "tqdm>=4.27", ] files = [ - {file = "transformers-4.40.1-py3-none-any.whl", hash = "sha256:9d5ee0c8142a60501faf9e49a0b42f8e9cb8611823bce4f195a9325a6816337e"}, - {file = "transformers-4.40.1.tar.gz", hash = "sha256:55e1697e6f18b58273e7117bb469cdffc11be28995462d8d5e422fef38d2de36"}, + {file = "transformers-4.40.2-py3-none-any.whl", hash = "sha256:71cb94301ec211a2e1d4b8c8d18dcfaa902dfa00a089dceca167a8aa265d6f2d"}, + {file = "transformers-4.40.2.tar.gz", hash = "sha256:657b6054a2097671398d976ad46e60836e7e15f9ea9551631a96e33cb9240649"}, ] [[package]] diff --git a/src/eva/core/callbacks/writers/embeddings.py b/src/eva/core/callbacks/writers/embeddings.py index 4b3cceec..70c462f9 100644 --- a/src/eva/core/callbacks/writers/embeddings.py +++ b/src/eva/core/callbacks/writers/embeddings.py @@ -30,19 +30,22 @@ def __init__( ) -> None: """Initializes a new EmbeddingsWriter instance. - This callback writes the embedding files in a separate process to avoid blocking the - main process where the model forward pass is executed. + This callback writes the embedding files in a separate process + to avoid blocking the main process where the model forward pass + is executed. Args: output_dir: The directory where the embeddings will be saved. backbone: A model to be used as feature extractor. If `None`, - it will be expected that the input batch returns the features directly. - dataloader_idx_map: A dictionary mapping dataloader indices to their respective - names (e.g. train, val, test). - group_key: The metadata key to group the embeddings by. If specified, the - embedding files will be saved in subdirectories named after the group_key. - If specified, the key must be present in the metadata of the input batch. - overwrite: Whether to overwrite the output directory. Defaults to True. + it will be expected that the input batch returns the + features directly. + dataloader_idx_map: A dictionary mapping dataloader indices to + their respective names (e.g. train, val, test). + group_key: The metadata key to group the embeddings by. If specified, + the embedding files will be saved in subdirectories named after + the group_key. If specified, the key must be present in the metadata + of the input batch. + overwrite: Whether to overwrite the output directory. """ super().__init__(write_interval="batch") diff --git a/src/eva/core/loggers/log/__init__.py b/src/eva/core/loggers/log/__init__.py new file mode 100644 index 00000000..df998fde --- /dev/null +++ b/src/eva/core/loggers/log/__init__.py @@ -0,0 +1,5 @@ +"""Experiment loggers operations.""" + +from eva.core.loggers.log.image import log_image + +__all__ = ["log_image"] diff --git a/src/eva/core/loggers/log/image.py b/src/eva/core/loggers/log/image.py new file mode 100644 index 00000000..d087a5d1 --- /dev/null +++ b/src/eva/core/loggers/log/image.py @@ -0,0 +1,59 @@ +"""Image log functionality.""" + +import functools + +import torch + +from eva.core.loggers import loggers +from eva.core.loggers.log import utils + + +@functools.singledispatch +def log_image( + logger, + tag: str, + image: torch.Tensor, + step: int = 0, +) -> None: + """Adds an image to the logger. + + Args: + logger: The desired logger. + tag: The log tag. + image: The image tensor to log. It should have + the shape of (3,H,W) and (0,1) normalized. + step: The global step of the log. + """ + utils.raise_not_supported(logger, "image") + + +@log_image.register +def _( + loggers: list, + tag: str, + image: torch.Tensor, + step: int = 0, +) -> None: + """Adds an image to a list of supported loggers.""" + for logger in loggers: + log_image( + logger, + tag=tag, + image=image, + step=step, + ) + + +@log_image.register +def _( + logger: loggers.TensorBoardLogger, + tag: str, + image: torch.Tensor, + step: int = 0, +) -> None: + """Adds an image to a TensorBoard logger.""" + logger.experiment.add_image( + tag=tag, + img_tensor=image, + global_step=step, + ) diff --git a/src/eva/core/loggers/log/utils.py b/src/eva/core/loggers/log/utils.py new file mode 100644 index 00000000..5e860257 --- /dev/null +++ b/src/eva/core/loggers/log/utils.py @@ -0,0 +1,13 @@ +"""Logging related utilities.""" + +from loguru import logger as cli_logger + +from eva.core.loggers import loggers as loggers_lib + + +def raise_not_supported(logger: loggers_lib.Loggers, data_type: str) -> None: + """Raises a warning for not supported tasks from the given logger.""" + print("\n") + cli_logger.debug( + f"Logger '{logger.__class__.__name__}' is not supported for '{data_type}' data." + ) diff --git a/src/eva/core/loggers/loggers.py b/src/eva/core/loggers/loggers.py new file mode 100644 index 00000000..1ed815e5 --- /dev/null +++ b/src/eva/core/loggers/loggers.py @@ -0,0 +1,6 @@ +"""Experimental loggers.""" + +from lightning.pytorch.loggers import TensorBoardLogger + +Loggers = TensorBoardLogger +"""Supported loggers.""" diff --git a/src/eva/core/utils/__init__.py b/src/eva/core/utils/__init__.py index f99e16fd..8d8dd40a 100644 --- a/src/eva/core/utils/__init__.py +++ b/src/eva/core/utils/__init__.py @@ -1 +1,5 @@ """Utilities and library level helper functionalities.""" + +from eva.core.utils.memory import to_cpu + +__all__ = ["to_cpu"] diff --git a/src/eva/core/utils/memory.py b/src/eva/core/utils/memory.py new file mode 100644 index 00000000..d820f718 --- /dev/null +++ b/src/eva/core/utils/memory.py @@ -0,0 +1,28 @@ +"""Memory related functions.""" + +import functools +from typing import Any, Dict, List + +import torch + + +@functools.singledispatch +def to_cpu(tensor_type: Any) -> Any: + """Moves tensor objects to `cpu`.""" + raise TypeError(f"Unsupported input type: {type(input)}.") + + +@to_cpu.register +def _(tensor: torch.Tensor) -> torch.Tensor: + detached_tensor = tensor.detach() + return detached_tensor.cpu() + + +@to_cpu.register +def _(tensors: list) -> List[torch.Tensor]: + return list(map(to_cpu, tensors)) + + +@to_cpu.register +def _(tensors: dict) -> Dict[str, torch.Tensor]: + return {key: to_cpu(tensors[key]) for key in tensors} diff --git a/src/eva/vision/__init__.py b/src/eva/vision/__init__.py index 0e4bd4ca..aba88723 100644 --- a/src/eva/vision/__init__.py +++ b/src/eva/vision/__init__.py @@ -1,7 +1,7 @@ """eva vision API.""" try: - from eva.vision import models, utils + from eva.vision import callbacks, models, utils from eva.vision.data import datasets, transforms except ImportError as e: msg = ( @@ -11,4 +11,4 @@ ) raise ImportError(str(e) + "\n\n" + msg) from e -__all__ = ["models", "utils", "datasets", "transforms"] +__all__ = ["callbacks", "models", "utils", "datasets", "transforms"] diff --git a/src/eva/vision/callbacks/__init__.py b/src/eva/vision/callbacks/__init__.py new file mode 100644 index 00000000..d7022760 --- /dev/null +++ b/src/eva/vision/callbacks/__init__.py @@ -0,0 +1,5 @@ +"""Vision callbacks API.""" + +from eva.vision.callbacks.loggers import SemanticSegmentationLogger + +__all__ = ["SemanticSegmentationLogger"] diff --git a/src/eva/vision/callbacks/loggers/__init__.py b/src/eva/vision/callbacks/loggers/__init__.py new file mode 100644 index 00000000..8d5fcf47 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/__init__.py @@ -0,0 +1,5 @@ +"""Vision logging related callbacks API.""" + +from eva.vision.callbacks.loggers.batch import SemanticSegmentationLogger + +__all__ = ["SemanticSegmentationLogger"] diff --git a/src/eva/vision/callbacks/loggers/batch/__init__.py b/src/eva/vision/callbacks/loggers/batch/__init__.py new file mode 100644 index 00000000..d03f1342 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/batch/__init__.py @@ -0,0 +1,5 @@ +"""Batch related loggers callbacks API.""" + +from eva.vision.callbacks.loggers.batch.segmentation import SemanticSegmentationLogger + +__all__ = ["SemanticSegmentationLogger"] diff --git a/src/eva/vision/callbacks/loggers/batch/base.py b/src/eva/vision/callbacks/loggers/batch/base.py new file mode 100644 index 00000000..71f1993b --- /dev/null +++ b/src/eva/vision/callbacks/loggers/batch/base.py @@ -0,0 +1,130 @@ +"""Base batch callback logger.""" + +import abc + +from lightning import pytorch as pl +from lightning.pytorch.utilities.types import STEP_OUTPUT +from typing_extensions import override + +from eva.core.models.modules.typings import INPUT_TENSOR_BATCH + + +class BatchLogger(pl.Callback, abc.ABC): + """Logs training and validation batch assets.""" + + _batch_idx_to_log: int = 0 + """The batch index log.""" + + def __init__( + self, + log_every_n_epochs: int | None = 10, + log_every_n_steps: int | None = None, + ) -> None: + """Initializes the callback object. + + Args: + log_every_n_epochs: Epoch-wise logging frequency. + log_every_n_steps: Step-wise logging frequency. + """ + super().__init__() + + if log_every_n_epochs is None and log_every_n_steps is None: + raise ValueError( + "Please configure the logging frequency though " + "`log_every_n_epochs` or `log_every_n_steps`." + ) + if None not in [log_every_n_epochs, log_every_n_steps]: + raise ValueError( + "Arguments `log_every_n_epochs` and `log_every_n_steps` " + "are mutually exclusive. Please configure one of them." + ) + + self._log_every_n_epochs = log_every_n_epochs + self._log_every_n_steps = log_every_n_steps + + @override + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + batch_idx: int, + ) -> None: + if self._skip_logging(trainer, batch_idx): + return + + self._log_batch( + trainer=trainer, + batch=batch, + outputs=outputs, + tag="BatchTrain", + ) + + @override + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if self._skip_logging(trainer, batch_idx): + return + + self._log_batch( + trainer=trainer, + batch=batch, + outputs=outputs, + tag="BatchValidation", + ) + + @abc.abstractmethod + def _log_batch( + self, + trainer: pl.Trainer, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + tag: str, + ) -> None: + """Logs the batch data. + + Args: + trainer: The trainer. + outputs: The output of the train / val step. + batch: The data batch. + tag: The log tag. + """ + + def _skip_logging( + self, + trainer: pl.Trainer, + batch_idx: int, + ) -> bool: + """Determines whether skip the logging step or not. + + Args: + trainer: The trainer. + batch_idx: The batch index. + + Returns: + A boolean indicating whether to skip the step execution. + """ + if trainer.global_step in [0, 1]: + return False + + skip_due_frequency = any( + [ + (trainer.current_epoch + 1) % (self._log_every_n_epochs or 1) != 0, + (trainer.global_step + 1) % (self._log_every_n_steps or 1) != 0, + ] + ) + + conditions = [ + skip_due_frequency, + not trainer.is_global_zero, + batch_idx != self._batch_idx_to_log, + ] + return any(conditions) diff --git a/src/eva/vision/callbacks/loggers/batch/segmentation.py b/src/eva/vision/callbacks/loggers/batch/segmentation.py new file mode 100644 index 00000000..4e41ee21 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/batch/segmentation.py @@ -0,0 +1,181 @@ +"""Segmentation datasets related data loggers.""" + +from typing import Iterable, List, Tuple + +import torch +import torchvision +from lightning import pytorch as pl +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torchvision.transforms.v2 import functional +from typing_extensions import override + +from eva.core.loggers import log +from eva.core.models.modules.typings import INPUT_TENSOR_BATCH +from eva.core.utils import to_cpu +from eva.vision.callbacks.loggers.batch import base +from eva.vision.utils import colormap + + +class SemanticSegmentationLogger(base.BatchLogger): + """Log the segmentation batch.""" + + def __init__( + self, + max_samples: int = 10, + number_of_images_per_subgrid_row: int = 2, + mean: Tuple[float, ...] = (0.0, 0.0, 0.0), + std: Tuple[float, ...] = (1.0, 1.0, 1.0), + log_every_n_epochs: int | None = 10, + log_every_n_steps: int | None = None, + ) -> None: + """Initializes the callback object. + + Args: + max_samples: The maximum number of images displayed in the grid. + number_of_images_per_subgrid_row: Number of images displayed in each row + of each sub-grid (that is images, targets and predictions). + mean: The mean of the input images to de-normalize from. + std: The std of the input images to de-normalize from. + log_every_n_epochs: Epoch-wise logging frequency. + log_every_n_steps: Step-wise logging frequency. + """ + super().__init__( + log_every_n_epochs=log_every_n_epochs, + log_every_n_steps=log_every_n_steps, + ) + + self._max_samples = max_samples + self._number_of_images_per_subgrid_row = number_of_images_per_subgrid_row + self._mean = mean + self._std = std + + @override + def _log_batch( + self, + trainer: pl.Trainer, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + tag: str, + ) -> None: + predictions = outputs.get("predictions") if isinstance(outputs, dict) else None + if predictions is None: + raise ValueError("Key `predictions` is missing from the output.") + + images, targets, predictions = _subsample_tensors( + tensors_stack=[batch[0], batch[1], predictions], + max_samples=self._max_samples, + ) + images, targets, predictions = to_cpu([images, targets, predictions]) + predictions = torch.argmax(predictions, dim=1) + + images = list(map(self._denormalize_image, images)) + targets = list(map(_draw_semantic_mask, targets)) + predictions = list(map(_draw_semantic_mask, predictions)) + image_grid = _make_grid_from_image_groups( + [images, targets, predictions], self._number_of_images_per_subgrid_row + ) + + log.log_image( + trainer.loggers, + image=image_grid, + tag=tag, + step=trainer.global_step, + ) + + def _denormalize_image(self, image: torch.Tensor) -> torch.Tensor: + """De-normalizes an image tensor to (0., 1.) range.""" + return _denormalize_image(image, mean=self._mean, std=self._std) + + +def _subsample_tensors( + tensors_stack: List[torch.Tensor], + max_samples: int, +) -> List[torch.Tensor]: + """Sub-samples tensors from a list of tensors in-place. + + Args: + tensors_stack: A list of tensors. + max_samples: The maximum number of images + displayed in the grid. + + Returns: + A sub-sample of the input tensors stack. + """ + for i, tensor in enumerate(tensors_stack): + tensors_stack[i] = tensor[:max_samples] + return tensors_stack + + +def _denormalize_image( + tensor: torch.Tensor, + mean: Iterable[float] = (0.0, 0.0, 0.0), + std: Iterable[float] = (1.0, 1.0, 1.0), + inplace: bool = True, +) -> torch.Tensor: + """De-normalizes an image tensor to (0., 1.) range. + + Args: + tensor: An image float tensor. + mean: The normalized channels mean values. + std: The normalized channels std values. + inplace: Whether to perform the operation in-place. + Defaults to `True`. + + Returns: + The de-normalized image tensor of range (0., 1.). + """ + if not inplace: + tensor = tensor.clone() + + return functional.normalize( + tensor, + mean=[-cmean / cstd for cmean, cstd in zip(mean, std, strict=False)], + std=[1 / cstd for cstd in std], + ) + + +def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor: + """Draws a semantic mask to an image RGB tensor. + + The input semantic mask is a (H x W) shaped tensor with + integer values which represent the pixel class id. + + Args: + tensor: An image tensor of range [0., 1.]. + + Returns: + The image as a tensor of range [0., 255.]. + """ + tensor = torch.squeeze(tensor) + height, width = tensor.shape[-2], tensor.shape[-1] + red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8) + for class_id, color in colormap.COLORMAP.items(): + indices = tensor == class_id + red[indices], green[indices], blue[indices] = color + return torch.stack([red, green, blue]) + + +def _make_grid_from_image_groups( + image_groups: List[List[torch.Tensor]], + number_of_images_per_subgrid_row: int = 2, +) -> torch.Tensor: + """Creates a single image grid from image groups. + + For example, it can combine the input images, targets predictions into a single image. + + Args: + image_groups: A list of lists of image tensors of shape (C x H x W) + all of the same size. + number_of_images_per_subgrid_row: Number of images displayed in each + row of the sub-grid. + + Returns: + An image grid as a `torch.Tensor`. + """ + return torchvision.utils.make_grid( + [ + torchvision.utils.make_grid(image_group, nrow=number_of_images_per_subgrid_row) + for image_group in image_groups + ], + nrow=len(image_groups), + ) diff --git a/src/eva/vision/utils/colormap.py b/src/eva/vision/utils/colormap.py new file mode 100644 index 00000000..2ca70604 --- /dev/null +++ b/src/eva/vision/utils/colormap.py @@ -0,0 +1,77 @@ +"""Color mapping constants.""" + +COLORS = [ + (0, 0, 0), + (255, 0, 0), # Red + (0, 255, 0), # Green + (0, 0, 255), # Blue + (255, 255, 0), # Yellow + (255, 0, 255), # Magenta + (0, 255, 255), # Cyan + (128, 128, 0), # Olive + (128, 0, 128), # Purple + (0, 128, 128), # Teal + (192, 192, 192), # Silver + (128, 128, 128), # Gray + (255, 165, 0), # Orange + (210, 105, 30), # Chocolate + (0, 128, 0), # Lime + (255, 192, 203), # Pink + (255, 69, 0), # Red-Orange + (255, 140, 0), # Dark Orange + (0, 255, 255), # Sky Blue + (0, 255, 127), # Spring Green + (0, 0, 139), # Dark Blue + (255, 20, 147), # Deep Pink + (139, 69, 19), # Saddle Brown + (0, 100, 0), # Dark Green + (106, 90, 205), # Slate Blue + (138, 43, 226), # Blue-Violet + (218, 165, 32), # Goldenrod + (199, 21, 133), # Medium Violet Red + (70, 130, 180), # Steel Blue + (165, 42, 42), # Brown + (128, 0, 0), # Maroon + (255, 0, 255), # Fuchsia + (210, 180, 140), # Tan + (0, 0, 128), # Navy + (139, 0, 139), # Dark Magenta + (144, 238, 144), # Light Green + (46, 139, 87), # Sea Green + (255, 255, 0), # Gold + (154, 205, 50), # Yellow Green + (0, 191, 255), # Deep Sky Blue + (0, 250, 154), # Medium Spring Green + (250, 128, 114), # Salmon + (255, 105, 180), # Hot Pink + (204, 255, 204), # Pastel Light Green + (51, 0, 51), # Very Dark Magenta + (255, 102, 0), # Dark Orange + (0, 255, 0), # Bright Green + (51, 153, 255), # Blue-Purple + (51, 51, 255), # Bright Blue + (204, 0, 0), # Dark Red + (90, 90, 90), # Very Dark Gray + (255, 255, 51), # Pastel Yellow + (255, 153, 255), # Pink-Magenta + (153, 0, 76), # Dark Pink + (51, 25, 0), # Very Dark Brown + (102, 51, 0), # Dark Brown + (0, 0, 51), # Very Dark Blue + (180, 180, 180), # Dark Gray + (102, 255, 204), # Pastel Green + (0, 102, 0), # Dark Green + (220, 245, 20), # Lime Yellow + (255, 204, 204), # Pastel Pink + (0, 204, 255), # Pastel Blue + (240, 240, 240), # Light Gray + (153, 153, 0), # Dark Yellow + (102, 0, 51), # Dark Red-Pink + (0, 51, 0), # Very Dark Green + (255, 102, 204), # Magenta Pink + (204, 0, 102), # Red-Pink +] +"""RGB colors.""" + +COLORMAP = dict(enumerate(COLORS)) | {255: (255, 255, 255)} +"""Class id to RGB color mapping.""" From 7656ae63bee213109900b39ef19849d88f4d0138 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Wed, 15 May 2024 08:56:47 +0200 Subject: [PATCH 08/21] Improve the mask loading in `TotalSegmentator2D` (#440) --- .gitignore | 6 ++ .../dino_vit/online/total_segmentator_2d.yaml | 7 +- .../classification/total_segmentator.py | 7 +- .../segmentation/total_segmentator.py | 75 ++++++++++++++++--- src/eva/vision/utils/io/__init__.py | 7 +- src/eva/vision/utils/io/nifti.py | 37 ++++++--- 6 files changed, 109 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 7ca6b17d..da0b5aae 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,9 @@ cython_debug/ # ignore local data /data/ + +# numpy data +*.npy + +# NiFti data +*.nii.gz diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index f60eb183..07822938 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -8,6 +8,8 @@ trainer: - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: log_every_n_epochs: 1 + mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5] + std: &NORMALIZE_STD [0.5, 0.5, 0.5] logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: @@ -72,6 +74,9 @@ data: # (see: https://creativecommons.org/licenses/by/4.0/deed.en) transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + mean: *NORMALIZE_MEAN + std: *NORMALIZE_STD val: class_path: eva.vision.datasets.TotalSegmentator2D init_args: @@ -83,5 +88,3 @@ data: shuffle: true val: batch_size: *BATCH_SIZE - predict: - batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 16} diff --git a/src/eva/vision/data/datasets/classification/total_segmentator.py b/src/eva/vision/data/datasets/classification/total_segmentator.py index c7c0c88d..58b98b9d 100644 --- a/src/eva/vision/data/datasets/classification/total_segmentator.py +++ b/src/eva/vision/data/datasets/classification/total_segmentator.py @@ -132,7 +132,7 @@ def __len__(self) -> int: def load_image(self, index: int) -> np.ndarray: image_path = self._get_image_path(index) slice_index = self._get_sample_slice_index(index) - image_array = io.read_nifti_slice(image_path, slice_index) + image_array = io.read_nifti(image_path, slice_index) return image_array.repeat(3, axis=2) @override @@ -146,7 +146,7 @@ def _load_masks(self, index: int) -> np.ndarray: masks_dir = self._get_masks_dir(index) slice_index = self._get_sample_slice_index(index) mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths] + masks = [io.read_nifti(path, slice_index) for path in mask_paths] return np.concatenate(masks, axis=-1) def _get_masks_dir(self, index: int) -> str: @@ -167,7 +167,8 @@ def _get_sample_dir(self, index: int) -> str: def _get_sample_slice_index(self, index: int) -> int: """Returns the corresponding slice index.""" image_path = self._get_image_path(index) - total_slices = io.fetch_total_nifti_slices(image_path) + image_shape = io.fetch_nifti_shape(image_path) + total_slices = image_shape[-1] slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int) return slice_indices[index % self._n_slices_per_image] diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index b427583d..c66edbef 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -3,9 +3,11 @@ import functools import os from glob import glob -from typing import Callable, Dict, List, Literal, Tuple +from typing import Any, Callable, Dict, List, Literal, Tuple import numpy as np +import numpy.typing as npt +import tqdm from torchvision import tv_tensors from torchvision.datasets import utils from typing_extensions import override @@ -53,6 +55,7 @@ def __init__( download: bool = False, as_uint8: bool = True, classes: List[str] | None = None, + optimize_mask_loading: bool = True, transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -70,6 +73,10 @@ def __init__( as_uint8: Whether to convert and return the images as a 8-bit. classes: Whether to configure the dataset with a subset of classes. If `None`, it will use all of them. + optimize_mask_loading: Whether to pre-process the segmentation masks + in order to optimize the loading time. In the `setup` method, it + will reformat the binary one-hot masks to a semantic mask and store + it on disk. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. """ @@ -81,6 +88,12 @@ def __init__( self._download = download self._as_uint8 = as_uint8 self._classes = classes + self._optimize_mask_loading = optimize_mask_loading + + if self._optimize_mask_loading and self._classes is not None: + raise ValueError( + "To use customize classes please set the optimize_mask_loading to `False`." + ) self._samples_dirs: List[str] = [] self._indices: List[Tuple[int, int]] = [] @@ -123,6 +136,8 @@ def prepare_data(self) -> None: def configure(self) -> None: self._samples_dirs = self._fetch_samples_dirs() self._indices = self._create_indices() + if self._optimize_mask_loading: + self._export_semantic_label_masks() @override def validate(self) -> None: @@ -148,7 +163,7 @@ def __len__(self) -> int: def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] image_path = self._get_image_path(sample_index) - image_array = io.read_nifti_slice(image_path, slice_index) + image_array = io.read_nifti(image_path, slice_index) if self._as_uint8: image_array = convert.to_8bit(image_array) image_rgb_array = image_array.repeat(3, axis=2) @@ -156,15 +171,54 @@ def load_image(self, index: int) -> tv_tensors.Image: @override def load_mask(self, index: int) -> tv_tensors.Mask: + if self._optimize_mask_loading: + return self._load_semantic_label_mask(index) + return self._load_mask(index) + + def _load_mask(self, index: int) -> tv_tensors.Mask: + """Loads and builds the segmentation mask from NifTi files.""" sample_index, slice_index = self._indices[index] + semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index) + return tv_tensors.Mask(semantic_labels) + + def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: + """Loads the segmentation mask from a semantic label NifTi file.""" + sample_index, slice_index = self._indices[index] + masks_dir = self._get_masks_dir(sample_index) + filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") + semantic_labels = io.read_nifti(filename, slice_index) + return tv_tensors.Mask(semantic_labels.squeeze()) + + def _load_masks_as_semantic_label( + self, sample_index: int, slice_index: int | None = None + ) -> npt.NDArray[Any]: + """Loads binary masks as a semantic label mask. + + Args: + sample_index: The data sample index. + slice_index: Whether to return only a specific slice. + """ masks_dir = self._get_masks_dir(sample_index) - mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - binary_masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths] - one_hot_encoded = np.concatenate(binary_masks, axis=2) - background_mask = one_hot_encoded.sum(axis=2, keepdims=True) == 0 - one_hot_encoded_with_bg = np.concatenate([background_mask, one_hot_encoded], axis=2) - segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2) - return tv_tensors.Mask(segmentation_label) + mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes] + binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths] + background_mask = np.zeros_like(binary_masks[0]) + return np.argmax([background_mask] + binary_masks, axis=0) + + def _export_semantic_label_masks(self) -> None: + """Exports the segmentation binary masks (one-hot) to semantic labels.""" + total_samples = len(self._samples_dirs) + for sample_index in tqdm.trange( + total_samples, desc=">> Exporting optimized semantic masks" + ): + masks_dir = self._get_masks_dir(sample_index) + filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") + if os.path.isfile(filename): + continue + + semantic_labels = self._load_masks_as_semantic_label(sample_index) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + io.save_array_as_nifti(semantic_labels.astype(np.uint8), filename) def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" @@ -179,7 +233,8 @@ def _get_masks_dir(self, sample_index: int) -> str: def _get_number_of_slices_per_sample(self, sample_index: int) -> int: """Returns the total amount of slices of a sample.""" image_path = self._get_image_path(sample_index) - return io.fetch_total_nifti_slices(image_path) + image_shape = io.fetch_nifti_shape(image_path) + return image_shape[-1] def _fetch_samples_dirs(self) -> List[str]: """Returns the name of all the samples of all the splits of the dataset.""" diff --git a/src/eva/vision/utils/io/__init__.py b/src/eva/vision/utils/io/__init__.py index 85d669b1..7d2fbe53 100644 --- a/src/eva/vision/utils/io/__init__.py +++ b/src/eva/vision/utils/io/__init__.py @@ -1,12 +1,13 @@ """Vision I/O utilities.""" from eva.vision.utils.io.image import read_image -from eva.vision.utils.io.nifti import fetch_total_nifti_slices, read_nifti_slice +from eva.vision.utils.io.nifti import fetch_nifti_shape, read_nifti, save_array_as_nifti from eva.vision.utils.io.text import read_csv __all__ = [ "read_image", - "fetch_total_nifti_slices", - "read_nifti_slice", + "fetch_nifti_shape", + "read_nifti", + "save_array_as_nifti", "read_csv", ] diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 162f0a64..ee7b383b 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -1,21 +1,22 @@ """NIfTI I/O related functions.""" -from typing import Any +from typing import Any, Tuple import nibabel as nib +import numpy as np import numpy.typing as npt from eva.vision.utils.io import _utils -def read_nifti_slice( - path: str, slice_index: int, *, use_storage_dtype: bool = True +def read_nifti( + path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True ) -> npt.NDArray[Any]: - """Reads and loads a NIfTI image from a file path as `uint8`. + """Reads and loads a NIfTI image from a file path. Args: path: The path to the NIfTI file. - slice_index: The image slice index to return. + slice_index: Whether to read only a slice from the file. use_storage_dtype: Whether to cast the raw image array to the inferred type. @@ -28,21 +29,34 @@ def read_nifti_slice( """ _utils.check_file(path) image_data = nib.load(path) # type: ignore - image_slice = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore - image_array = image_slice.get_fdata() + if slice_index is not None: + image_data = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore + image_array = image_data.get_fdata() # type: ignore if use_storage_dtype: image_array = image_array.astype(image_data.get_data_dtype()) # type: ignore return image_array -def fetch_total_nifti_slices(path: str) -> int: - """Fetches the total slides of a NIfTI image file. +def save_array_as_nifti(array: npt.ArrayLike, filename: str) -> None: + """Saved a numpy array as a NIfTI image file. + + Args: + array: The image array to save. + filename: The name to save the image like. + """ + nifti_image = nib.Nifti1Image(array, affine=np.eye(4)) # type: ignore + nifti_image.header.get_xyzt_units() + nifti_image.to_filename(filename) + + +def fetch_nifti_shape(path: str) -> Tuple[int]: + """Fetches the NIfTI image shape from a file. Args: path: The path to the NIfTI file. Returns: - The number of the total available slides. + The image shape. Raises: FileExistsError: If the path does not exist or it is unreachable. @@ -50,5 +64,4 @@ def fetch_total_nifti_slices(path: str) -> int: """ _utils.check_file(path) image = nib.load(path) # type: ignore - image_shape = image.header.get_data_shape() # type: ignore - return image_shape[-1] + return image.header.get_data_shape() # type: ignore From 702f634f628a6ae26f497797e9c852cb2bb49a48 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 16 May 2024 09:54:46 +0200 Subject: [PATCH 09/21] Add per class metrics dice score in `TotalSegmentator2D` (#447) --- .../dino_vit/online/total_segmentator_2d.yaml | 15 ++++- src/eva/core/metrics/defaults/__init__.py | 13 +++- .../defaults/classification/__init__.py | 2 +- .../metrics/defaults/classification/binary.py | 9 --- .../defaults/classification/multiclass.py | 8 --- .../metrics/defaults/segmentation/__init__.py | 5 ++ .../defaults/segmentation/multiclass.py | 64 +++++++++++++++++++ src/eva/core/metrics/structs/schemas.py | 4 +- src/eva/core/metrics/wrappers/__init__.py | 5 ++ src/eva/core/metrics/wrappers/classwise.py | 24 +++++++ src/eva/core/models/modules/module.py | 26 +++++++- .../metrics/defaults/segmentation/__init__.py | 1 + .../defaults/segmentation/test_multiclass.py | 54 ++++++++++++++++ 13 files changed, 205 insertions(+), 25 deletions(-) create mode 100644 src/eva/core/metrics/defaults/segmentation/__init__.py create mode 100644 src/eva/core/metrics/defaults/segmentation/multiclass.py create mode 100644 src/eva/core/metrics/wrappers/__init__.py create mode 100644 src/eva/core/metrics/wrappers/classwise.py create mode 100644 tests/eva/core/metrics/defaults/segmentation/__init__.py create mode 100644 tests/eva/core/metrics/defaults/segmentation/test_multiclass.py diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 07822938..2e9e6ce3 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -3,7 +3,7 @@ trainer: init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/total_segmentator_2d/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} callbacks: - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: @@ -55,9 +55,20 @@ model: metrics: common: - class_path: eva.metrics.AverageLoss - - class_path: torchmetrics.Dice + - class_path: torchmetrics.classification.MulticlassF1Score init_args: num_classes: *NUM_CLASSES + evaluation: + - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: eva.core.metrics.wrappers.ClasswiseWrapper + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: *NUM_CLASSES + average: null data: class_path: eva.DataModule init_args: diff --git a/src/eva/core/metrics/defaults/__init__.py b/src/eva/core/metrics/defaults/__init__.py index a75d8286..84120acf 100644 --- a/src/eva/core/metrics/defaults/__init__.py +++ b/src/eva/core/metrics/defaults/__init__.py @@ -1,6 +1,13 @@ """Default metric collections API.""" -from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics -from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics +from eva.core.metrics.defaults.classification import ( + BinaryClassificationMetrics, + MulticlassClassificationMetrics, +) +from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics -__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"] +__all__ = [ + "MulticlassClassificationMetrics", + "BinaryClassificationMetrics", + "MulticlassSegmentationMetrics", +] diff --git a/src/eva/core/metrics/defaults/classification/__init__.py b/src/eva/core/metrics/defaults/classification/__init__.py index 638d43ab..d30fbee9 100644 --- a/src/eva/core/metrics/defaults/classification/__init__.py +++ b/src/eva/core/metrics/defaults/classification/__init__.py @@ -3,4 +3,4 @@ from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics -__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"] +__all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"] diff --git a/src/eva/core/metrics/defaults/classification/binary.py b/src/eva/core/metrics/defaults/classification/binary.py index e13d55e3..ad6d45a6 100644 --- a/src/eva/core/metrics/defaults/classification/binary.py +++ b/src/eva/core/metrics/defaults/classification/binary.py @@ -17,15 +17,6 @@ def __init__( ) -> None: """Initializes the binary classification metrics. - The metrics instantiated here are: - - - BinaryAUROC - - BinaryAccuracy - - BinaryBalancedAccuracy - - BinaryF1Score - - BinaryPrecision - - BinaryRecall - Args: threshold: Threshold for transforming probability to binary (0,1) predictions ignore_index: Specifies a target value that is ignored and does not diff --git a/src/eva/core/metrics/defaults/classification/multiclass.py b/src/eva/core/metrics/defaults/classification/multiclass.py index fba93835..09cd54f4 100644 --- a/src/eva/core/metrics/defaults/classification/multiclass.py +++ b/src/eva/core/metrics/defaults/classification/multiclass.py @@ -20,14 +20,6 @@ def __init__( ) -> None: """Initializes the multi-class classification metrics. - The metrics instantiated here are: - - - MulticlassAccuracy - - MulticlassPrecision - - MulticlassRecall - - MulticlassF1Score - - MulticlassAUROC - Args: num_classes: Integer specifying the number of classes. average: Defines the reduction that is applied over labels. diff --git a/src/eva/core/metrics/defaults/segmentation/__init__.py b/src/eva/core/metrics/defaults/segmentation/__init__.py new file mode 100644 index 00000000..31c397dd --- /dev/null +++ b/src/eva/core/metrics/defaults/segmentation/__init__.py @@ -0,0 +1,5 @@ +"""Default segmentation metric collections API.""" + +from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics + +__all__ = ["MulticlassSegmentationMetrics"] diff --git a/src/eva/core/metrics/defaults/segmentation/multiclass.py b/src/eva/core/metrics/defaults/segmentation/multiclass.py new file mode 100644 index 00000000..9433d0df --- /dev/null +++ b/src/eva/core/metrics/defaults/segmentation/multiclass.py @@ -0,0 +1,64 @@ +"""Default metric collection for multiclass semantic segmentation tasks.""" + +from typing import Literal + +from torchmetrics import classification + +from eva.core.metrics import structs + + +class MulticlassSegmentationMetrics(structs.MetricCollection): + """Default metrics for multi-class semantic segmentation tasks.""" + + def __init__( + self, + num_classes: int, + average: Literal["macro", "weighted", "none"] = "macro", + ignore_index: int | None = None, + prefix: str | None = None, + postfix: str | None = None, + ) -> None: + """Initializes the multi-class semantic segmentation metrics. + + Args: + num_classes: Integer specifying the number of classes. + average: Defines the reduction that is applied over labels. + ignore_index: Specifies a target value that is ignored and + does not contribute to the metric calculation. + prefix: A string to add before the keys in the output dictionary. + postfix: A string to add after the keys in the output dictionary. + """ + super().__init__( + metrics=[ + classification.MulticlassJaccardIndex( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + classification.MulticlassPrecision( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + classification.MulticlassRecall( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + classification.MulticlassF1Score( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + ], + prefix=prefix, + postfix=postfix, + compute_groups=[ + [ + "MulticlassJaccardIndex", + "MulticlassPrecision", + "MulticlassRecall", + "MulticlassF1Score", + ], + ], + ) diff --git a/src/eva/core/metrics/structs/schemas.py b/src/eva/core/metrics/structs/schemas.py index ef8b7fb4..a258fa45 100644 --- a/src/eva/core/metrics/structs/schemas.py +++ b/src/eva/core/metrics/structs/schemas.py @@ -44,4 +44,6 @@ def _join_with_common(self, metrics: MetricModuleType | None) -> MetricModuleTyp if metrics is None or self.common is None: return self.common or metrics - return [self.common, metrics] # type: ignore + metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore + common = self.common if isinstance(self.common, list) else [self.common] + return common + metrics # type: ignore diff --git a/src/eva/core/metrics/wrappers/__init__.py b/src/eva/core/metrics/wrappers/__init__.py new file mode 100644 index 00000000..4c4b77a3 --- /dev/null +++ b/src/eva/core/metrics/wrappers/__init__.py @@ -0,0 +1,5 @@ +"""Metric wrappers API.""" + +from eva.core.metrics.wrappers.classwise import ClasswiseWrapper + +__all__ = ["ClasswiseWrapper"] diff --git a/src/eva/core/metrics/wrappers/classwise.py b/src/eva/core/metrics/wrappers/classwise.py new file mode 100644 index 00000000..006cfd2f --- /dev/null +++ b/src/eva/core/metrics/wrappers/classwise.py @@ -0,0 +1,24 @@ +"""Wrapper metric to retrieve classwise metrics from metrics.""" + +from typing import Any + +from torchmetrics import wrappers +from typing_extensions import override + + +class ClasswiseWrapper(wrappers.ClasswiseWrapper): + """Wrapper metric for altering the output of classification metrics. + + It adds kwargs filtering during the update step for easy integration + with `MetricCollection`. + """ + + @override + def forward(self, *args: Any, **kwargs: Any) -> Any: + metric_kwargs = self.metric._filter_kwargs(**kwargs) + return self._convert(self.metric(*args, **metric_kwargs)) + + @override + def update(self, *args: Any, **kwargs: Any) -> None: + metric_kwargs = self.metric._filter_kwargs(**kwargs) + self.metric.update(*args, **metric_kwargs) diff --git a/src/eva/core/models/modules/module.py b/src/eva/core/models/modules/module.py index cb5e222a..d1e2ab64 100644 --- a/src/eva/core/models/modules/module.py +++ b/src/eva/core/models/modules/module.py @@ -4,6 +4,7 @@ import lightning.pytorch as pl import torch +from lightning.pytorch.strategies.single_device import SingleDeviceStrategy from lightning.pytorch.utilities import memory from lightning.pytorch.utilities.types import STEP_OUTPUT from typing_extensions import override @@ -46,6 +47,21 @@ def default_postprocess(self) -> batch_postprocess.BatchPostProcess: """The default post-processes.""" return batch_postprocess.BatchPostProcess() + @property + def metrics_device(self) -> torch.device: + """Returns the device by which the metrics should be calculated. + + We allocate the metrics to CPU when operating on single device, as + it is much faster, but to GPU when employing multiple ones, as DDP + strategy requires the metrics to be allocated to the module's GPU. + """ + move_to_cpu = isinstance(self.trainer.strategy, SingleDeviceStrategy) + return torch.device("cpu") if move_to_cpu else self.device + + @override + def on_fit_start(self) -> None: + self.metrics.to(device=self.metrics_device) + @override def on_train_batch_end( self, @@ -59,6 +75,10 @@ def on_train_batch_end( batch_outputs=outputs, ) + @override + def on_validation_start(self) -> None: + self.metrics.to(device=self.metrics_device) + @override def on_validation_batch_end( self, @@ -78,6 +98,10 @@ def on_validation_batch_end( def on_validation_epoch_end(self) -> None: self._compute_and_log_metrics(self.metrics.validation_metrics) + @override + def on_test_start(self) -> None: + self.metrics.to(device=self.metrics_device) + @override def on_test_batch_end( self, @@ -110,7 +134,7 @@ def _common_batch_end(self, outputs: STEP_OUTPUT) -> STEP_OUTPUT: The updated outputs. """ self._postprocess(outputs) - return memory.recursive_detach(outputs, to_cpu=self.device.type == "cpu") + return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu") def _forward_and_log_metrics( self, diff --git a/tests/eva/core/metrics/defaults/segmentation/__init__.py b/tests/eva/core/metrics/defaults/segmentation/__init__.py new file mode 100644 index 00000000..358c4946 --- /dev/null +++ b/tests/eva/core/metrics/defaults/segmentation/__init__.py @@ -0,0 +1 @@ +"""Tests default metric groups for segmentation tasks.""" diff --git a/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py b/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py new file mode 100644 index 00000000..dbb4629d --- /dev/null +++ b/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py @@ -0,0 +1,54 @@ +"""MulticlassSegmentationMetrics metric tests.""" + +import pytest +import torch + +from eva.core.metrics import defaults + +NUM_CLASSES_ONE = 5 +PREDS_ONE = torch.tensor( + [ + [0.70, 0.05, 0.05, 0.05, 0.05], + [0.05, 0.70, 0.05, 0.05, 0.05], + [0.05, 0.05, 0.70, 0.05, 0.05], + [0.05, 0.05, 0.05, 0.70, 0.05], + ] +) +TARGET_ONE = torch.tensor([0, 1, 3, 2]) +EXPECTED_ONE = { + "MulticlassJaccardIndex": torch.tensor(0.5), + "MulticlassPrecision": torch.tensor(0.0), + "MulticlassRecall": torch.tensor(0.0), + "MulticlassF1Score": torch.tensor(0.0), +} +"""Test features.""" + + +@pytest.mark.parametrize( + "num_classes, preds, target, expected", + [ + (NUM_CLASSES_ONE, PREDS_ONE, TARGET_ONE, EXPECTED_ONE), + ], +) +def test_multiclass_segmentation_metrics( + multiclass_segmentation_metrics: defaults.MulticlassSegmentationMetrics, + preds: torch.Tensor, + target: torch.Tensor, + expected: torch.Tensor, +) -> None: + """Tests the multiclass_segmentation_metrics metric.""" + + def _calculate_metric() -> None: + multiclass_segmentation_metrics.update(preds=preds, target=target) # type: ignore + actual = multiclass_segmentation_metrics.compute() + torch.testing.assert_close(actual, expected, rtol=1e-04, atol=1e-04) + + _calculate_metric() + multiclass_segmentation_metrics.reset() + _calculate_metric() + + +@pytest.fixture(scope="function") +def multiclass_segmentation_metrics(num_classes: int) -> defaults.MulticlassSegmentationMetrics: + """MulticlassSegmentationMetrics fixture.""" + return defaults.MulticlassSegmentationMetrics(num_classes=num_classes) From 434cacb9133c4e37faaec69991ca4c69ba09cb67 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 16 May 2024 14:44:56 +0200 Subject: [PATCH 10/21] Support `int16` training on `TotalSegementator2D` (#443) --- .../callbacks/loggers/batch/segmentation.py | 41 ++--------- .../segmentation/total_segmentator.py | 7 +- src/eva/vision/utils/convert.py | 69 +++++++++++++++---- src/eva/vision/utils/io/nifti.py | 2 + tests/eva/vision/utils/test_convert.py | 33 ++++----- 5 files changed, 79 insertions(+), 73 deletions(-) diff --git a/src/eva/vision/callbacks/loggers/batch/segmentation.py b/src/eva/vision/callbacks/loggers/batch/segmentation.py index 4e41ee21..47db5cef 100644 --- a/src/eva/vision/callbacks/loggers/batch/segmentation.py +++ b/src/eva/vision/callbacks/loggers/batch/segmentation.py @@ -1,19 +1,18 @@ """Segmentation datasets related data loggers.""" -from typing import Iterable, List, Tuple +from typing import List, Tuple import torch import torchvision from lightning import pytorch as pl from lightning.pytorch.utilities.types import STEP_OUTPUT -from torchvision.transforms.v2 import functional from typing_extensions import override from eva.core.loggers import log from eva.core.models.modules.typings import INPUT_TENSOR_BATCH from eva.core.utils import to_cpu from eva.vision.callbacks.loggers.batch import base -from eva.vision.utils import colormap +from eva.vision.utils import colormap, convert class SemanticSegmentationLogger(base.BatchLogger): @@ -68,7 +67,7 @@ def _log_batch( images, targets, predictions = to_cpu([images, targets, predictions]) predictions = torch.argmax(predictions, dim=1) - images = list(map(self._denormalize_image, images)) + images = list(map(self._format_image, images)) targets = list(map(_draw_semantic_mask, targets)) predictions = list(map(_draw_semantic_mask, predictions)) image_grid = _make_grid_from_image_groups( @@ -82,9 +81,9 @@ def _log_batch( step=trainer.global_step, ) - def _denormalize_image(self, image: torch.Tensor) -> torch.Tensor: - """De-normalizes an image tensor to (0., 1.) range.""" - return _denormalize_image(image, mean=self._mean, std=self._std) + def _format_image(self, image: torch.Tensor) -> torch.Tensor: + """Descaled an image tensor to (0, 255) uint8 tensor.""" + return convert.descale_and_denorm_image(image, mean=self._mean, std=self._std) def _subsample_tensors( @@ -106,34 +105,6 @@ def _subsample_tensors( return tensors_stack -def _denormalize_image( - tensor: torch.Tensor, - mean: Iterable[float] = (0.0, 0.0, 0.0), - std: Iterable[float] = (1.0, 1.0, 1.0), - inplace: bool = True, -) -> torch.Tensor: - """De-normalizes an image tensor to (0., 1.) range. - - Args: - tensor: An image float tensor. - mean: The normalized channels mean values. - std: The normalized channels std values. - inplace: Whether to perform the operation in-place. - Defaults to `True`. - - Returns: - The de-normalized image tensor of range (0., 1.). - """ - if not inplace: - tensor = tensor.clone() - - return functional.normalize( - tensor, - mean=[-cmean / cstd for cmean, cstd in zip(mean, std, strict=False)], - std=[1 / cstd for cstd in std], - ) - - def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor: """Draws a semantic mask to an image RGB tensor. diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index c66edbef..109a10fb 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -14,7 +14,7 @@ from eva.vision.data.datasets import _utils, _validators, structs from eva.vision.data.datasets.segmentation import base -from eva.vision.utils import convert, io +from eva.vision.utils import io class TotalSegmentator2D(base.ImageSegmentation): @@ -53,7 +53,6 @@ def __init__( split: Literal["train", "val"] | None, version: Literal["small", "full"] | None = "small", download: bool = False, - as_uint8: bool = True, classes: List[str] | None = None, optimize_mask_loading: bool = True, transforms: Callable | None = None, @@ -70,7 +69,6 @@ def __init__( Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not exist yet on disk. - as_uint8: Whether to convert and return the images as a 8-bit. classes: Whether to configure the dataset with a subset of classes. If `None`, it will use all of them. optimize_mask_loading: Whether to pre-process the segmentation masks @@ -86,7 +84,6 @@ def __init__( self._split = split self._version = version self._download = download - self._as_uint8 = as_uint8 self._classes = classes self._optimize_mask_loading = optimize_mask_loading @@ -164,8 +161,6 @@ def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] image_path = self._get_image_path(sample_index) image_array = io.read_nifti(image_path, slice_index) - if self._as_uint8: - image_array = convert.to_8bit(image_array) image_rgb_array = image_array.repeat(3, axis=2) return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1)) diff --git a/src/eva/vision/utils/convert.py b/src/eva/vision/utils/convert.py index cfe22f55..f013a5d7 100644 --- a/src/eva/vision/utils/convert.py +++ b/src/eva/vision/utils/convert.py @@ -1,24 +1,67 @@ """Image conversion related functionalities.""" -from typing import Any +from typing import Iterable -import numpy as np -import numpy.typing as npt +import torch +from torchvision.transforms.v2 import functional -def to_8bit(image_array: npt.NDArray[Any]) -> npt.NDArray[np.uint8]: - """Casts an image of higher bit image (i.e. 16bit) to 8bit. +def descale_and_denorm_image( + image: torch.Tensor, + mean: Iterable[float] = (0.0, 0.0, 0.0), + std: Iterable[float] = (1.0, 1.0, 1.0), + inplace: bool = True, +) -> torch.Tensor: + """De-scales and de-norms an image tensor to (0, 255) range. Args: - image_array: The image array to convert. + image: An image float tensor. + mean: The mean that the image channels are normalized with. + std: The std that the image channels are normalized with. + inplace: Whether to perform the operation in-place. Returns: - The image as normalized as a 8-bit format. + The image tensor of range (0, 255) range as uint8. """ - if np.issubdtype(image_array.dtype, np.integer): - image_array = image_array.astype(np.float64) + if not inplace: + image = image.clone() - image_scaled_array = image_array - image_array.min() - image_scaled_array /= image_scaled_array.max() - image_scaled_array *= 255 - return image_scaled_array.astype(np.uint8) + norm_image = _descale_image(image, mean=mean, std=std) + return _denorm_image(norm_image) + + +def _descale_image( + image: torch.Tensor, + mean: Iterable[float] = (0.0, 0.0, 0.0), + std: Iterable[float] = (1.0, 1.0, 1.0), +) -> torch.Tensor: + """De-scales an image tensor to (0., 1.) range. + + Args: + image: An image float tensor. + mean: The normalized channels mean values. + std: The normalized channels std values. + + Returns: + The de-normalized image tensor of range (0., 1.). + """ + return functional.normalize( + image, + mean=[-cmean / cstd for cmean, cstd in zip(mean, std, strict=False)], + std=[1 / cstd for cstd in std], + ) + + +def _denorm_image(image: torch.Tensor) -> torch.Tensor: + """De-normalizes an image tensor from (0., 1.) to (0, 255) range. + + Args: + image: An image float tensor. + + Returns: + The image tensor of range (0, 255) range as uint8. + """ + image_scaled = image - image.min() + image_scaled /= image_scaled.max() + image_scaled *= 255 + return image_scaled.to(dtype=torch.uint8) diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index ee7b383b..8ceaba07 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -31,9 +31,11 @@ def read_nifti( image_data = nib.load(path) # type: ignore if slice_index is not None: image_data = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore + image_array = image_data.get_fdata() # type: ignore if use_storage_dtype: image_array = image_array.astype(image_data.get_data_dtype()) # type: ignore + return image_array diff --git a/tests/eva/vision/utils/test_convert.py b/tests/eva/vision/utils/test_convert.py index a9eb05eb..4692b8a8 100644 --- a/tests/eva/vision/utils/test_convert.py +++ b/tests/eva/vision/utils/test_convert.py @@ -1,14 +1,11 @@ """Tests image conversion related functionalities.""" -from typing import Any - -import numpy as np -import numpy.typing as npt +import torch from pytest import mark from eva.vision.utils import convert -IMAGE_ARRAY_INT16 = np.array( +IMAGE_ONE = torch.Tensor( [ [ [-794, -339, -607, -950], @@ -17,34 +14,32 @@ [-790, -325, -564, -969], ] ], - dtype=np.int16, -) +).to(dtype=torch.float16) """Test input data.""" -EXPECTED_ARRAY_INT16 = np.array( +EXPECTED_ONE = torch.Tensor( [ [ [33, 120, 69, 3], - [68, 169, 217, 25], + [69, 169, 217, 25], [74, 183, 255, 25], [34, 123, 77, 0], ] ], - dtype=np.uint8, -) +).to(dtype=torch.uint8) """Test expected/desired features.""" @mark.parametrize( - "image_array, expected", + "image, expected", [ - [IMAGE_ARRAY_INT16, EXPECTED_ARRAY_INT16], + [IMAGE_ONE, EXPECTED_ONE], ], ) -def test_to_8bit( - image_array: npt.NDArray[Any], - expected: npt.NDArray[np.uint8], +def test_descale_and_denorm_image( + image: torch.Tensor, + expected: torch.Tensor, ) -> None: - """Tests the `to_8bit` image conversion.""" - actual = convert.to_8bit(image_array) - np.testing.assert_allclose(actual, expected) + """Tests the `descale_and_denorm_image` image conversion.""" + actual = convert.descale_and_denorm_image(image, mean=(0.0,), std=(1.0,)) + torch.testing.assert_close(actual, expected) From 35bb532040412a02cb237b5224bb9b64293438bc Mon Sep 17 00:00:00 2001 From: ioangatop Date: Wed, 22 May 2024 11:50:53 +0200 Subject: [PATCH 11/21] Normalisations and transforms for `int16` image types (#457) --- .../dino_vit/online/total_segmentator_2d.yaml | 9 +--- .../vision/callbacks/loggers/batch/base.py | 8 +-- .../callbacks/loggers/batch/segmentation.py | 2 +- .../segmentation/total_segmentator.py | 7 +-- src/eva/vision/data/transforms/__init__.py | 5 +- .../vision/data/transforms/common/__init__.py | 3 +- .../transforms/common/resize_and_clamp.py | 53 +++++++++++++++++++ .../data/transforms/common/resize_and_crop.py | 9 +--- .../data/transforms/normalization/__init__.py | 6 +++ .../data/transforms/normalization/clamp.py | 43 +++++++++++++++ .../normalization/functional/__init__.py | 5 ++ .../functional/rescale_intensity.py | 28 ++++++++++ .../normalization/rescale_intensity.py | 53 +++++++++++++++++++ src/eva/vision/utils/io/nifti.py | 10 +++- .../common/test_resize_and_clamp.py | 47 ++++++++++++++++ .../data/transforms/normalization/__init__.py | 1 + .../normalization/functional/__init__.py | 1 + .../functional/test_rescale_intensity.py | 37 +++++++++++++ 18 files changed, 299 insertions(+), 28 deletions(-) create mode 100644 src/eva/vision/data/transforms/common/resize_and_clamp.py create mode 100644 src/eva/vision/data/transforms/normalization/__init__.py create mode 100644 src/eva/vision/data/transforms/normalization/clamp.py create mode 100644 src/eva/vision/data/transforms/normalization/functional/__init__.py create mode 100644 src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py create mode 100644 src/eva/vision/data/transforms/normalization/rescale_intensity.py create mode 100644 tests/eva/vision/data/transforms/common/test_resize_and_clamp.py create mode 100644 tests/eva/vision/data/transforms/normalization/__init__.py create mode 100644 tests/eva/vision/data/transforms/normalization/functional/__init__.py create mode 100644 tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 2e9e6ce3..415c7b61 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -7,7 +7,7 @@ trainer: callbacks: - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: - log_every_n_epochs: 1 + log_every_n_steps: 1000 mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5] std: &NORMALIZE_STD [0.5, 0.5, 0.5] logger: @@ -47,11 +47,6 @@ model: init_args: total_iters: *MAX_STEPS power: 0.9 - postprocess: - targets_transforms: - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: torch.int64 metrics: common: - class_path: eva.metrics.AverageLoss @@ -84,7 +79,7 @@ data: # "Creative Commons Attribution 4.0 International" # (see: https://creativecommons.org/licenses/by/4.0/deed.en) transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop + class_path: eva.vision.data.transforms.common.ResizeAndClamp init_args: mean: *NORMALIZE_MEAN std: *NORMALIZE_STD diff --git a/src/eva/vision/callbacks/loggers/batch/base.py b/src/eva/vision/callbacks/loggers/batch/base.py index 71f1993b..d90ded83 100644 --- a/src/eva/vision/callbacks/loggers/batch/base.py +++ b/src/eva/vision/callbacks/loggers/batch/base.py @@ -17,7 +17,7 @@ class BatchLogger(pl.Callback, abc.ABC): def __init__( self, - log_every_n_epochs: int | None = 10, + log_every_n_epochs: int | None = None, log_every_n_steps: int | None = None, ) -> None: """Initializes the callback object. @@ -51,7 +51,7 @@ def on_train_batch_end( batch: INPUT_TENSOR_BATCH, batch_idx: int, ) -> None: - if self._skip_logging(trainer, batch_idx): + if self._skip_logging(trainer): return self._log_batch( @@ -101,7 +101,7 @@ def _log_batch( def _skip_logging( self, trainer: pl.Trainer, - batch_idx: int, + batch_idx: int | None = None, ) -> bool: """Determines whether skip the logging step or not. @@ -125,6 +125,6 @@ def _skip_logging( conditions = [ skip_due_frequency, not trainer.is_global_zero, - batch_idx != self._batch_idx_to_log, + batch_idx != self._batch_idx_to_log if batch_idx else False, ] return any(conditions) diff --git a/src/eva/vision/callbacks/loggers/batch/segmentation.py b/src/eva/vision/callbacks/loggers/batch/segmentation.py index 47db5cef..201486b7 100644 --- a/src/eva/vision/callbacks/loggers/batch/segmentation.py +++ b/src/eva/vision/callbacks/loggers/batch/segmentation.py @@ -24,7 +24,7 @@ def __init__( number_of_images_per_subgrid_row: int = 2, mean: Tuple[float, ...] = (0.0, 0.0, 0.0), std: Tuple[float, ...] = (1.0, 1.0, 1.0), - log_every_n_epochs: int | None = 10, + log_every_n_epochs: int | None = None, log_every_n_steps: int | None = None, ) -> None: """Initializes the callback object. diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 109a10fb..1fce8402 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -7,6 +7,7 @@ import numpy as np import numpy.typing as npt +import torch import tqdm from torchvision import tv_tensors from torchvision.datasets import utils @@ -174,7 +175,7 @@ def _load_mask(self, index: int) -> tv_tensors.Mask: """Loads and builds the segmentation mask from NifTi files.""" sample_index, slice_index = self._indices[index] semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index) - return tv_tensors.Mask(semantic_labels) + return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue] def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: """Loads the segmentation mask from a semantic label NifTi file.""" @@ -182,7 +183,7 @@ def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: masks_dir = self._get_masks_dir(sample_index) filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") semantic_labels = io.read_nifti(filename, slice_index) - return tv_tensors.Mask(semantic_labels.squeeze()) + return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] def _load_masks_as_semantic_label( self, sample_index: int, slice_index: int | None = None @@ -213,7 +214,7 @@ def _export_semantic_label_masks(self) -> None: semantic_labels = self._load_masks_as_semantic_label(sample_index) os.makedirs(os.path.dirname(filename), exist_ok=True) - io.save_array_as_nifti(semantic_labels.astype(np.uint8), filename) + io.save_array_as_nifti(semantic_labels, filename) def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" diff --git a/src/eva/vision/data/transforms/__init__.py b/src/eva/vision/data/transforms/__init__.py index 4e89dfa3..8fc222e8 100644 --- a/src/eva/vision/data/transforms/__init__.py +++ b/src/eva/vision/data/transforms/__init__.py @@ -1,5 +1,6 @@ """Vision data transforms.""" -from eva.vision.data.transforms.common import ResizeAndCrop +from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop +from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity -__all__ = ["ResizeAndCrop"] +__all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"] diff --git a/src/eva/vision/data/transforms/common/__init__.py b/src/eva/vision/data/transforms/common/__init__.py index b1a13ac5..1511cd74 100644 --- a/src/eva/vision/data/transforms/common/__init__.py +++ b/src/eva/vision/data/transforms/common/__init__.py @@ -1,5 +1,6 @@ """Common vision transforms.""" +from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop -__all__ = ["ResizeAndCrop"] +__all__ = ["ResizeAndClamp", "ResizeAndCrop"] diff --git a/src/eva/vision/data/transforms/common/resize_and_clamp.py b/src/eva/vision/data/transforms/common/resize_and_clamp.py new file mode 100644 index 00000000..bf86b847 --- /dev/null +++ b/src/eva/vision/data/transforms/common/resize_and_clamp.py @@ -0,0 +1,53 @@ +"""Specialized transforms for resizing, clamping and range normalizing.""" + +from typing import Callable, Sequence, Tuple + +import torch +import torchvision.transforms.v2 as torch_transforms + +from eva.vision.data.transforms import normalization + + +class ResizeAndClamp(torch_transforms.Compose): + """Resizes, crops, clamps and normalizes an input image.""" + + def __init__( + self, + size: int | Sequence[int] = 224, + clamp_range: Tuple[int, int] = (-1024, 1024), + mean: Sequence[float] = (0.0, 0.0, 0.0), + std: Sequence[float] = (1.0, 1.0, 1.0), + ) -> None: + """Initializes the transform object. + + Args: + size: Desired output size of the crop. If size is an `int` instead + of sequence like (h, w), a square crop (size, size) is made. + clamp_range: The lower and upper bound to clamp the pixel values. + mean: Sequence of means for each image channel. + std: Sequence of standard deviations for each image channel. + """ + self._size = size + self._clamp_range = clamp_range + self._mean = mean + self._std = std + + super().__init__(transforms=self._build_transforms()) + + def _build_transforms(self) -> Sequence[Callable]: + """Builds and returns the list of transforms.""" + transforms = [ + torch_transforms.Resize(size=self._size), + torch_transforms.CenterCrop(size=self._size), + normalization.Clamp(out_range=self._clamp_range), + torch_transforms.ToDtype(torch.float32), + normalization.RescaleIntensity( + in_range=self._clamp_range, + out_range=(0.0, 1.0), + ), + torch_transforms.Normalize( + mean=self._mean, + std=self._std, + ), + ] + return transforms diff --git a/src/eva/vision/data/transforms/common/resize_and_crop.py b/src/eva/vision/data/transforms/common/resize_and_crop.py index f5320679..f1956a66 100644 --- a/src/eva/vision/data/transforms/common/resize_and_crop.py +++ b/src/eva/vision/data/transforms/common/resize_and_crop.py @@ -4,7 +4,6 @@ import torch import torchvision.transforms.v2 as torch_transforms -from torchvision import tv_tensors class ResizeAndCrop(torch_transforms.Compose): @@ -36,13 +35,7 @@ def _build_transforms(self) -> Sequence[Callable]: torch_transforms.ToImage(), torch_transforms.Resize(size=self._size), torch_transforms.CenterCrop(size=self._size), - torch_transforms.ToDtype( - { - tv_tensors.Image: torch.float32, - tv_tensors.Mask: torch.float32, - }, - scale=True, - ), + torch_transforms.ToDtype(torch.float32, scale=True), torch_transforms.Normalize( mean=self._mean, std=self._std, diff --git a/src/eva/vision/data/transforms/normalization/__init__.py b/src/eva/vision/data/transforms/normalization/__init__.py new file mode 100644 index 00000000..d995520c --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/__init__.py @@ -0,0 +1,6 @@ +"""Normalization related transformations.""" + +from eva.vision.data.transforms.normalization.clamp import Clamp +from eva.vision.data.transforms.normalization.rescale_intensity import RescaleIntensity + +__all__ = ["Clamp", "RescaleIntensity"] diff --git a/src/eva/vision/data/transforms/normalization/clamp.py b/src/eva/vision/data/transforms/normalization/clamp.py new file mode 100644 index 00000000..6078e182 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/clamp.py @@ -0,0 +1,43 @@ +"""Image clamp transform.""" + +import functools +from typing import Any, Dict, Tuple + +import torch +import torchvision.transforms.v2 as torch_transforms +from torchvision import tv_tensors +from typing_extensions import override + + +class Clamp(torch_transforms.Transform): + """Clamps all elements in input into a specific range.""" + + def __init__(self, out_range: Tuple[int, int]) -> None: + """Initializes the transform. + + Args: + out_range: The lower and upper bound of the range to + be clamped to. + """ + super().__init__() + + self._out_range = out_range + + @functools.singledispatchmethod + @override + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + @_transform.register(torch.Tensor) + def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any: + return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1]) + + @_transform.register(tv_tensors.Image) + def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any: + inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1]) + return tv_tensors.wrap(inpt_clamp, like=inpt) + + @_transform.register(tv_tensors.BoundingBoxes) + @_transform.register(tv_tensors.Mask) + def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any: + return inpt diff --git a/src/eva/vision/data/transforms/normalization/functional/__init__.py b/src/eva/vision/data/transforms/normalization/functional/__init__.py new file mode 100644 index 00000000..0baa7632 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/functional/__init__.py @@ -0,0 +1,5 @@ +"""Functional normalization related transformations API.""" + +from eva.vision.data.transforms.normalization.functional.rescale_intensity import rescale_intensity + +__all__ = ["rescale_intensity"] diff --git a/src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py b/src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py new file mode 100644 index 00000000..f6dc70e0 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py @@ -0,0 +1,28 @@ +"""Intensity level functions.""" + +import sys +from typing import Tuple + +import torch + + +def rescale_intensity( + image: torch.Tensor, + in_range: Tuple[float, float] | None = None, + out_range: Tuple[float, float] = (0.0, 1.0), +) -> torch.Tensor: + """Stretches or shrinks the image intensity levels. + + Args: + image: The image tensor as float-type. + in_range: The input data range. If `None`, it will + fetch the min and max of the input image. + out_range: The desired intensity range of the output. + + Returns: + The image tensor after stretching or shrinking its intensity levels. + """ + imin, imax = in_range or (image.min(), image.max()) + omin, omax = out_range + image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon) + return image_scaled * (omax - omin) + omin diff --git a/src/eva/vision/data/transforms/normalization/rescale_intensity.py b/src/eva/vision/data/transforms/normalization/rescale_intensity.py new file mode 100644 index 00000000..deeea284 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/rescale_intensity.py @@ -0,0 +1,53 @@ +"""Intensity level scaling transform.""" + +import functools +from typing import Any, Dict, Tuple + +import torch +import torchvision.transforms.v2 as torch_transforms +from torchvision import tv_tensors +from typing_extensions import override + +from eva.vision.data.transforms.normalization import functional + + +class RescaleIntensity(torch_transforms.Transform): + """Stretches or shrinks the image intensity levels.""" + + def __init__( + self, + in_range: Tuple[float, float] | None = None, + out_range: Tuple[float, float] = (0.0, 1.0), + ) -> None: + """Initializes the transform. + + Args: + in_range: The input data range. If `None`, it will + fetch the min and max of the input image. + out_range: The desired intensity range of the output. + """ + super().__init__() + + self._in_range = in_range + self._out_range = out_range + + @functools.singledispatchmethod + @override + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + @_transform.register(torch.Tensor) + def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any: + return functional.rescale_intensity( + inpt, in_range=self._in_range, out_range=self._out_range + ) + + @_transform.register(tv_tensors.Image) + def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any: + scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range) + return tv_tensors.wrap(scaled_inpt, like=inpt) + + @_transform.register(tv_tensors.BoundingBoxes) + @_transform.register(tv_tensors.Mask) + def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any: + return inpt diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 8ceaba07..6859729f 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -39,14 +39,20 @@ def read_nifti( return image_array -def save_array_as_nifti(array: npt.ArrayLike, filename: str) -> None: +def save_array_as_nifti( + array: npt.ArrayLike, + filename: str, + *, + dtype: npt.DTypeLike | None = np.int64, +) -> None: """Saved a numpy array as a NIfTI image file. Args: array: The image array to save. filename: The name to save the image like. + dtype: The data type to save the image. """ - nifti_image = nib.Nifti1Image(array, affine=np.eye(4)) # type: ignore + nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype) # type: ignore nifti_image.header.get_xyzt_units() nifti_image.to_filename(filename) diff --git a/tests/eva/vision/data/transforms/common/test_resize_and_clamp.py b/tests/eva/vision/data/transforms/common/test_resize_and_clamp.py new file mode 100644 index 00000000..b3cb0887 --- /dev/null +++ b/tests/eva/vision/data/transforms/common/test_resize_and_clamp.py @@ -0,0 +1,47 @@ +"""Test the ResizeAndClamp augmentation.""" + +from typing import Tuple + +import pytest +import torch +from torch import testing +from torchvision import tv_tensors + +from eva.vision.data.transforms import common + + +@pytest.mark.parametrize( + "image_size, target_size, clamp_range, expected_size, expected_mean", + [ + ((3, 512, 224), [112, 224], [0, 255], (3, 112, 224), 0.498039186000824), + ((3, 512, 224), [112, 224], [100, 155], (3, 112, 224), 0.4909091293811798), + ((3, 224, 512), [112, 224], [0, 100], (3, 112, 224), 1.0), + ((3, 512, 224), [112, 97], [100, 255], (3, 112, 97), 0.17419354617595673), + ((3, 512, 512), 224, [0, 255], (3, 224, 224), 0.4980391561985016), + ], +) +def test_resize_and_clamp( + image_tensor: tv_tensors.Image, + resize_and_clamp: common.ResizeAndClamp, + expected_size: Tuple[int, int, int], + expected_mean: float, +) -> None: + """Tests the ResizeAndClamp transform.""" + output = resize_and_clamp(image_tensor) + assert output.shape == expected_size + testing.assert_close(torch.tensor(expected_mean), output.mean()) + + +@pytest.fixture(scope="function") +def resize_and_clamp( + target_size: Tuple[int, int, int], clamp_range: Tuple[int, int] +) -> common.ResizeAndClamp: + """Transform ResizeAndClamp fixture.""" + return common.ResizeAndClamp(size=target_size, clamp_range=clamp_range) + + +@pytest.fixture(scope="function") +def image_tensor(image_size: Tuple[int, int, int]) -> tv_tensors.Image: + """Image tensor fixture.""" + image_tensor = 127 * torch.ones(image_size, dtype=torch.uint8) + return tv_tensors.wrap(image_tensor, like=image_tensor) # type: ignore diff --git a/tests/eva/vision/data/transforms/normalization/__init__.py b/tests/eva/vision/data/transforms/normalization/__init__.py new file mode 100644 index 00000000..043d7a45 --- /dev/null +++ b/tests/eva/vision/data/transforms/normalization/__init__.py @@ -0,0 +1 @@ +"""Pixel normalization related transforms unit tests.""" diff --git a/tests/eva/vision/data/transforms/normalization/functional/__init__.py b/tests/eva/vision/data/transforms/normalization/functional/__init__.py new file mode 100644 index 00000000..d00105b3 --- /dev/null +++ b/tests/eva/vision/data/transforms/normalization/functional/__init__.py @@ -0,0 +1 @@ +"""Pixel normalization related functional transforms unit tests.""" diff --git a/tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py b/tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py new file mode 100644 index 00000000..5ee1f095 --- /dev/null +++ b/tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py @@ -0,0 +1,37 @@ +"""Test the rescale intensity transform.""" + +from typing import Tuple + +import pytest +import torch +from torch import testing +from torchvision import tv_tensors + +from eva.vision.data.transforms.normalization import functional + + +@pytest.mark.parametrize( + "image_size, in_range, out_range, expected_mean", + [ + ((3, 224, 224), (0, 255), (0.0, 1.0), 0.4980391561985016), + ((3, 224, 224), (0, 255), (-0.5, 0.5), -0.001960783964022994), + ((3, 224, 224), (100, 155), (-0.5, 0.5), -0.009090900421142578), + ((3, 224, 224), (100, 155), (0.0, 1.0), 0.4909091293811798), + ], +) +def test_rescale_intensity( + image_tensor: tv_tensors.Image, + in_range: Tuple[int, int], + out_range: Tuple[int, int], + expected_mean: float, +) -> None: + """Tests the rescale_intensity functional transform.""" + output = functional.rescale_intensity(image_tensor, in_range=in_range, out_range=out_range) + testing.assert_close(torch.tensor(expected_mean), output.mean()) + + +@pytest.fixture(scope="function") +def image_tensor(image_size: Tuple[int, int, int]) -> tv_tensors.Image: + """Image tensor fixture.""" + image_tensor = 127 * torch.ones(image_size, dtype=torch.uint8) + return tv_tensors.wrap(image_tensor, like=image_tensor) # type: ignore From c43b8ad052181b9844708279288be1d212df4567 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 6 Jun 2024 16:11:22 +0200 Subject: [PATCH 12/21] Fix default segmentation metrics (#503) --- .../defaults/segmentation/multiclass.py | 10 ++++--- .../defaults/segmentation/test_multiclass.py | 30 ++++++++++++------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/eva/core/metrics/defaults/segmentation/multiclass.py b/src/eva/core/metrics/defaults/segmentation/multiclass.py index 9433d0df..37fea48d 100644 --- a/src/eva/core/metrics/defaults/segmentation/multiclass.py +++ b/src/eva/core/metrics/defaults/segmentation/multiclass.py @@ -35,17 +35,17 @@ def __init__( average=average, ignore_index=ignore_index, ), - classification.MulticlassPrecision( + classification.MulticlassF1Score( num_classes=num_classes, average=average, ignore_index=ignore_index, ), - classification.MulticlassRecall( + classification.MulticlassPrecision( num_classes=num_classes, average=average, ignore_index=ignore_index, ), - classification.MulticlassF1Score( + classification.MulticlassRecall( num_classes=num_classes, average=average, ignore_index=ignore_index, @@ -56,9 +56,11 @@ def __init__( compute_groups=[ [ "MulticlassJaccardIndex", + ], + [ + "MulticlassF1Score", "MulticlassPrecision", "MulticlassRecall", - "MulticlassF1Score", ], ], ) diff --git a/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py b/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py index dbb4629d..1f896ba1 100644 --- a/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py +++ b/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py @@ -5,21 +5,29 @@ from eva.core.metrics import defaults -NUM_CLASSES_ONE = 5 +NUM_CLASSES_ONE = 3 PREDS_ONE = torch.tensor( [ - [0.70, 0.05, 0.05, 0.05, 0.05], - [0.05, 0.70, 0.05, 0.05, 0.05], - [0.05, 0.05, 0.70, 0.05, 0.05], - [0.05, 0.05, 0.05, 0.70, 0.05], - ] + [1, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 2, 2], + [0, 0, 2, 2], + ], +) +TARGET_ONE = torch.tensor( + [ + [1, 1, 0, 0], + [1, 1, 2, 2], + [0, 0, 2, 2], + [0, 0, 2, 2], + ], + dtype=torch.int32, ) -TARGET_ONE = torch.tensor([0, 1, 3, 2]) EXPECTED_ONE = { - "MulticlassJaccardIndex": torch.tensor(0.5), - "MulticlassPrecision": torch.tensor(0.0), - "MulticlassRecall": torch.tensor(0.0), - "MulticlassF1Score": torch.tensor(0.0), + "MulticlassJaccardIndex": torch.tensor(0.4722222089767456), + "MulticlassF1Score": torch.tensor(0.6222222447395325), + "MulticlassPrecision": torch.tensor(0.6666666865348816), + "MulticlassRecall": torch.tensor(0.611011104490899), } """Test features.""" From b009aec7a763917df9e1436dee87337f546caae2 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 6 Jun 2024 16:19:08 +0200 Subject: [PATCH 13/21] Add support for multi-level embeddings training in segmentation tasks (#501) --- .../dino_vit/online/total_segmentator_2d.yaml | 4 ++-- .../networks/decoders/segmentation/conv.py | 19 ++++++++++++++++++- .../networks/decoders/segmentation/linear.py | 19 ++++++++++++++++++- .../networks/decoders/segmentation/conv.py | 6 ++++++ .../networks/decoders/segmentation/linear.py | 6 ++++++ 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 415c7b61..ec73bb67 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -23,7 +23,7 @@ model: init_args: model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224} pretrained: true - out_indices: 1 + out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 1} model_arguments: dynamic_img_size: true decoder: @@ -32,7 +32,7 @@ model: layers: class_path: torch.nn.Conv2d init_args: - in_channels: ${oc.env:IN_FEATURES, 384} + in_channels: ${oc.env:DECODER_IN_FEATURES, 384} out_channels: &NUM_CLASSES 118 kernel_size: [1, 1] criterion: torch.nn.CrossEntropyLoss diff --git a/src/eva/vision/models/networks/decoders/segmentation/conv.py b/src/eva/vision/models/networks/decoders/segmentation/conv.py index a21749bc..4a5f6d74 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/conv.py +++ b/src/eva/vision/models/networks/decoders/segmentation/conv.py @@ -30,6 +30,14 @@ def __init__(self, layers: nn.Module) -> None: def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: """Forward function for multi-level feature maps to a single one. + It will interpolate the features and concat them into a single tensor + on the dimension axis of the hidden size. + + Example: + >>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)] + >>> output = self._forward_features(features) + >>> assert output.shape == torch.Size([16, 768, 14, 14]) + Args: features: List of multi-level image features of shape (batch_size, hidden_size, n_patches_height, n_patches_width). @@ -44,7 +52,16 @@ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." ) - return features[-1] + upsampled_features = [ + functional.interpolate( + input=embeddings, + size=features[0].shape[2:], + mode="bilinear", + align_corners=False, + ) + for embeddings in features + ] + return torch.cat(upsampled_features, dim=1) def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: """Forward of the decoder head. diff --git a/src/eva/vision/models/networks/decoders/segmentation/linear.py b/src/eva/vision/models/networks/decoders/segmentation/linear.py index 6b123cf5..75229347 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/linear.py +++ b/src/eva/vision/models/networks/decoders/segmentation/linear.py @@ -30,6 +30,14 @@ def __init__(self, layers: nn.Module) -> None: def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: """Forward function for multi-level feature maps to a single one. + It will interpolate the features and concat them into a single tensor + on the dimension axis of the hidden size. + + Example: + >>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)] + >>> output = self._forward_features(features) + >>> assert output.shape == torch.Size([16, 768, 14, 14]) + Args: features: List of multi-level image features of shape (batch_size, hidden_size, n_patches_height, n_patches_width). @@ -44,7 +52,16 @@ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." ) - return features[-1] + upsampled_features = [ + functional.interpolate( + input=embeddings, + size=features[0].shape[2:], + mode="bilinear", + align_corners=False, + ) + for embeddings in features + ] + return torch.cat(upsampled_features, dim=1) def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: """Forward of the decoder head. diff --git a/tests/eva/vision/models/networks/decoders/segmentation/conv.py b/tests/eva/vision/models/networks/decoders/segmentation/conv.py index 45cace1a..85aa9051 100644 --- a/tests/eva/vision/models/networks/decoders/segmentation/conv.py +++ b/tests/eva/vision/models/networks/decoders/segmentation/conv.py @@ -29,6 +29,12 @@ (224, 224), torch.Size([2, 5, 224, 224]), ), + ( + nn.Conv2d(768, 5, kernel_size=(1, 1)), + [torch.Tensor(2, 384, 14, 14), torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), ], ) def test_conv_decoder( diff --git a/tests/eva/vision/models/networks/decoders/segmentation/linear.py b/tests/eva/vision/models/networks/decoders/segmentation/linear.py index bcc727da..0db7005d 100644 --- a/tests/eva/vision/models/networks/decoders/segmentation/linear.py +++ b/tests/eva/vision/models/networks/decoders/segmentation/linear.py @@ -18,6 +18,12 @@ (224, 224), torch.Size([2, 5, 224, 224]), ), + ( + nn.Linear(768, 5), + [torch.Tensor(2, 384, 14, 14), torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), ], ) def test_linear_decoder( From cf54cad43fec0e16e1ff543c3c6ae64dbd03ffaa Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 6 Jun 2024 17:27:19 +0200 Subject: [PATCH 14/21] update with main --- pdm.lock | 6 +- src/eva/core/callbacks/writers/__init__.py | 2 +- src/eva/core/callbacks/writers/embeddings.py | 172 -------------- .../classification/total_segmentator.py | 214 ------------------ .../segmentation/total_segmentator.py | 2 +- src/eva/vision/utils/io/__init__.py | 3 +- 6 files changed, 5 insertions(+), 394 deletions(-) delete mode 100644 src/eva/core/callbacks/writers/embeddings.py delete mode 100644 src/eva/vision/data/datasets/classification/total_segmentator.py diff --git a/pdm.lock b/pdm.lock index cf322c09..7fe2af2e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "all", "typecheck", "lint", "vision", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:0eada25803c4ccab62b3b9d11881c46185c564f29a81541ddc3cb0c651fe254e" +content_hash = "sha256:32860036df1d64f8d69cd1a8db14dc5f2f8d1a42171ef1668207de6932ec141c" [[package]] name = "absl-py" @@ -2234,10 +2234,6 @@ dependencies = [ "torch", "torchvision", ] -files = [ - {file = "timm-1.0.3-py3-none-any.whl", hash = "sha256:d1ec86f7765aa79fbc7491508fa6e285d38a38f10bf4fe44ba2e9c70f91f0f5b"}, - {file = "timm-1.0.3.tar.gz", hash = "sha256:83920a7efe2cfd503b2a1257dc8808d6ff7dcd18a4b79f451c283e7d71497329"}, -] [[package]] name = "tokenizers" diff --git a/src/eva/core/callbacks/writers/__init__.py b/src/eva/core/callbacks/writers/__init__.py index 8d907e66..16b54b3d 100644 --- a/src/eva/core/callbacks/writers/__init__.py +++ b/src/eva/core/callbacks/writers/__init__.py @@ -1,4 +1,4 @@ -"""Callbacks API.""" +"""Writers callbacks API.""" from eva.core.callbacks.writers.embeddings import ClassificationEmbeddingsWriter diff --git a/src/eva/core/callbacks/writers/embeddings.py b/src/eva/core/callbacks/writers/embeddings.py deleted file mode 100644 index 70c462f9..00000000 --- a/src/eva/core/callbacks/writers/embeddings.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Embeddings writer.""" - -import csv -import io -import os -from typing import Any, Dict, Sequence - -import lightning.pytorch as pl -import torch -from lightning.pytorch import callbacks -from loguru import logger -from torch import multiprocessing, nn -from typing_extensions import override - -from eva.core.callbacks.writers.typings import QUEUE_ITEM -from eva.core.models.modules.typings import INPUT_BATCH -from eva.core.utils import multiprocessing as eva_multiprocessing - - -class EmbeddingsWriter(callbacks.BasePredictionWriter): - """Callback for writing generated embeddings to disk.""" - - def __init__( - self, - output_dir: str, - backbone: nn.Module | None = None, - dataloader_idx_map: Dict[int, str] | None = None, - group_key: str | None = None, - overwrite: bool = True, - ) -> None: - """Initializes a new EmbeddingsWriter instance. - - This callback writes the embedding files in a separate process - to avoid blocking the main process where the model forward pass - is executed. - - Args: - output_dir: The directory where the embeddings will be saved. - backbone: A model to be used as feature extractor. If `None`, - it will be expected that the input batch returns the - features directly. - dataloader_idx_map: A dictionary mapping dataloader indices to - their respective names (e.g. train, val, test). - group_key: The metadata key to group the embeddings by. If specified, - the embedding files will be saved in subdirectories named after - the group_key. If specified, the key must be present in the metadata - of the input batch. - overwrite: Whether to overwrite the output directory. - """ - super().__init__(write_interval="batch") - - self._output_dir = output_dir - self._backbone = backbone - self._dataloader_idx_map = dataloader_idx_map or {} - self._group_key = group_key - self._overwrite = overwrite - - self._write_queue: multiprocessing.Queue - self._write_process: eva_multiprocessing.Process - - @override - def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - os.makedirs(self._output_dir, exist_ok=self._overwrite) - self._initialize_write_process() - self._write_process.start() - - if self._backbone is not None: - self._backbone = self._backbone.to(pl_module.device) - self._backbone.eval() - - @override - def write_on_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - prediction: Any, - batch_indices: Sequence[int], - batch: INPUT_BATCH, - batch_idx: int, - dataloader_idx: int, - ) -> None: - dataset = trainer.predict_dataloaders[dataloader_idx].dataset # type: ignore - _, targets, metadata = INPUT_BATCH(*batch) - split = self._dataloader_idx_map.get(dataloader_idx) - - embeddings = self._get_embeddings(prediction) - for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]): - input_name, save_name = self._construct_save_name( - dataset.filename(global_idx), metadata, local_idx - ) - embeddings_buffer, target_buffer = io.BytesIO(), io.BytesIO() - torch.save(embeddings[local_idx].clone(), embeddings_buffer) - torch.save(targets[local_idx], target_buffer) # type: ignore - item = QUEUE_ITEM(embeddings_buffer, target_buffer, input_name, save_name, split) - self._write_queue.put(item) - - self._write_process.check_exceptions() - - @override - def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - self._write_queue.put(None) - self._write_process.join() - logger.info(f"Predictions and manifest saved to {self._output_dir}") - - def _initialize_write_process(self) -> None: - self._write_queue = multiprocessing.Queue() - self._write_process = eva_multiprocessing.Process( - target=_process_write_queue, args=(self._write_queue, self._output_dir, self._overwrite) - ) - - def _get_embeddings(self, prediction: torch.Tensor) -> torch.Tensor: - """Returns the embeddings from predictions.""" - if self._backbone is None: - return prediction - - with torch.no_grad(): - return self._backbone(prediction) - - def _construct_save_name(self, input_name, metadata, local_idx): - group_name = metadata[self._group_key][local_idx] if self._group_key else None - save_name = os.path.splitext(input_name)[0] + ".pt" - if group_name: - save_name = os.path.join(group_name, save_name) - return input_name, save_name - - -def _process_write_queue( - write_queue: multiprocessing.Queue, output_dir: str, overwrite: bool = False -) -> None: - manifest_file, manifest_writer = _init_manifest(output_dir, overwrite) - while True: - item = write_queue.get() - if item is None: - break - - prediction_buffer, target_buffer, input_name, save_name, split = QUEUE_ITEM(*item) - _save_prediction(prediction_buffer, save_name, output_dir) - _update_manifest(target_buffer, input_name, save_name, split, manifest_writer) - - manifest_file.close() - - -def _save_prediction(prediction_buffer: io.BytesIO, save_name: str, output_dir: str) -> None: - save_path = os.path.join(output_dir, save_name) - prediction = torch.load(io.BytesIO(prediction_buffer.getbuffer()), map_location="cpu") - os.makedirs(os.path.dirname(save_path), exist_ok=True) - torch.save(prediction, save_path) - - -def _init_manifest(output_dir: str, overwrite: bool = False) -> tuple[io.TextIOWrapper, Any]: - manifest_path = os.path.join(output_dir, "manifest.csv") - if os.path.exists(manifest_path) and not overwrite: - raise FileExistsError( - f"Manifest file already exists at {manifest_path}. This likely means that the " - "embeddings have been computed before. Consider using `eva fit` instead " - "of `eva predict_fit` or `eva predict`." - ) - manifest_file = open(manifest_path, "w", newline="") - manifest_writer = csv.writer(manifest_file) - manifest_writer.writerow(["origin", "embeddings", "target", "split"]) - return manifest_file, manifest_writer - - -def _update_manifest( - target_buffer: io.BytesIO, - input_name: str, - save_name: str, - split: str | None, - manifest_writer, -) -> None: - target = torch.load(io.BytesIO(target_buffer.getbuffer()), map_location="cpu") - manifest_writer.writerow([input_name, save_name, target.item(), split]) diff --git a/src/eva/vision/data/datasets/classification/total_segmentator.py b/src/eva/vision/data/datasets/classification/total_segmentator.py deleted file mode 100644 index 58b98b9d..00000000 --- a/src/eva/vision/data/datasets/classification/total_segmentator.py +++ /dev/null @@ -1,214 +0,0 @@ -"""TotalSegmentator 2D segmentation dataset class.""" - -import functools -import os -from glob import glob -from typing import Callable, Dict, List, Literal, Tuple - -import numpy as np -from torchvision.datasets import utils -from typing_extensions import override - -from eva.vision.data.datasets import _utils, _validators, structs -from eva.vision.data.datasets.classification import base -from eva.vision.utils import io - - -class TotalSegmentatorClassification(base.ImageClassification): - """TotalSegmentator multi-label classification dataset.""" - - _train_index_ranges: List[Tuple[int, int]] = [(0, 83)] - """Train range indices.""" - - _val_index_ranges: List[Tuple[int, int]] = [(83, 103)] - """Validation range indices.""" - - _n_slices_per_image: int = 20 - """The amount of slices to sample per 3D CT scan image.""" - - _resources_full: List[structs.DownloadResource] = [ - structs.DownloadResource( - filename="Totalsegmentator_dataset_v201.zip", - url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip", - md5="fe250e5718e0a3b5df4c4ea9d58a62fe", - ), - ] - """Resources for the full dataset version.""" - - _resources_small: List[structs.DownloadResource] = [ - structs.DownloadResource( - filename="Totalsegmentator_dataset_small_v201.zip", - url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip", - md5="6b5524af4b15e6ba06ef2d700c0c73e0", - ), - ] - """Resources for the small dataset version.""" - - def __init__( - self, - root: str, - split: Literal["train", "val"] | None, - version: Literal["small", "full"] = "small", - download: bool = False, - image_transforms: Callable | None = None, - target_transforms: Callable | None = None, - ) -> None: - """Initialize dataset. - - Args: - root: Path to the root directory of the dataset. The dataset will - be downloaded and extracted here, if it does not already exist. - split: Dataset split to use. If None, the entire dataset is used. - version: The version of the dataset to initialize. - download: Whether to download the data for the specified split. - Note that the download will be executed only by additionally - calling the :meth:`prepare_data` method and if the data does not - exist yet on disk. - image_transforms: A function/transform that takes in an image - and returns a transformed version. - target_transforms: A function/transform that takes in the target - and transforms it. - """ - super().__init__( - image_transforms=image_transforms, - target_transforms=target_transforms, - ) - - self._root = root - self._split = split - self._version = version - self._download = download - - self._samples_dirs: List[str] = [] - self._indices: List[int] = [] - - @functools.cached_property - @override - def classes(self) -> List[str]: - def get_filename(path: str) -> str: - """Returns the filename from the full path.""" - return os.path.basename(path).split(".")[0] - - first_sample_labels = os.path.join( - self._root, self._samples_dirs[0], "segmentations", "*.nii.gz" - ) - return sorted(map(get_filename, glob(first_sample_labels))) - - @property - @override - def class_to_idx(self) -> Dict[str, int]: - return {label: index for index, label in enumerate(self.classes)} - - @override - def filename(self, index: int) -> str: - sample_dir = self._samples_dirs[self._indices[index]] - return os.path.join(sample_dir, "ct.nii.gz") - - @override - def prepare_data(self) -> None: - if self._download: - self._download_dataset() - _validators.check_dataset_exists(self._root, True) - - @override - def configure(self) -> None: - self._samples_dirs = self._fetch_samples_dirs() - self._indices = self._create_indices() - - @override - def validate(self) -> None: - _validators.check_dataset_integrity( - self, - length=1660 if self._split == "train" else 400, - n_classes=117, - first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"), - ) - - @override - def __len__(self) -> int: - return len(self._indices) * self._n_slices_per_image - - @override - def load_image(self, index: int) -> np.ndarray: - image_path = self._get_image_path(index) - slice_index = self._get_sample_slice_index(index) - image_array = io.read_nifti(image_path, slice_index) - return image_array.repeat(3, axis=2) - - @override - def load_target(self, index: int) -> np.ndarray: - masks = self._load_masks(index) - targets = [1 in masks[..., mask_index] for mask_index in range(masks.shape[-1])] - return np.asarray(targets, dtype=np.int64) - - def _load_masks(self, index: int) -> np.ndarray: - """Returns the `index`'th target mask sample.""" - masks_dir = self._get_masks_dir(index) - slice_index = self._get_sample_slice_index(index) - mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - masks = [io.read_nifti(path, slice_index) for path in mask_paths] - return np.concatenate(masks, axis=-1) - - def _get_masks_dir(self, index: int) -> str: - """Returns the directory of the corresponding masks.""" - sample_dir = self._get_sample_dir(index) - return os.path.join(self._root, sample_dir, "segmentations") - - def _get_image_path(self, index: int) -> str: - """Returns the corresponding image path.""" - sample_dir = self._get_sample_dir(index) - return os.path.join(self._root, sample_dir, "ct.nii.gz") - - def _get_sample_dir(self, index: int) -> str: - """Returns the corresponding sample directory.""" - sample_index = self._indices[index // self._n_slices_per_image] - return self._samples_dirs[sample_index] - - def _get_sample_slice_index(self, index: int) -> int: - """Returns the corresponding slice index.""" - image_path = self._get_image_path(index) - image_shape = io.fetch_nifti_shape(image_path) - total_slices = image_shape[-1] - slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int) - return slice_indices[index % self._n_slices_per_image] - - def _fetch_samples_dirs(self) -> List[str]: - """Returns the name of all the samples of all the splits of the dataset.""" - sample_filenames = [ - filename - for filename in os.listdir(self._root) - if os.path.isdir(os.path.join(self._root, filename)) - ] - return sorted(sample_filenames) - - def _create_indices(self) -> List[int]: - """Builds the dataset indices for the specified split.""" - split_index_ranges = { - "train": self._train_index_ranges, - "val": self._val_index_ranges, - None: [(0, 103)], - } - index_ranges = split_index_ranges.get(self._split) - if index_ranges is None: - raise ValueError("Invalid data split. Use 'train', 'val' or `None`.") - - return _utils.ranges_to_indices(index_ranges) - - def _download_dataset(self) -> None: - """Downloads the dataset.""" - dataset_resources = { - "small": self._resources_small, - "full": self._resources_full, - None: (0, 103), - } - resources = dataset_resources.get(self._version) - if resources is None: - raise ValueError("Invalid data version. Use 'small' or 'full'.") - - for resource in resources: - utils.download_and_extract_archive( - resource.url, - download_root=self._root, - filename=resource.filename, - remove_finished=True, - ) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index af38a2a8..1fce8402 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -15,7 +15,7 @@ from eva.vision.data.datasets import _utils, _validators, structs from eva.vision.data.datasets.segmentation import base -from eva.vision.utils import convert, io +from eva.vision.utils import io class TotalSegmentator2D(base.ImageSegmentation): diff --git a/src/eva/vision/utils/io/__init__.py b/src/eva/vision/utils/io/__init__.py index 7d2fbe53..ef2161ac 100644 --- a/src/eva/vision/utils/io/__init__.py +++ b/src/eva/vision/utils/io/__init__.py @@ -1,11 +1,12 @@ """Vision I/O utilities.""" -from eva.vision.utils.io.image import read_image +from eva.vision.utils.io.image import read_image, read_image_as_tensor from eva.vision.utils.io.nifti import fetch_nifti_shape, read_nifti, save_array_as_nifti from eva.vision.utils.io.text import read_csv __all__ = [ "read_image", + "read_image_as_tensor", "fetch_nifti_shape", "read_nifti", "save_array_as_nifti", From cfa47c922fe86bfea28fc31436fc2299927370b7 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 10 Jun 2024 14:56:38 +0200 Subject: [PATCH 15/21] Add dataset licence and env var for dataset download (#516) --- .../vision/dino_vit/online/total_segmentator_2d.yaml | 2 +- .../data/datasets/segmentation/total_segmentator.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index ec73bb67..04a60839 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -73,7 +73,7 @@ data: init_args: &DATASET_ARGS root: ${oc.env:DATA_ROOT, ./data}/total_segmentator split: train - download: false + download: ${oc.env:DOWNLOAD, false} # Set `download: true` to download the dataset from https://zenodo.org/records/10047292 # The TotalSegmentator dataset is distributed under the following license: # "Creative Commons Attribution 4.0 International" diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 1fce8402..4b5e059b 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -48,6 +48,12 @@ class TotalSegmentator2D(base.ImageSegmentation): ] """Resources for the small dataset version.""" + _license: str = ( + "Creative Commons Attribution 4.0 International " + "(https://creativecommons.org/licenses/by/4.0/deed.en)" + ) + """Dataset license.""" + def __init__( self, root: str, @@ -282,6 +288,7 @@ def _download_dataset(self) -> None: f"Can't download data version '{self._version}'. Use 'small' or 'full'." ) + self._print_license() for resource in resources: if os.path.isdir(self._root): continue @@ -292,3 +299,7 @@ def _download_dataset(self) -> None: filename=resource.filename, remove_finished=True, ) + + def _print_license(self) -> None: + """Prints the dataset license.""" + print(f"Dataset license: {self._license}") From 6fefade95a1c4ba1b9e644a40daf0f48508a328b Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 10 Jun 2024 16:42:42 +0200 Subject: [PATCH 16/21] Minor updates on semantic segmentation tasks (#519) --- .../dino_vit/online/total_segmentator_2d.yaml | 13 +--- .../segmentation/total_segmentator.py | 23 +++--- .../decoders/segmentation/__init__.py | 9 ++- .../networks/decoders/segmentation/common.py | 74 +++++++++++++++++++ .../segmentation/{conv.py => conv2d.py} | 0 .../networks/decoders/segmentation/conv.py | 13 ++++ .../networks/decoders/segmentation/linear.py | 13 ++++ tests/eva/vision/test_vision_cli.py | 1 + 8 files changed, 124 insertions(+), 22 deletions(-) create mode 100644 src/eva/vision/models/networks/decoders/segmentation/common.py rename src/eva/vision/models/networks/decoders/segmentation/{conv.py => conv2d.py} (100%) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 04a60839..a647a975 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -27,14 +27,10 @@ model: model_arguments: dynamic_img_size: true decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoder + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS init_args: - layers: - class_path: torch.nn.Conv2d - init_args: - in_channels: ${oc.env:DECODER_IN_FEATURES, 384} - out_channels: &NUM_CLASSES 118 - kernel_size: [1, 1] + in_features: ${oc.env:DECODER_IN_FEATURES, 384} + num_classes: &NUM_CLASSES 118 criterion: torch.nn.CrossEntropyLoss lr_multiplier_encoder: 0.0 optimizer: @@ -50,9 +46,6 @@ model: metrics: common: - class_path: eva.metrics.AverageLoss - - class_path: torchmetrics.classification.MulticlassF1Score - init_args: - num_classes: *NUM_CLASSES evaluation: - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics init_args: diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 4b5e059b..e055f88f 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -127,9 +127,9 @@ def class_to_idx(self) -> Dict[str, int]: @override def filename(self, index: int) -> str: - sample_idx, _ = self._indices[index] + sample_idx, slice_index = self._indices[index] sample_dir = self._samples_dirs[sample_idx] - return os.path.join(sample_dir, "ct.nii.gz") + return os.path.join(sample_dir, f"{slice_index}-ct.nii.gz") @override def prepare_data(self) -> None: @@ -209,16 +209,19 @@ def _load_masks_as_semantic_label( def _export_semantic_label_masks(self) -> None: """Exports the segmentation binary masks (one-hot) to semantic labels.""" total_samples = len(self._samples_dirs) - for sample_index in tqdm.trange( - total_samples, desc=">> Exporting optimized semantic masks" - ): - masks_dir = self._get_masks_dir(sample_index) - filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") - if os.path.isfile(filename): - continue + masks_dirs = map(self._get_masks_dir, range(total_samples)) + semantic_labels = [ + (index, os.path.join(directory, "semantic_labels", "masks.nii.gz")) + for index, directory in enumerate(masks_dirs) + ] + to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels) + for sample_index, filename in tqdm.tqdm( + to_export, + desc=">> Exporting optimized semantic masks", + leave=False, + ): semantic_labels = self._load_masks_as_semantic_label(sample_index) - os.makedirs(os.path.dirname(filename), exist_ok=True) io.save_array_as_nifti(semantic_labels, filename) diff --git a/src/eva/vision/models/networks/decoders/segmentation/__init__.py b/src/eva/vision/models/networks/decoders/segmentation/__init__.py index 8a5f014f..e7417d96 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/__init__.py +++ b/src/eva/vision/models/networks/decoders/segmentation/__init__.py @@ -1,6 +1,11 @@ """Segmentation decoder heads API.""" -from eva.vision.models.networks.decoders.segmentation.conv import ConvDecoder +from eva.vision.models.networks.decoders.segmentation.common import ( + ConvDecoder1x1, + ConvDecoderMS, + SingleLinearDecoder, +) +from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder -__all__ = ["ConvDecoder", "LinearDecoder"] +__all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"] diff --git a/src/eva/vision/models/networks/decoders/segmentation/common.py b/src/eva/vision/models/networks/decoders/segmentation/common.py new file mode 100644 index 00000000..7f612eb0 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/common.py @@ -0,0 +1,74 @@ +"""Common semantic segmentation decoders. + +This module contains implementations of different types of decoder models +used in semantic segmentation. These decoders convert the high-level features +output by an encoder into pixel-wise predictions for segmentation tasks. +""" + +from torch import nn + +from eva.vision.models.networks.decoders.segmentation import conv2d, linear + + +class ConvDecoder1x1(conv2d.ConvDecoder): + """A convolutional decoder with a single 1x1 convolutional layer.""" + + def __init__(self, in_features: int, num_classes: int) -> None: + """Initializes the decoder. + + Args: + in_features: The hidden dimension size of the embeddings. + num_classes: Number of output classes as channels. + """ + super().__init__( + layers=nn.Conv2d( + in_channels=in_features, + out_channels=num_classes, + kernel_size=(1, 1), + ), + ) + + +class ConvDecoderMS(conv2d.ConvDecoder): + """A multi-stage convolutional decoder with upsampling and convolutional layers. + + This decoder applies a series of upsampling and convolutional layers to transform + the input features into output predictions with the desired spatial resolution. + + This decoder is based on the `+ms` segmentation decoder from DINOv2 + (https://arxiv.org/pdf/2304.07193) + """ + + def __init__(self, in_features: int, num_classes: int) -> None: + """Initializes the decoder. + + Args: + in_features: The hidden dimension size of the embeddings. + num_classes: Number of output classes as channels. + """ + super().__init__( + layers=nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(in_features, 64, kernel_size=(3, 3), padding=(1, 1)), + nn.Upsample(scale_factor=2), + nn.Conv2d(64, num_classes, kernel_size=(3, 3), padding=(1, 1)), + ), + ) + + +class SingleLinearDecoder(linear.LinearDecoder): + """A simple linear decoder with a single fully connected layer.""" + + def __init__(self, in_features: int, num_classes: int) -> None: + """Initializes the decoder. + + Args: + in_features: The hidden dimension size of the embeddings. + num_classes: Number of output classes as channels. + """ + super().__init__( + layers=nn.Linear( + in_features=in_features, + out_features=num_classes, + ), + ) diff --git a/src/eva/vision/models/networks/decoders/segmentation/conv.py b/src/eva/vision/models/networks/decoders/segmentation/conv2d.py similarity index 100% rename from src/eva/vision/models/networks/decoders/segmentation/conv.py rename to src/eva/vision/models/networks/decoders/segmentation/conv2d.py diff --git a/tests/eva/vision/models/networks/decoders/segmentation/conv.py b/tests/eva/vision/models/networks/decoders/segmentation/conv.py index 85aa9051..59a27cc2 100644 --- a/tests/eva/vision/models/networks/decoders/segmentation/conv.py +++ b/tests/eva/vision/models/networks/decoders/segmentation/conv.py @@ -7,6 +7,7 @@ from torch import nn from eva.vision.models.networks.decoders import segmentation +from eva.vision.models.networks.decoders.segmentation import common @pytest.mark.parametrize( @@ -35,6 +36,18 @@ (224, 224), torch.Size([2, 5, 224, 224]), ), + ( + common.ConvDecoder1x1(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + common.ConvDecoderMS(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), ], ) def test_conv_decoder( diff --git a/tests/eva/vision/models/networks/decoders/segmentation/linear.py b/tests/eva/vision/models/networks/decoders/segmentation/linear.py index 0db7005d..156d5c49 100644 --- a/tests/eva/vision/models/networks/decoders/segmentation/linear.py +++ b/tests/eva/vision/models/networks/decoders/segmentation/linear.py @@ -7,6 +7,7 @@ from torch import nn from eva.vision.models.networks.decoders import segmentation +from eva.vision.models.networks.decoders.segmentation import common @pytest.mark.parametrize( @@ -24,6 +25,18 @@ (224, 224), torch.Size([2, 5, 224, 224]), ), + ( + common.SingleLinearDecoder(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + common.SingleLinearDecoder(768, 5), + [torch.Tensor(2, 384, 14, 14), torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), ], ) def test_linear_decoder( diff --git a/tests/eva/vision/test_vision_cli.py b/tests/eva/vision/test_vision_cli.py index a1c69963..8b315964 100644 --- a/tests/eva/vision/test_vision_cli.py +++ b/tests/eva/vision/test_vision_cli.py @@ -17,6 +17,7 @@ "configs/vision/dino_vit/online/crc.yaml", "configs/vision/dino_vit/online/mhist.yaml", "configs/vision/dino_vit/online/patch_camelyon.yaml", + "configs/vision/dino_vit/online/total_segmentator_2d.yaml", "configs/vision/dino_vit/offline/bach.yaml", "configs/vision/dino_vit/offline/crc.yaml", "configs/vision/dino_vit/offline/mhist.yaml", From 3b326a047e5b332fcbb4a69fbd6ebb9ceffa072a Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 10 Jun 2024 17:20:27 +0200 Subject: [PATCH 17/21] Fix SemanticSegmentationLogger callback write frequency (#521) --- .../vision/dino_vit/online/total_segmentator_2d.yaml | 2 +- src/eva/vision/callbacks/loggers/batch/base.py | 2 +- src/eva/vision/models/networks/encoders/from_timm.py | 11 +++++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index a647a975..555e4929 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -7,7 +7,7 @@ trainer: callbacks: - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: - log_every_n_steps: 1000 + log_every_n_epochs: 1 mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5] std: &NORMALIZE_STD [0.5, 0.5, 0.5] logger: diff --git a/src/eva/vision/callbacks/loggers/batch/base.py b/src/eva/vision/callbacks/loggers/batch/base.py index d90ded83..311b39e0 100644 --- a/src/eva/vision/callbacks/loggers/batch/base.py +++ b/src/eva/vision/callbacks/loggers/batch/base.py @@ -51,7 +51,7 @@ def on_train_batch_end( batch: INPUT_TENSOR_BATCH, batch_idx: int, ) -> None: - if self._skip_logging(trainer): + if self._skip_logging(trainer, batch_idx if self._log_every_n_epochs else None): return self._log_batch( diff --git a/src/eva/vision/models/networks/encoders/from_timm.py b/src/eva/vision/models/networks/encoders/from_timm.py index 8ac1f66a..bf0b4084 100644 --- a/src/eva/vision/models/networks/encoders/from_timm.py +++ b/src/eva/vision/models/networks/encoders/from_timm.py @@ -43,11 +43,13 @@ def __init__( self._out_indices = out_indices self._model_arguments = model_arguments or {} - self._feature_extractor = self._load_model() + self._feature_extractor: nn.Module - def _load_model(self) -> nn.Module: - """Builds, loads and returns the timm model as feature extractor.""" - return timm.create_model( + self.configure_model() + + def configure_model(self) -> None: + """Builds and loads the timm model as feature extractor.""" + self._feature_extractor = timm.create_model( model_name=self._model_name, pretrained=self._pretrained, checkpoint_path=self._checkpoint_path, @@ -55,6 +57,7 @@ def _load_model(self) -> nn.Module: features_only=True, **self._model_arguments, ) + TimmEncoder.__name__ = self._model_name @override def forward(self, tensor: torch.Tensor) -> List[torch.Tensor]: From fb948a0312cae2d0e6f38c891c29be932d449206 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 10 Jun 2024 17:32:08 +0200 Subject: [PATCH 18/21] Add `ModelCheckpoint` and `EarlyStopping` in TotalSegmentator2D` task (#523) --- .../dino_vit/online/total_segmentator_2d.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 555e4929..334384b9 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -10,6 +10,19 @@ trainer: log_every_n_epochs: 1 mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5] std: &NORMALIZE_STD [0.5, 0.5, 0.5] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassJaccardIndex} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 5 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: From 4ac910ed4086e8ad0dcc82f91565e0d683845f8a Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 10 Jun 2024 18:04:29 +0200 Subject: [PATCH 19/21] Minor fixes in `SemanticSegmentationModule` (#525) --- .../models/modules/semantic_segmentation.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/eva/vision/models/modules/semantic_segmentation.py b/src/eva/vision/models/modules/semantic_segmentation.py index f058bc52..5b2f1848 100644 --- a/src/eva/vision/models/modules/semantic_segmentation.py +++ b/src/eva/vision/models/modules/semantic_segmentation.py @@ -76,25 +76,25 @@ def configure_optimizers(self) -> Any: @override def forward( self, - tensor: torch.Tensor, - image_size: Tuple[int, int] | None = None, + inputs: torch.Tensor, + to_size: Tuple[int, int] | None = None, *args: Any, **kwargs: Any, ) -> torch.Tensor: """Maps the input tensor (image tensor or embeddings) to masks. - If `tensor` is image tensor, then the `self.encoder` + If `inputs` is image tensor, then the `self.encoder` should be implemented, otherwise it will be interpreted - as embeddings, where the `image_size` should be given. + as embeddings, where the `to_size` should be given. """ - if self.encoder is None and image_size is None: + if self.encoder is None and to_size is None: raise ValueError( - "Please provide the expected `image_size` that the " - "decoder should map the embeddings (`tensor`) to." + "Please provide the expected `to_size` that the " + "decoder should map the embeddings (`inputs`) to." ) - patch_embeddings = self.encoder(tensor) if self.encoder else tensor - return self.decoder(patch_embeddings, image_size or tensor.shape[-2:]) + patch_embeddings = self.encoder(inputs) if self.encoder else inputs + return self.decoder(patch_embeddings, to_size or inputs.shape[-2:]) @override def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @@ -111,7 +111,7 @@ def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STE @override def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor: tensor = INPUT_BATCH(*batch).data - return tensor if self.backbone is None else self.backbone(tensor) + return tensor if self.encoder is None else self.encoder(tensor) @property def _base_lr(self) -> float: @@ -128,7 +128,7 @@ def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]: ) def _freeze_encoder(self) -> None: - """If initialized, Freezes the encoder network.""" + """If initialized, it freezes the encoder network.""" if self.encoder is not None and self.lr_multiplier_encoder == 0: grad.deactivate_requires_grad(self.encoder) @@ -142,7 +142,7 @@ def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT: The batch step output. """ data, targets, metadata = INPUT_TENSOR_BATCH(*batch) - predictions = self(data, targets.shape[-2:]) + predictions = self(data, to_size=targets.shape[-2:]) loss = self.criterion(predictions, targets) return { "loss": loss, From 6a00235960e4e56d64a649b581d25f6b974e22cd Mon Sep 17 00:00:00 2001 From: ioangatop Date: Wed, 12 Jun 2024 17:13:28 +0200 Subject: [PATCH 20/21] Support the full `TotalSegmentator2D` dataset (#535) --- .../dino_vit/online/total_segmentator_2d.yaml | 7 +++- .../data/datasets/segmentation/__init__.py | 2 +- ...segmentator.py => total_segmentator_2d.py} | 37 +++++++++++-------- src/eva/vision/utils/io/text.py | 13 +++++-- .../Totalsegmentator_dataset_v201/meta.csv | 3 ++ .../segmentation/test_total_segmentator.py | 2 +- 6 files changed, 43 insertions(+), 21 deletions(-) rename src/eva/vision/data/datasets/segmentation/{total_segmentator.py => total_segmentator_2d.py} (91%) create mode 100644 tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 334384b9..4d1c9da9 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -79,7 +79,7 @@ data: init_args: &DATASET_ARGS root: ${oc.env:DATA_ROOT, ./data}/total_segmentator split: train - download: ${oc.env:DOWNLOAD, false} + download: ${oc.env:DOWNLOAD_DATA, false} # Set `download: true` to download the dataset from https://zenodo.org/records/10047292 # The TotalSegmentator dataset is distributed under the following license: # "Creative Commons Attribution 4.0 International" @@ -94,6 +94,11 @@ data: init_args: <<: *DATASET_ARGS split: val + test: + class_path: eva.vision.datasets.TotalSegmentator2D + init_args: + <<: *DATASET_ARGS + split: test dataloaders: train: batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 16} diff --git a/src/eva/vision/data/datasets/segmentation/__init__.py b/src/eva/vision/data/datasets/segmentation/__init__.py index 4e5a3fc0..3cc60ce7 100644 --- a/src/eva/vision/data/datasets/segmentation/__init__.py +++ b/src/eva/vision/data/datasets/segmentation/__init__.py @@ -1,6 +1,6 @@ """Segmentation datasets API.""" from eva.vision.data.datasets.segmentation.base import ImageSegmentation -from eva.vision.data.datasets.segmentation.total_segmentator import TotalSegmentator2D +from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D __all__ = ["ImageSegmentation", "TotalSegmentator2D"] diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py similarity index 91% rename from src/eva/vision/data/datasets/segmentation/total_segmentator.py rename to src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index e055f88f..3830b468 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -13,7 +13,7 @@ from torchvision.datasets import utils from typing_extensions import override -from eva.vision.data.datasets import _utils, _validators, structs +from eva.vision.data.datasets import _validators, structs from eva.vision.data.datasets.segmentation import base from eva.vision.utils import io @@ -22,8 +22,11 @@ class TotalSegmentator2D(base.ImageSegmentation): """TotalSegmentator 2D segmentation dataset.""" _expected_dataset_lengths: Dict[str, int] = { - "train_small": 29892, - "val_small": 6480, + "train_small": 35089, + "val_small": 1283, + "train_full": 278190, + "val_full": 14095, + "test_full": 25578, } """Dataset version and split to the expected size.""" @@ -57,8 +60,8 @@ class TotalSegmentator2D(base.ImageSegmentation): def __init__( self, root: str, - split: Literal["train", "val"] | None, - version: Literal["small", "full"] | None = "small", + split: Literal["train", "val", "test"] | None, + version: Literal["small", "full"] | None = "full", download: bool = False, classes: List[str] | None = None, optimize_mask_loading: bool = True, @@ -145,7 +148,7 @@ def configure(self) -> None: @override def validate(self) -> None: - if self._version is None: + if self._version is None or self._sample_every_n_slices is not None: return _validators.check_dataset_integrity( @@ -217,7 +220,7 @@ def _export_semantic_label_masks(self) -> None: to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels) for sample_index, filename in tqdm.tqdm( - to_export, + list(to_export), desc=">> Exporting optimized semantic masks", leave=False, ): @@ -252,16 +255,20 @@ def _fetch_samples_dirs(self) -> List[str]: def _get_split_indices(self) -> List[int]: """Returns the samples indices that corresponding the dataset split and version.""" - key = f"{self._split}_{self._version}" - match key: - case "train_small": - index_ranges = [(0, 83)] - case "val_small": - index_ranges = [(83, 102)] + metadata_file = os.path.join(self._root, "meta.csv") + metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig") + + match self._split: + case "train": + image_ids = [item["image_id"] for item in metadata if item["split"] == "train"] + case "val": + image_ids = [item["image_id"] for item in metadata if item["split"] == "val"] + case "test": + image_ids = [item["image_id"] for item in metadata if item["split"] == "test"] case _: - index_ranges = [(0, len(self._samples_dirs))] + image_ids = self._samples_dirs - return _utils.ranges_to_indices(index_ranges) + return sorted(map(self._samples_dirs.index, image_ids)) def _create_indices(self) -> List[Tuple[int, int]]: """Builds the dataset indices for the specified split. diff --git a/src/eva/vision/utils/io/text.py b/src/eva/vision/utils/io/text.py index 7f6089dd..34120238 100644 --- a/src/eva/vision/utils/io/text.py +++ b/src/eva/vision/utils/io/text.py @@ -4,15 +4,22 @@ from typing import Dict, List -def read_csv(path: str) -> List[Dict[str, str]]: +def read_csv( + path: str, + *, + delimiter: str = ",", + encoding: str = "utf-8", +) -> List[Dict[str, str]]: """Reads a CSV file and returns its contents as a list of dictionaries. Args: path: The path to the CSV file. + delimiter: The character that separates fields in the CSV file. + encoding: The encoding of the CSV file. Returns: A list of dictionaries representing the data in the CSV file. """ - with open(path, newline="") as file: - data = csv.DictReader(file, skipinitialspace=True) + with open(path, newline="", encoding=encoding) as file: + data = csv.DictReader(file, skipinitialspace=True, delimiter=delimiter) return list(data) diff --git a/tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv b/tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv new file mode 100644 index 00000000..557084d8 --- /dev/null +++ b/tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81f0de47859c10a45bfd02550b4d37e5c3c9c954b44657d0f340e665ce8b2cc5 +size 399 diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 9607a2a8..a2845b2a 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "split, expected_length", - [("train", 9), ("val", 9), (None, 9)], + [("train", 6), ("val", 3), (None, 9)], ) def test_length( total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int From 3d523ca4beb0039edebcbca03bf295618ac7b7a1 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Wed, 12 Jun 2024 17:20:39 +0200 Subject: [PATCH 21/21] Add test dataloader in TotalSegmentator2D --- configs/vision/dino_vit/online/total_segmentator_2d.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml index 4d1c9da9..f4cdc68f 100644 --- a/configs/vision/dino_vit/online/total_segmentator_2d.yaml +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -105,3 +105,5 @@ data: shuffle: true val: batch_size: *BATCH_SIZE + test: + batch_size: *BATCH_SIZE