diff --git a/.gitignore b/.gitignore index 7ca6b17d..da0b5aae 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,9 @@ cython_debug/ # ignore local data /data/ + +# numpy data +*.npy + +# NiFti data +*.nii.gz diff --git a/configs/vision/dino_vit/online/total_segmentator_2d.yaml b/configs/vision/dino_vit/online/total_segmentator_2d.yaml new file mode 100644 index 00000000..f4cdc68f --- /dev/null +++ b/configs/vision/dino_vit/online/total_segmentator_2d.yaml @@ -0,0 +1,109 @@ +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/total_segmentator_2d/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + callbacks: + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5] + std: &NORMALIZE_STD [0.5, 0.5, 0.5] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassJaccardIndex} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 5 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + encoder: + class_path: eva.vision.models.networks.encoders.TimmEncoder + init_args: + model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224} + pretrained: true + out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 1} + model_arguments: + dynamic_img_size: true + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + init_args: + in_features: ${oc.env:DECODER_IN_FEATURES, 384} + num_classes: &NUM_CLASSES 118 + criterion: torch.nn.CrossEntropyLoss + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0001 + weight_decay: 0.05 + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: eva.core.metrics.wrappers.ClasswiseWrapper + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: *NUM_CLASSES + average: null +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.TotalSegmentator2D + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/total_segmentator + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset from https://zenodo.org/records/10047292 + # The TotalSegmentator dataset is distributed under the following license: + # "Creative Commons Attribution 4.0 International" + # (see: https://creativecommons.org/licenses/by/4.0/deed.en) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + mean: *NORMALIZE_MEAN + std: *NORMALIZE_STD + val: + class_path: eva.vision.datasets.TotalSegmentator2D + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.vision.datasets.TotalSegmentator2D + init_args: + <<: *DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 16} + shuffle: true + val: + batch_size: *BATCH_SIZE + test: + batch_size: *BATCH_SIZE diff --git a/pdm.lock b/pdm.lock index b881e343..61eb6a1e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "all", "typecheck", "lint", "vision", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:0eada25803c4ccab62b3b9d11881c46185c564f29a81541ddc3cb0c651fe254e" +content_hash = "sha256:32860036df1d64f8d69cd1a8db14dc5f2f8d1a42171ef1668207de6932ec141c" [[package]] name = "absl-py" @@ -2242,6 +2242,9 @@ files = [ name = "timm" version = "1.0.3" requires_python = ">=3.8" +git = "https://github.com/huggingface/pytorch-image-models.git" +ref = "main" +revision = "f8979d4f50b7920c78511746f7315df8f1857bc5" summary = "PyTorch Image Models" groups = ["all", "vision"] dependencies = [ @@ -2251,10 +2254,6 @@ dependencies = [ "torch", "torchvision", ] -files = [ - {file = "timm-1.0.3-py3-none-any.whl", hash = "sha256:d1ec86f7765aa79fbc7491508fa6e285d38a38f10bf4fe44ba2e9c70f91f0f5b"}, - {file = "timm-1.0.3.tar.gz", hash = "sha256:83920a7efe2cfd503b2a1257dc8808d6ff7dcd18a4b79f451c283e7d71497329"}, -] [[package]] name = "tokenizers" diff --git a/pyproject.toml b/pyproject.toml index 97c2dfb3..6edfead3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ vision = [ "h5py>=3.10.0", "nibabel>=5.2.0", "opencv-python-headless>=4.9.0.80", - "timm>=0.9.12", + "timm @ git+https://github.com/huggingface/pytorch-image-models.git@main", "torchvision>=0.17.0", "openslide-python>=1.3.1", ] @@ -68,7 +68,7 @@ all = [ "h5py>=3.10.0", "nibabel>=5.2.0", "opencv-python-headless>=4.9.0.80", - "timm>=0.9.12", + "timm @ git+https://github.com/huggingface/pytorch-image-models.git@main", "torchvision>=0.17.0", "openslide-python>=1.3.1", ] diff --git a/src/eva/core/callbacks/writers/__init__.py b/src/eva/core/callbacks/writers/__init__.py index 8d907e66..16b54b3d 100644 --- a/src/eva/core/callbacks/writers/__init__.py +++ b/src/eva/core/callbacks/writers/__init__.py @@ -1,4 +1,4 @@ -"""Callbacks API.""" +"""Writers callbacks API.""" from eva.core.callbacks.writers.embeddings import ClassificationEmbeddingsWriter diff --git a/src/eva/core/loggers/log/__init__.py b/src/eva/core/loggers/log/__init__.py index a6b65238..46cece03 100644 --- a/src/eva/core/loggers/log/__init__.py +++ b/src/eva/core/loggers/log/__init__.py @@ -1,5 +1,6 @@ -"""Experiment loggers actions.""" +"""Experiment loggers operations.""" +from eva.core.loggers.log.image import log_image from eva.core.loggers.log.parameters import log_parameters -__all__ = ["log_parameters"] +__all__ = ["log_image", "log_parameters"] diff --git a/src/eva/core/loggers/log/image.py b/src/eva/core/loggers/log/image.py new file mode 100644 index 00000000..d087a5d1 --- /dev/null +++ b/src/eva/core/loggers/log/image.py @@ -0,0 +1,59 @@ +"""Image log functionality.""" + +import functools + +import torch + +from eva.core.loggers import loggers +from eva.core.loggers.log import utils + + +@functools.singledispatch +def log_image( + logger, + tag: str, + image: torch.Tensor, + step: int = 0, +) -> None: + """Adds an image to the logger. + + Args: + logger: The desired logger. + tag: The log tag. + image: The image tensor to log. It should have + the shape of (3,H,W) and (0,1) normalized. + step: The global step of the log. + """ + utils.raise_not_supported(logger, "image") + + +@log_image.register +def _( + loggers: list, + tag: str, + image: torch.Tensor, + step: int = 0, +) -> None: + """Adds an image to a list of supported loggers.""" + for logger in loggers: + log_image( + logger, + tag=tag, + image=image, + step=step, + ) + + +@log_image.register +def _( + logger: loggers.TensorBoardLogger, + tag: str, + image: torch.Tensor, + step: int = 0, +) -> None: + """Adds an image to a TensorBoard logger.""" + logger.experiment.add_image( + tag=tag, + img_tensor=image, + global_step=step, + ) diff --git a/src/eva/core/loggers/loggers.py b/src/eva/core/loggers/loggers.py new file mode 100644 index 00000000..1ed815e5 --- /dev/null +++ b/src/eva/core/loggers/loggers.py @@ -0,0 +1,6 @@ +"""Experimental loggers.""" + +from lightning.pytorch.loggers import TensorBoardLogger + +Loggers = TensorBoardLogger +"""Supported loggers.""" diff --git a/src/eva/core/metrics/defaults/__init__.py b/src/eva/core/metrics/defaults/__init__.py index a75d8286..84120acf 100644 --- a/src/eva/core/metrics/defaults/__init__.py +++ b/src/eva/core/metrics/defaults/__init__.py @@ -1,6 +1,13 @@ """Default metric collections API.""" -from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics -from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics +from eva.core.metrics.defaults.classification import ( + BinaryClassificationMetrics, + MulticlassClassificationMetrics, +) +from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics -__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"] +__all__ = [ + "MulticlassClassificationMetrics", + "BinaryClassificationMetrics", + "MulticlassSegmentationMetrics", +] diff --git a/src/eva/core/metrics/defaults/classification/__init__.py b/src/eva/core/metrics/defaults/classification/__init__.py index 638d43ab..d30fbee9 100644 --- a/src/eva/core/metrics/defaults/classification/__init__.py +++ b/src/eva/core/metrics/defaults/classification/__init__.py @@ -3,4 +3,4 @@ from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics -__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"] +__all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"] diff --git a/src/eva/core/metrics/defaults/classification/binary.py b/src/eva/core/metrics/defaults/classification/binary.py index e13d55e3..ad6d45a6 100644 --- a/src/eva/core/metrics/defaults/classification/binary.py +++ b/src/eva/core/metrics/defaults/classification/binary.py @@ -17,15 +17,6 @@ def __init__( ) -> None: """Initializes the binary classification metrics. - The metrics instantiated here are: - - - BinaryAUROC - - BinaryAccuracy - - BinaryBalancedAccuracy - - BinaryF1Score - - BinaryPrecision - - BinaryRecall - Args: threshold: Threshold for transforming probability to binary (0,1) predictions ignore_index: Specifies a target value that is ignored and does not diff --git a/src/eva/core/metrics/defaults/classification/multiclass.py b/src/eva/core/metrics/defaults/classification/multiclass.py index fba93835..09cd54f4 100644 --- a/src/eva/core/metrics/defaults/classification/multiclass.py +++ b/src/eva/core/metrics/defaults/classification/multiclass.py @@ -20,14 +20,6 @@ def __init__( ) -> None: """Initializes the multi-class classification metrics. - The metrics instantiated here are: - - - MulticlassAccuracy - - MulticlassPrecision - - MulticlassRecall - - MulticlassF1Score - - MulticlassAUROC - Args: num_classes: Integer specifying the number of classes. average: Defines the reduction that is applied over labels. diff --git a/src/eva/core/metrics/defaults/segmentation/__init__.py b/src/eva/core/metrics/defaults/segmentation/__init__.py new file mode 100644 index 00000000..31c397dd --- /dev/null +++ b/src/eva/core/metrics/defaults/segmentation/__init__.py @@ -0,0 +1,5 @@ +"""Default segmentation metric collections API.""" + +from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics + +__all__ = ["MulticlassSegmentationMetrics"] diff --git a/src/eva/core/metrics/defaults/segmentation/multiclass.py b/src/eva/core/metrics/defaults/segmentation/multiclass.py new file mode 100644 index 00000000..37fea48d --- /dev/null +++ b/src/eva/core/metrics/defaults/segmentation/multiclass.py @@ -0,0 +1,66 @@ +"""Default metric collection for multiclass semantic segmentation tasks.""" + +from typing import Literal + +from torchmetrics import classification + +from eva.core.metrics import structs + + +class MulticlassSegmentationMetrics(structs.MetricCollection): + """Default metrics for multi-class semantic segmentation tasks.""" + + def __init__( + self, + num_classes: int, + average: Literal["macro", "weighted", "none"] = "macro", + ignore_index: int | None = None, + prefix: str | None = None, + postfix: str | None = None, + ) -> None: + """Initializes the multi-class semantic segmentation metrics. + + Args: + num_classes: Integer specifying the number of classes. + average: Defines the reduction that is applied over labels. + ignore_index: Specifies a target value that is ignored and + does not contribute to the metric calculation. + prefix: A string to add before the keys in the output dictionary. + postfix: A string to add after the keys in the output dictionary. + """ + super().__init__( + metrics=[ + classification.MulticlassJaccardIndex( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + classification.MulticlassF1Score( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + classification.MulticlassPrecision( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + classification.MulticlassRecall( + num_classes=num_classes, + average=average, + ignore_index=ignore_index, + ), + ], + prefix=prefix, + postfix=postfix, + compute_groups=[ + [ + "MulticlassJaccardIndex", + ], + [ + "MulticlassF1Score", + "MulticlassPrecision", + "MulticlassRecall", + ], + ], + ) diff --git a/src/eva/core/metrics/structs/schemas.py b/src/eva/core/metrics/structs/schemas.py index ef8b7fb4..a258fa45 100644 --- a/src/eva/core/metrics/structs/schemas.py +++ b/src/eva/core/metrics/structs/schemas.py @@ -44,4 +44,6 @@ def _join_with_common(self, metrics: MetricModuleType | None) -> MetricModuleTyp if metrics is None or self.common is None: return self.common or metrics - return [self.common, metrics] # type: ignore + metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore + common = self.common if isinstance(self.common, list) else [self.common] + return common + metrics # type: ignore diff --git a/src/eva/core/metrics/wrappers/__init__.py b/src/eva/core/metrics/wrappers/__init__.py new file mode 100644 index 00000000..4c4b77a3 --- /dev/null +++ b/src/eva/core/metrics/wrappers/__init__.py @@ -0,0 +1,5 @@ +"""Metric wrappers API.""" + +from eva.core.metrics.wrappers.classwise import ClasswiseWrapper + +__all__ = ["ClasswiseWrapper"] diff --git a/src/eva/core/metrics/wrappers/classwise.py b/src/eva/core/metrics/wrappers/classwise.py new file mode 100644 index 00000000..006cfd2f --- /dev/null +++ b/src/eva/core/metrics/wrappers/classwise.py @@ -0,0 +1,24 @@ +"""Wrapper metric to retrieve classwise metrics from metrics.""" + +from typing import Any + +from torchmetrics import wrappers +from typing_extensions import override + + +class ClasswiseWrapper(wrappers.ClasswiseWrapper): + """Wrapper metric for altering the output of classification metrics. + + It adds kwargs filtering during the update step for easy integration + with `MetricCollection`. + """ + + @override + def forward(self, *args: Any, **kwargs: Any) -> Any: + metric_kwargs = self.metric._filter_kwargs(**kwargs) + return self._convert(self.metric(*args, **metric_kwargs)) + + @override + def update(self, *args: Any, **kwargs: Any) -> None: + metric_kwargs = self.metric._filter_kwargs(**kwargs) + self.metric.update(*args, **metric_kwargs) diff --git a/src/eva/core/models/modules/typings.py b/src/eva/core/models/modules/typings.py index e9c56675..a999a7a3 100644 --- a/src/eva/core/models/modules/typings.py +++ b/src/eva/core/models/modules/typings.py @@ -21,3 +21,16 @@ class INPUT_BATCH(NamedTuple): metadata: Dict[str, Any] | None = None """The associated metadata.""" + + +class INPUT_TENSOR_BATCH(NamedTuple): + """Tensor based input batch data scheme.""" + + data: torch.Tensor + """The data batch.""" + + targets: torch.Tensor + """The target batch.""" + + metadata: Dict[str, Any] | None = None + """The associated metadata.""" diff --git a/src/eva/core/models/networks/__init__.py b/src/eva/core/models/networks/__init__.py index c1ce5bb3..34e3d7a9 100644 --- a/src/eva/core/models/networks/__init__.py +++ b/src/eva/core/models/networks/__init__.py @@ -3,4 +3,4 @@ from eva.core.models.networks.mlp import MLP from eva.core.models.networks.wrappers import HuggingFaceModel, ModelFromFunction, ONNXModel -__all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel", "MLP"] +__all__ = ["MLP", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"] diff --git a/src/eva/core/models/networks/wrappers/__init__.py b/src/eva/core/models/networks/wrappers/__init__.py index 42513046..766dd8f4 100644 --- a/src/eva/core/models/networks/wrappers/__init__.py +++ b/src/eva/core/models/networks/wrappers/__init__.py @@ -5,4 +5,9 @@ from eva.core.models.networks.wrappers.huggingface import HuggingFaceModel from eva.core.models.networks.wrappers.onnx import ONNXModel -__all__ = ["BaseModel", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"] +__all__ = [ + "BaseModel", + "ModelFromFunction", + "HuggingFaceModel", + "ONNXModel", +] diff --git a/src/eva/core/utils/__init__.py b/src/eva/core/utils/__init__.py index f99e16fd..8d8dd40a 100644 --- a/src/eva/core/utils/__init__.py +++ b/src/eva/core/utils/__init__.py @@ -1 +1,5 @@ """Utilities and library level helper functionalities.""" + +from eva.core.utils.memory import to_cpu + +__all__ = ["to_cpu"] diff --git a/src/eva/core/utils/memory.py b/src/eva/core/utils/memory.py new file mode 100644 index 00000000..d820f718 --- /dev/null +++ b/src/eva/core/utils/memory.py @@ -0,0 +1,28 @@ +"""Memory related functions.""" + +import functools +from typing import Any, Dict, List + +import torch + + +@functools.singledispatch +def to_cpu(tensor_type: Any) -> Any: + """Moves tensor objects to `cpu`.""" + raise TypeError(f"Unsupported input type: {type(input)}.") + + +@to_cpu.register +def _(tensor: torch.Tensor) -> torch.Tensor: + detached_tensor = tensor.detach() + return detached_tensor.cpu() + + +@to_cpu.register +def _(tensors: list) -> List[torch.Tensor]: + return list(map(to_cpu, tensors)) + + +@to_cpu.register +def _(tensors: dict) -> Dict[str, torch.Tensor]: + return {key: to_cpu(tensors[key]) for key in tensors} diff --git a/src/eva/vision/__init__.py b/src/eva/vision/__init__.py index 0e4bd4ca..aba88723 100644 --- a/src/eva/vision/__init__.py +++ b/src/eva/vision/__init__.py @@ -1,7 +1,7 @@ """eva vision API.""" try: - from eva.vision import models, utils + from eva.vision import callbacks, models, utils from eva.vision.data import datasets, transforms except ImportError as e: msg = ( @@ -11,4 +11,4 @@ ) raise ImportError(str(e) + "\n\n" + msg) from e -__all__ = ["models", "utils", "datasets", "transforms"] +__all__ = ["callbacks", "models", "utils", "datasets", "transforms"] diff --git a/src/eva/vision/callbacks/__init__.py b/src/eva/vision/callbacks/__init__.py new file mode 100644 index 00000000..d7022760 --- /dev/null +++ b/src/eva/vision/callbacks/__init__.py @@ -0,0 +1,5 @@ +"""Vision callbacks API.""" + +from eva.vision.callbacks.loggers import SemanticSegmentationLogger + +__all__ = ["SemanticSegmentationLogger"] diff --git a/src/eva/vision/callbacks/loggers/__init__.py b/src/eva/vision/callbacks/loggers/__init__.py new file mode 100644 index 00000000..8d5fcf47 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/__init__.py @@ -0,0 +1,5 @@ +"""Vision logging related callbacks API.""" + +from eva.vision.callbacks.loggers.batch import SemanticSegmentationLogger + +__all__ = ["SemanticSegmentationLogger"] diff --git a/src/eva/vision/callbacks/loggers/batch/__init__.py b/src/eva/vision/callbacks/loggers/batch/__init__.py new file mode 100644 index 00000000..d03f1342 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/batch/__init__.py @@ -0,0 +1,5 @@ +"""Batch related loggers callbacks API.""" + +from eva.vision.callbacks.loggers.batch.segmentation import SemanticSegmentationLogger + +__all__ = ["SemanticSegmentationLogger"] diff --git a/src/eva/vision/callbacks/loggers/batch/base.py b/src/eva/vision/callbacks/loggers/batch/base.py new file mode 100644 index 00000000..311b39e0 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/batch/base.py @@ -0,0 +1,130 @@ +"""Base batch callback logger.""" + +import abc + +from lightning import pytorch as pl +from lightning.pytorch.utilities.types import STEP_OUTPUT +from typing_extensions import override + +from eva.core.models.modules.typings import INPUT_TENSOR_BATCH + + +class BatchLogger(pl.Callback, abc.ABC): + """Logs training and validation batch assets.""" + + _batch_idx_to_log: int = 0 + """The batch index log.""" + + def __init__( + self, + log_every_n_epochs: int | None = None, + log_every_n_steps: int | None = None, + ) -> None: + """Initializes the callback object. + + Args: + log_every_n_epochs: Epoch-wise logging frequency. + log_every_n_steps: Step-wise logging frequency. + """ + super().__init__() + + if log_every_n_epochs is None and log_every_n_steps is None: + raise ValueError( + "Please configure the logging frequency though " + "`log_every_n_epochs` or `log_every_n_steps`." + ) + if None not in [log_every_n_epochs, log_every_n_steps]: + raise ValueError( + "Arguments `log_every_n_epochs` and `log_every_n_steps` " + "are mutually exclusive. Please configure one of them." + ) + + self._log_every_n_epochs = log_every_n_epochs + self._log_every_n_steps = log_every_n_steps + + @override + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + batch_idx: int, + ) -> None: + if self._skip_logging(trainer, batch_idx if self._log_every_n_epochs else None): + return + + self._log_batch( + trainer=trainer, + batch=batch, + outputs=outputs, + tag="BatchTrain", + ) + + @override + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if self._skip_logging(trainer, batch_idx): + return + + self._log_batch( + trainer=trainer, + batch=batch, + outputs=outputs, + tag="BatchValidation", + ) + + @abc.abstractmethod + def _log_batch( + self, + trainer: pl.Trainer, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + tag: str, + ) -> None: + """Logs the batch data. + + Args: + trainer: The trainer. + outputs: The output of the train / val step. + batch: The data batch. + tag: The log tag. + """ + + def _skip_logging( + self, + trainer: pl.Trainer, + batch_idx: int | None = None, + ) -> bool: + """Determines whether skip the logging step or not. + + Args: + trainer: The trainer. + batch_idx: The batch index. + + Returns: + A boolean indicating whether to skip the step execution. + """ + if trainer.global_step in [0, 1]: + return False + + skip_due_frequency = any( + [ + (trainer.current_epoch + 1) % (self._log_every_n_epochs or 1) != 0, + (trainer.global_step + 1) % (self._log_every_n_steps or 1) != 0, + ] + ) + + conditions = [ + skip_due_frequency, + not trainer.is_global_zero, + batch_idx != self._batch_idx_to_log if batch_idx else False, + ] + return any(conditions) diff --git a/src/eva/vision/callbacks/loggers/batch/segmentation.py b/src/eva/vision/callbacks/loggers/batch/segmentation.py new file mode 100644 index 00000000..201486b7 --- /dev/null +++ b/src/eva/vision/callbacks/loggers/batch/segmentation.py @@ -0,0 +1,152 @@ +"""Segmentation datasets related data loggers.""" + +from typing import List, Tuple + +import torch +import torchvision +from lightning import pytorch as pl +from lightning.pytorch.utilities.types import STEP_OUTPUT +from typing_extensions import override + +from eva.core.loggers import log +from eva.core.models.modules.typings import INPUT_TENSOR_BATCH +from eva.core.utils import to_cpu +from eva.vision.callbacks.loggers.batch import base +from eva.vision.utils import colormap, convert + + +class SemanticSegmentationLogger(base.BatchLogger): + """Log the segmentation batch.""" + + def __init__( + self, + max_samples: int = 10, + number_of_images_per_subgrid_row: int = 2, + mean: Tuple[float, ...] = (0.0, 0.0, 0.0), + std: Tuple[float, ...] = (1.0, 1.0, 1.0), + log_every_n_epochs: int | None = None, + log_every_n_steps: int | None = None, + ) -> None: + """Initializes the callback object. + + Args: + max_samples: The maximum number of images displayed in the grid. + number_of_images_per_subgrid_row: Number of images displayed in each row + of each sub-grid (that is images, targets and predictions). + mean: The mean of the input images to de-normalize from. + std: The std of the input images to de-normalize from. + log_every_n_epochs: Epoch-wise logging frequency. + log_every_n_steps: Step-wise logging frequency. + """ + super().__init__( + log_every_n_epochs=log_every_n_epochs, + log_every_n_steps=log_every_n_steps, + ) + + self._max_samples = max_samples + self._number_of_images_per_subgrid_row = number_of_images_per_subgrid_row + self._mean = mean + self._std = std + + @override + def _log_batch( + self, + trainer: pl.Trainer, + outputs: STEP_OUTPUT, + batch: INPUT_TENSOR_BATCH, + tag: str, + ) -> None: + predictions = outputs.get("predictions") if isinstance(outputs, dict) else None + if predictions is None: + raise ValueError("Key `predictions` is missing from the output.") + + images, targets, predictions = _subsample_tensors( + tensors_stack=[batch[0], batch[1], predictions], + max_samples=self._max_samples, + ) + images, targets, predictions = to_cpu([images, targets, predictions]) + predictions = torch.argmax(predictions, dim=1) + + images = list(map(self._format_image, images)) + targets = list(map(_draw_semantic_mask, targets)) + predictions = list(map(_draw_semantic_mask, predictions)) + image_grid = _make_grid_from_image_groups( + [images, targets, predictions], self._number_of_images_per_subgrid_row + ) + + log.log_image( + trainer.loggers, + image=image_grid, + tag=tag, + step=trainer.global_step, + ) + + def _format_image(self, image: torch.Tensor) -> torch.Tensor: + """Descaled an image tensor to (0, 255) uint8 tensor.""" + return convert.descale_and_denorm_image(image, mean=self._mean, std=self._std) + + +def _subsample_tensors( + tensors_stack: List[torch.Tensor], + max_samples: int, +) -> List[torch.Tensor]: + """Sub-samples tensors from a list of tensors in-place. + + Args: + tensors_stack: A list of tensors. + max_samples: The maximum number of images + displayed in the grid. + + Returns: + A sub-sample of the input tensors stack. + """ + for i, tensor in enumerate(tensors_stack): + tensors_stack[i] = tensor[:max_samples] + return tensors_stack + + +def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor: + """Draws a semantic mask to an image RGB tensor. + + The input semantic mask is a (H x W) shaped tensor with + integer values which represent the pixel class id. + + Args: + tensor: An image tensor of range [0., 1.]. + + Returns: + The image as a tensor of range [0., 255.]. + """ + tensor = torch.squeeze(tensor) + height, width = tensor.shape[-2], tensor.shape[-1] + red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8) + for class_id, color in colormap.COLORMAP.items(): + indices = tensor == class_id + red[indices], green[indices], blue[indices] = color + return torch.stack([red, green, blue]) + + +def _make_grid_from_image_groups( + image_groups: List[List[torch.Tensor]], + number_of_images_per_subgrid_row: int = 2, +) -> torch.Tensor: + """Creates a single image grid from image groups. + + For example, it can combine the input images, targets predictions into a single image. + + Args: + image_groups: A list of lists of image tensors of shape (C x H x W) + all of the same size. + number_of_images_per_subgrid_row: Number of images displayed in each + row of the sub-grid. + + Returns: + An image grid as a `torch.Tensor`. + """ + return torchvision.utils.make_grid( + [ + torchvision.utils.make_grid(image_group, nrow=number_of_images_per_subgrid_row) + for image_group in image_groups + ], + nrow=len(image_groups), + ) diff --git a/src/eva/vision/data/datasets/segmentation/__init__.py b/src/eva/vision/data/datasets/segmentation/__init__.py index 4e5a3fc0..3cc60ce7 100644 --- a/src/eva/vision/data/datasets/segmentation/__init__.py +++ b/src/eva/vision/data/datasets/segmentation/__init__.py @@ -1,6 +1,6 @@ """Segmentation datasets API.""" from eva.vision.data.datasets.segmentation.base import ImageSegmentation -from eva.vision.data.datasets.segmentation.total_segmentator import TotalSegmentator2D +from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D __all__ = ["ImageSegmentation", "TotalSegmentator2D"] diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py similarity index 54% rename from src/eva/vision/data/datasets/segmentation/total_segmentator.py rename to src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 92bb8992..3830b468 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -3,24 +3,30 @@ import functools import os from glob import glob -from typing import Callable, Dict, List, Literal, Tuple +from typing import Any, Callable, Dict, List, Literal, Tuple import numpy as np +import numpy.typing as npt +import torch +import tqdm from torchvision import tv_tensors from torchvision.datasets import utils from typing_extensions import override -from eva.vision.data.datasets import _utils, _validators, structs +from eva.vision.data.datasets import _validators, structs from eva.vision.data.datasets.segmentation import base -from eva.vision.utils import convert, io +from eva.vision.utils import io class TotalSegmentator2D(base.ImageSegmentation): """TotalSegmentator 2D segmentation dataset.""" _expected_dataset_lengths: Dict[str, int] = { - "train_small": 29892, - "val_small": 6480, + "train_small": 35089, + "val_small": 1283, + "train_full": 278190, + "val_full": 14095, + "test_full": 25578, } """Dataset version and split to the expected size.""" @@ -45,13 +51,20 @@ class TotalSegmentator2D(base.ImageSegmentation): ] """Resources for the small dataset version.""" + _license: str = ( + "Creative Commons Attribution 4.0 International " + "(https://creativecommons.org/licenses/by/4.0/deed.en)" + ) + """Dataset license.""" + def __init__( self, root: str, - split: Literal["train", "val"] | None, - version: Literal["small", "full"] | None = "small", + split: Literal["train", "val", "test"] | None, + version: Literal["small", "full"] | None = "full", download: bool = False, - as_uint8: bool = True, + classes: List[str] | None = None, + optimize_mask_loading: bool = True, transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -66,7 +79,12 @@ def __init__( Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not exist yet on disk. - as_uint8: Whether to convert and return the images as a 8-bit. + classes: Whether to configure the dataset with a subset of classes. + If `None`, it will use all of them. + optimize_mask_loading: Whether to pre-process the segmentation masks + in order to optimize the loading time. In the `setup` method, it + will reformat the binary one-hot masks to a semantic mask and store + it on disk. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. """ @@ -76,7 +94,13 @@ def __init__( self._split = split self._version = version self._download = download - self._as_uint8 = as_uint8 + self._classes = classes + self._optimize_mask_loading = optimize_mask_loading + + if self._optimize_mask_loading and self._classes is not None: + raise ValueError( + "To use customize classes please set the optimize_mask_loading to `False`." + ) self._samples_dirs: List[str] = [] self._indices: List[Tuple[int, int]] = [] @@ -91,7 +115,13 @@ def get_filename(path: str) -> str: first_sample_labels = os.path.join( self._root, self._samples_dirs[0], "segmentations", "*.nii.gz" ) - return sorted(map(get_filename, glob(first_sample_labels))) + all_classes = sorted(map(get_filename, glob(first_sample_labels))) + if self._classes: + is_subset = all(name in all_classes for name in self._classes) + if not is_subset: + raise ValueError("Provided class names are not subset of the dataset onces.") + + return all_classes if self._classes is None else self._classes @property @override @@ -100,9 +130,9 @@ def class_to_idx(self) -> Dict[str, int]: @override def filename(self, index: int) -> str: - sample_idx, _ = self._indices[index] + sample_idx, slice_index = self._indices[index] sample_dir = self._samples_dirs[sample_idx] - return os.path.join(sample_dir, "ct.nii.gz") + return os.path.join(sample_dir, f"{slice_index}-ct.nii.gz") @override def prepare_data(self) -> None: @@ -113,17 +143,23 @@ def prepare_data(self) -> None: def configure(self) -> None: self._samples_dirs = self._fetch_samples_dirs() self._indices = self._create_indices() + if self._optimize_mask_loading: + self._export_semantic_label_masks() @override def validate(self) -> None: - if self._version is None: + if self._version is None or self._sample_every_n_slices is not None: return _validators.check_dataset_integrity( self, length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0), - n_classes=117, - first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"), + n_classes=len(self._classes) if self._classes else 117, + first_and_last_labels=( + (self._classes[0], self._classes[-1]) + if self._classes + else ("adrenal_gland_left", "vertebrae_T9") + ), ) @override @@ -134,25 +170,63 @@ def __len__(self) -> int: def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] image_path = self._get_image_path(sample_index) - image_array = io.read_nifti_slice(image_path, slice_index) - if self._as_uint8: - image_array = convert.to_8bit(image_array) + image_array = io.read_nifti(image_path, slice_index) image_rgb_array = image_array.repeat(3, axis=2) return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1)) @override def load_mask(self, index: int) -> tv_tensors.Mask: + if self._optimize_mask_loading: + return self._load_semantic_label_mask(index) + return self._load_mask(index) + + def _load_mask(self, index: int) -> tv_tensors.Mask: + """Loads and builds the segmentation mask from NifTi files.""" + sample_index, slice_index = self._indices[index] + semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index) + return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue] + + def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: + """Loads the segmentation mask from a semantic label NifTi file.""" sample_index, slice_index = self._indices[index] masks_dir = self._get_masks_dir(sample_index) - mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - one_hot_encoded = np.concatenate( - [io.read_nifti_slice(path, slice_index) for path in mask_paths], - axis=2, - ) - background_mask = one_hot_encoded.sum(axis=2, keepdims=True) == 0 - one_hot_encoded_with_bg = np.concatenate([background_mask, one_hot_encoded], axis=2) - segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2) - return tv_tensors.Mask(segmentation_label) + filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") + semantic_labels = io.read_nifti(filename, slice_index) + return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + + def _load_masks_as_semantic_label( + self, sample_index: int, slice_index: int | None = None + ) -> npt.NDArray[Any]: + """Loads binary masks as a semantic label mask. + + Args: + sample_index: The data sample index. + slice_index: Whether to return only a specific slice. + """ + masks_dir = self._get_masks_dir(sample_index) + mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes] + binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths] + background_mask = np.zeros_like(binary_masks[0]) + return np.argmax([background_mask] + binary_masks, axis=0) + + def _export_semantic_label_masks(self) -> None: + """Exports the segmentation binary masks (one-hot) to semantic labels.""" + total_samples = len(self._samples_dirs) + masks_dirs = map(self._get_masks_dir, range(total_samples)) + semantic_labels = [ + (index, os.path.join(directory, "semantic_labels", "masks.nii.gz")) + for index, directory in enumerate(masks_dirs) + ] + to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels) + + for sample_index, filename in tqdm.tqdm( + list(to_export), + desc=">> Exporting optimized semantic masks", + leave=False, + ): + semantic_labels = self._load_masks_as_semantic_label(sample_index) + os.makedirs(os.path.dirname(filename), exist_ok=True) + io.save_array_as_nifti(semantic_labels, filename) def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" @@ -167,7 +241,8 @@ def _get_masks_dir(self, sample_index: int) -> str: def _get_number_of_slices_per_sample(self, sample_index: int) -> int: """Returns the total amount of slices of a sample.""" image_path = self._get_image_path(sample_index) - return io.fetch_total_nifti_slices(image_path) + image_shape = io.fetch_nifti_shape(image_path) + return image_shape[-1] def _fetch_samples_dirs(self) -> List[str]: """Returns the name of all the samples of all the splits of the dataset.""" @@ -180,16 +255,20 @@ def _fetch_samples_dirs(self) -> List[str]: def _get_split_indices(self) -> List[int]: """Returns the samples indices that corresponding the dataset split and version.""" - key = f"{self._split}_{self._version}" - match key: - case "train_small": - index_ranges = [(0, 83)] - case "val_small": - index_ranges = [(83, 102)] + metadata_file = os.path.join(self._root, "meta.csv") + metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig") + + match self._split: + case "train": + image_ids = [item["image_id"] for item in metadata if item["split"] == "train"] + case "val": + image_ids = [item["image_id"] for item in metadata if item["split"] == "val"] + case "test": + image_ids = [item["image_id"] for item in metadata if item["split"] == "test"] case _: - index_ranges = [(0, len(self._samples_dirs))] + image_ids = self._samples_dirs - return _utils.ranges_to_indices(index_ranges) + return sorted(map(self._samples_dirs.index, image_ids)) def _create_indices(self) -> List[Tuple[int, int]]: """Builds the dataset indices for the specified split. @@ -219,6 +298,7 @@ def _download_dataset(self) -> None: f"Can't download data version '{self._version}'. Use 'small' or 'full'." ) + self._print_license() for resource in resources: if os.path.isdir(self._root): continue @@ -229,3 +309,7 @@ def _download_dataset(self) -> None: filename=resource.filename, remove_finished=True, ) + + def _print_license(self) -> None: + """Prints the dataset license.""" + print(f"Dataset license: {self._license}") diff --git a/src/eva/vision/data/transforms/__init__.py b/src/eva/vision/data/transforms/__init__.py index 4e89dfa3..8fc222e8 100644 --- a/src/eva/vision/data/transforms/__init__.py +++ b/src/eva/vision/data/transforms/__init__.py @@ -1,5 +1,6 @@ """Vision data transforms.""" -from eva.vision.data.transforms.common import ResizeAndCrop +from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop +from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity -__all__ = ["ResizeAndCrop"] +__all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"] diff --git a/src/eva/vision/data/transforms/common/__init__.py b/src/eva/vision/data/transforms/common/__init__.py index b1a13ac5..1511cd74 100644 --- a/src/eva/vision/data/transforms/common/__init__.py +++ b/src/eva/vision/data/transforms/common/__init__.py @@ -1,5 +1,6 @@ """Common vision transforms.""" +from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop -__all__ = ["ResizeAndCrop"] +__all__ = ["ResizeAndClamp", "ResizeAndCrop"] diff --git a/src/eva/vision/data/transforms/common/resize_and_clamp.py b/src/eva/vision/data/transforms/common/resize_and_clamp.py new file mode 100644 index 00000000..bf86b847 --- /dev/null +++ b/src/eva/vision/data/transforms/common/resize_and_clamp.py @@ -0,0 +1,53 @@ +"""Specialized transforms for resizing, clamping and range normalizing.""" + +from typing import Callable, Sequence, Tuple + +import torch +import torchvision.transforms.v2 as torch_transforms + +from eva.vision.data.transforms import normalization + + +class ResizeAndClamp(torch_transforms.Compose): + """Resizes, crops, clamps and normalizes an input image.""" + + def __init__( + self, + size: int | Sequence[int] = 224, + clamp_range: Tuple[int, int] = (-1024, 1024), + mean: Sequence[float] = (0.0, 0.0, 0.0), + std: Sequence[float] = (1.0, 1.0, 1.0), + ) -> None: + """Initializes the transform object. + + Args: + size: Desired output size of the crop. If size is an `int` instead + of sequence like (h, w), a square crop (size, size) is made. + clamp_range: The lower and upper bound to clamp the pixel values. + mean: Sequence of means for each image channel. + std: Sequence of standard deviations for each image channel. + """ + self._size = size + self._clamp_range = clamp_range + self._mean = mean + self._std = std + + super().__init__(transforms=self._build_transforms()) + + def _build_transforms(self) -> Sequence[Callable]: + """Builds and returns the list of transforms.""" + transforms = [ + torch_transforms.Resize(size=self._size), + torch_transforms.CenterCrop(size=self._size), + normalization.Clamp(out_range=self._clamp_range), + torch_transforms.ToDtype(torch.float32), + normalization.RescaleIntensity( + in_range=self._clamp_range, + out_range=(0.0, 1.0), + ), + torch_transforms.Normalize( + mean=self._mean, + std=self._std, + ), + ] + return transforms diff --git a/src/eva/vision/data/transforms/normalization/__init__.py b/src/eva/vision/data/transforms/normalization/__init__.py new file mode 100644 index 00000000..d995520c --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/__init__.py @@ -0,0 +1,6 @@ +"""Normalization related transformations.""" + +from eva.vision.data.transforms.normalization.clamp import Clamp +from eva.vision.data.transforms.normalization.rescale_intensity import RescaleIntensity + +__all__ = ["Clamp", "RescaleIntensity"] diff --git a/src/eva/vision/data/transforms/normalization/clamp.py b/src/eva/vision/data/transforms/normalization/clamp.py new file mode 100644 index 00000000..6078e182 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/clamp.py @@ -0,0 +1,43 @@ +"""Image clamp transform.""" + +import functools +from typing import Any, Dict, Tuple + +import torch +import torchvision.transforms.v2 as torch_transforms +from torchvision import tv_tensors +from typing_extensions import override + + +class Clamp(torch_transforms.Transform): + """Clamps all elements in input into a specific range.""" + + def __init__(self, out_range: Tuple[int, int]) -> None: + """Initializes the transform. + + Args: + out_range: The lower and upper bound of the range to + be clamped to. + """ + super().__init__() + + self._out_range = out_range + + @functools.singledispatchmethod + @override + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + @_transform.register(torch.Tensor) + def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any: + return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1]) + + @_transform.register(tv_tensors.Image) + def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any: + inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1]) + return tv_tensors.wrap(inpt_clamp, like=inpt) + + @_transform.register(tv_tensors.BoundingBoxes) + @_transform.register(tv_tensors.Mask) + def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any: + return inpt diff --git a/src/eva/vision/data/transforms/normalization/functional/__init__.py b/src/eva/vision/data/transforms/normalization/functional/__init__.py new file mode 100644 index 00000000..0baa7632 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/functional/__init__.py @@ -0,0 +1,5 @@ +"""Functional normalization related transformations API.""" + +from eva.vision.data.transforms.normalization.functional.rescale_intensity import rescale_intensity + +__all__ = ["rescale_intensity"] diff --git a/src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py b/src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py new file mode 100644 index 00000000..f6dc70e0 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/functional/rescale_intensity.py @@ -0,0 +1,28 @@ +"""Intensity level functions.""" + +import sys +from typing import Tuple + +import torch + + +def rescale_intensity( + image: torch.Tensor, + in_range: Tuple[float, float] | None = None, + out_range: Tuple[float, float] = (0.0, 1.0), +) -> torch.Tensor: + """Stretches or shrinks the image intensity levels. + + Args: + image: The image tensor as float-type. + in_range: The input data range. If `None`, it will + fetch the min and max of the input image. + out_range: The desired intensity range of the output. + + Returns: + The image tensor after stretching or shrinking its intensity levels. + """ + imin, imax = in_range or (image.min(), image.max()) + omin, omax = out_range + image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon) + return image_scaled * (omax - omin) + omin diff --git a/src/eva/vision/data/transforms/normalization/rescale_intensity.py b/src/eva/vision/data/transforms/normalization/rescale_intensity.py new file mode 100644 index 00000000..deeea284 --- /dev/null +++ b/src/eva/vision/data/transforms/normalization/rescale_intensity.py @@ -0,0 +1,53 @@ +"""Intensity level scaling transform.""" + +import functools +from typing import Any, Dict, Tuple + +import torch +import torchvision.transforms.v2 as torch_transforms +from torchvision import tv_tensors +from typing_extensions import override + +from eva.vision.data.transforms.normalization import functional + + +class RescaleIntensity(torch_transforms.Transform): + """Stretches or shrinks the image intensity levels.""" + + def __init__( + self, + in_range: Tuple[float, float] | None = None, + out_range: Tuple[float, float] = (0.0, 1.0), + ) -> None: + """Initializes the transform. + + Args: + in_range: The input data range. If `None`, it will + fetch the min and max of the input image. + out_range: The desired intensity range of the output. + """ + super().__init__() + + self._in_range = in_range + self._out_range = out_range + + @functools.singledispatchmethod + @override + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + @_transform.register(torch.Tensor) + def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any: + return functional.rescale_intensity( + inpt, in_range=self._in_range, out_range=self._out_range + ) + + @_transform.register(tv_tensors.Image) + def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any: + scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range) + return tv_tensors.wrap(scaled_inpt, like=inpt) + + @_transform.register(tv_tensors.BoundingBoxes) + @_transform.register(tv_tensors.Mask) + def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any: + return inpt diff --git a/src/eva/vision/models/modules/__init__.py b/src/eva/vision/models/modules/__init__.py new file mode 100644 index 00000000..ac3342f7 --- /dev/null +++ b/src/eva/vision/models/modules/__init__.py @@ -0,0 +1,5 @@ +"""Vision modules API.""" + +from eva.vision.models.modules.semantic_segmentation import SemanticSegmentationModule + +__all__ = ["SemanticSegmentationModule"] diff --git a/src/eva/vision/models/modules/semantic_segmentation.py b/src/eva/vision/models/modules/semantic_segmentation.py new file mode 100644 index 00000000..5b2f1848 --- /dev/null +++ b/src/eva/vision/models/modules/semantic_segmentation.py @@ -0,0 +1,152 @@ +""""Neural Network Semantic Segmentation Module.""" + +from typing import Any, Callable, Iterable, Tuple + +import torch +from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim +from torch.optim import lr_scheduler +from typing_extensions import override + +from eva.core.metrics import structs as metrics_lib +from eva.core.models.modules import module +from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH +from eva.core.models.modules.utils import batch_postprocess, grad +from eva.vision.models.networks import decoders, encoders + + +class SemanticSegmentationModule(module.ModelModule): + """Neural network semantic segmentation module for training on patch embeddings.""" + + def __init__( + self, + decoder: decoders.Decoder, + criterion: Callable[..., torch.Tensor], + encoder: encoders.Encoder | None = None, + lr_multiplier_encoder: float = 0.0, + optimizer: OptimizerCallable = optim.AdamW, + lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR, + metrics: metrics_lib.MetricsSchema | None = None, + postprocess: batch_postprocess.BatchPostProcess | None = None, + ) -> None: + """Initializes the neural net head module. + + Args: + decoder: The decoder model. + criterion: The loss function to use. + encoder: The encoder model. If `None`, it will be expected + that the input batch returns the features directly. + lr_multiplier_encoder: The learning rate multiplier for the + encoder parameters. If `0`, it will freeze the encoder. + optimizer: The optimizer to use. + lr_scheduler: The learning rate scheduler to use. + metrics: The metric groups to track. + postprocess: A list of helper functions to apply after the + loss and before the metrics calculation to the model + predictions and targets. + """ + super().__init__(metrics=metrics, postprocess=postprocess) + + self.decoder = decoder + self.criterion = criterion + self.encoder = encoder + self.lr_multiplier_encoder = lr_multiplier_encoder + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + @override + def configure_model(self) -> None: + self._freeze_encoder() + + @override + def configure_optimizers(self) -> Any: + optimizer = self.optimizer( + [ + {"params": self.decoder.parameters()}, + { + "params": self._encoder_trainable_parameters(), + "lr": self._base_lr * self.lr_multiplier_encoder, + }, + ] + ) + lr_scheduler = self.lr_scheduler(optimizer) + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + + @override + def forward( + self, + inputs: torch.Tensor, + to_size: Tuple[int, int] | None = None, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """Maps the input tensor (image tensor or embeddings) to masks. + + If `inputs` is image tensor, then the `self.encoder` + should be implemented, otherwise it will be interpreted + as embeddings, where the `to_size` should be given. + """ + if self.encoder is None and to_size is None: + raise ValueError( + "Please provide the expected `to_size` that the " + "decoder should map the embeddings (`inputs`) to." + ) + + patch_embeddings = self.encoder(inputs) if self.encoder else inputs + return self.decoder(patch_embeddings, to_size or inputs.shape[-2:]) + + @override + def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self._batch_step(batch) + + @override + def validation_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self._batch_step(batch) + + @override + def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self._batch_step(batch) + + @override + def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor: + tensor = INPUT_BATCH(*batch).data + return tensor if self.encoder is None else self.encoder(tensor) + + @property + def _base_lr(self) -> float: + """Returns the base learning rate.""" + base_optimizer = self.optimizer(self.parameters()) + return base_optimizer.param_groups[-1]["lr"] + + def _encoder_trainable_parameters(self) -> Iterable[torch.Tensor]: + """Returns the trainable parameters of the encoder.""" + return ( + self.encoder.parameters() + if self.encoder is not None and self.lr_multiplier_encoder > 0 + else iter(()) + ) + + def _freeze_encoder(self) -> None: + """If initialized, it freezes the encoder network.""" + if self.encoder is not None and self.lr_multiplier_encoder == 0: + grad.deactivate_requires_grad(self.encoder) + + def _batch_step(self, batch: INPUT_TENSOR_BATCH) -> STEP_OUTPUT: + """Performs a model forward step and calculates the loss. + + Args: + batch: The desired batch to process. + + Returns: + The batch step output. + """ + data, targets, metadata = INPUT_TENSOR_BATCH(*batch) + predictions = self(data, to_size=targets.shape[-2:]) + loss = self.criterion(predictions, targets) + return { + "loss": loss, + "targets": targets, + "predictions": predictions, + "metadata": metadata, + } diff --git a/src/eva/vision/models/networks/__init__.py b/src/eva/vision/models/networks/__init__.py index 554e65c3..4d669351 100644 --- a/src/eva/vision/models/networks/__init__.py +++ b/src/eva/vision/models/networks/__init__.py @@ -2,5 +2,6 @@ from eva.vision.models.networks import postprocesses from eva.vision.models.networks.abmil import ABMIL +from eva.vision.models.networks.from_timm import TimmModel -__all__ = ["postprocesses", "ABMIL"] +__all__ = ["postprocesses", "ABMIL", "TimmModel"] diff --git a/src/eva/vision/models/networks/decoders/__init__.py b/src/eva/vision/models/networks/decoders/__init__.py new file mode 100644 index 00000000..1f659ad1 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/__init__.py @@ -0,0 +1,6 @@ +"""Decoder heads API.""" + +from eva.vision.models.networks.decoders import segmentation +from eva.vision.models.networks.decoders.decoder import Decoder + +__all__ = ["segmentation", "Decoder"] diff --git a/src/eva/vision/models/networks/decoders/decoder.py b/src/eva/vision/models/networks/decoders/decoder.py new file mode 100644 index 00000000..3299a290 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/decoder.py @@ -0,0 +1,7 @@ +"""Semantic segmentation decoder base class.""" + +from torch import nn + + +class Decoder(nn.Module): + """Semantic segmentation decoder base class.""" diff --git a/src/eva/vision/models/networks/decoders/segmentation/__init__.py b/src/eva/vision/models/networks/decoders/segmentation/__init__.py new file mode 100644 index 00000000..e7417d96 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/__init__.py @@ -0,0 +1,11 @@ +"""Segmentation decoder heads API.""" + +from eva.vision.models.networks.decoders.segmentation.common import ( + ConvDecoder1x1, + ConvDecoderMS, + SingleLinearDecoder, +) +from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder +from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder + +__all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"] diff --git a/src/eva/vision/models/networks/decoders/segmentation/common.py b/src/eva/vision/models/networks/decoders/segmentation/common.py new file mode 100644 index 00000000..7f612eb0 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/common.py @@ -0,0 +1,74 @@ +"""Common semantic segmentation decoders. + +This module contains implementations of different types of decoder models +used in semantic segmentation. These decoders convert the high-level features +output by an encoder into pixel-wise predictions for segmentation tasks. +""" + +from torch import nn + +from eva.vision.models.networks.decoders.segmentation import conv2d, linear + + +class ConvDecoder1x1(conv2d.ConvDecoder): + """A convolutional decoder with a single 1x1 convolutional layer.""" + + def __init__(self, in_features: int, num_classes: int) -> None: + """Initializes the decoder. + + Args: + in_features: The hidden dimension size of the embeddings. + num_classes: Number of output classes as channels. + """ + super().__init__( + layers=nn.Conv2d( + in_channels=in_features, + out_channels=num_classes, + kernel_size=(1, 1), + ), + ) + + +class ConvDecoderMS(conv2d.ConvDecoder): + """A multi-stage convolutional decoder with upsampling and convolutional layers. + + This decoder applies a series of upsampling and convolutional layers to transform + the input features into output predictions with the desired spatial resolution. + + This decoder is based on the `+ms` segmentation decoder from DINOv2 + (https://arxiv.org/pdf/2304.07193) + """ + + def __init__(self, in_features: int, num_classes: int) -> None: + """Initializes the decoder. + + Args: + in_features: The hidden dimension size of the embeddings. + num_classes: Number of output classes as channels. + """ + super().__init__( + layers=nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(in_features, 64, kernel_size=(3, 3), padding=(1, 1)), + nn.Upsample(scale_factor=2), + nn.Conv2d(64, num_classes, kernel_size=(3, 3), padding=(1, 1)), + ), + ) + + +class SingleLinearDecoder(linear.LinearDecoder): + """A simple linear decoder with a single fully connected layer.""" + + def __init__(self, in_features: int, num_classes: int) -> None: + """Initializes the decoder. + + Args: + in_features: The hidden dimension size of the embeddings. + num_classes: Number of output classes as channels. + """ + super().__init__( + layers=nn.Linear( + in_features=in_features, + out_features=num_classes, + ), + ) diff --git a/src/eva/vision/models/networks/decoders/segmentation/conv2d.py b/src/eva/vision/models/networks/decoders/segmentation/conv2d.py new file mode 100644 index 00000000..4a5f6d74 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/conv2d.py @@ -0,0 +1,114 @@ +"""Convolutional based semantic segmentation decoder.""" + +from typing import List, Tuple + +import torch +from torch import nn +from torch.nn import functional + +from eva.vision.models.networks.decoders import decoder + + +class ConvDecoder(decoder.Decoder): + """Convolutional segmentation decoder.""" + + def __init__(self, layers: nn.Module) -> None: + """Initializes the convolutional based decoder head. + + Here the input nn layers will be directly applied to the + features of shape (batch_size, hidden_size, n_patches_height, + n_patches_width), where n_patches is image_size / patch_size. + Note the n_patches is also known as grid_size. + + Args: + layers: The convolutional layers to be used as the decoder head. + """ + super().__init__() + + self._layers = layers + + def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: + """Forward function for multi-level feature maps to a single one. + + It will interpolate the features and concat them into a single tensor + on the dimension axis of the hidden size. + + Example: + >>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)] + >>> output = self._forward_features(features) + >>> assert output.shape == torch.Size([16, 768, 14, 14]) + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + + Returns: + A tensor of shape (batch_size, hidden_size, n_patches_height, + n_patches_width) which is feature map of the decoder head. + """ + if not isinstance(features, list) and features[0].ndim != 4: + raise ValueError( + "Input features should be a list of four (4) dimensional inputs of " + "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." + ) + + upsampled_features = [ + functional.interpolate( + input=embeddings, + size=features[0].shape[2:], + mode="bilinear", + align_corners=False, + ) + for embeddings in features + ] + return torch.cat(upsampled_features, dim=1) + + def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: + """Forward of the decoder head. + + Args: + patch_embeddings: The patch embeddings tensor of shape + (batch_size, hidden_size, n_patches_height, n_patches_width). + + Returns: + The logits as a tensor (batch_size, n_classes, upscale_height, upscale_width). + """ + return self._layers(patch_embeddings) + + def _cls_seg( + self, + logits: torch.Tensor, + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Classify each pixel of the image. + + Args: + logits: The decoder outputs of shape (batch_size, n_classes, + height, width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + return functional.interpolate(logits, image_size, mode="bilinear") + + def forward( + self, + features: List[torch.Tensor], + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Maps the patch embeddings to a segmentation mask of the image size. + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + patch_embeddings = self._forward_features(features) + logits = self._forward_head(patch_embeddings) + return self._cls_seg(logits, image_size) diff --git a/src/eva/vision/models/networks/decoders/segmentation/linear.py b/src/eva/vision/models/networks/decoders/segmentation/linear.py new file mode 100644 index 00000000..75229347 --- /dev/null +++ b/src/eva/vision/models/networks/decoders/segmentation/linear.py @@ -0,0 +1,125 @@ +"""Linear based decoder.""" + +from typing import List, Tuple + +import torch +from torch import nn +from torch.nn import functional + +from eva.vision.models.networks.decoders import decoder + + +class LinearDecoder(decoder.Decoder): + """Linear decoder.""" + + def __init__(self, layers: nn.Module) -> None: + """Initializes the linear based decoder head. + + Here the input nn layers will be applied to the reshaped + features (batch_size, patch_embeddings, hidden_size) from + the input (batch_size, hidden_size, height, width) and then + unwrapped again to (batch_size, n_classes, height, width). + + Args: + layers: The linear layers to be used as the decoder head. + """ + super().__init__() + + self._layers = layers + + def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: + """Forward function for multi-level feature maps to a single one. + + It will interpolate the features and concat them into a single tensor + on the dimension axis of the hidden size. + + Example: + >>> features = [torch.Tensor(16, 384, 14, 14), torch.Size(16, 384, 14, 14)] + >>> output = self._forward_features(features) + >>> assert output.shape == torch.Size([16, 768, 14, 14]) + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + + Returns: + A tensor of shape (batch_size, hidden_size, n_patches_height, + n_patches_width) which is feature map of the decoder head. + """ + if not isinstance(features, list) and features[0].ndim != 4: + raise ValueError( + "Input features should be a list of four (4) dimensional inputs of " + "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." + ) + + upsampled_features = [ + functional.interpolate( + input=embeddings, + size=features[0].shape[2:], + mode="bilinear", + align_corners=False, + ) + for embeddings in features + ] + return torch.cat(upsampled_features, dim=1) + + def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: + """Forward of the decoder head. + + Here the following transformations will take place: + - (batch_size, hidden_size, n_patches_height, n_patches_width) + - (batch_size, hidden_size, n_patches_height * n_patches_width) + - (batch_size, n_patches_height * n_patches_width, hidden_size) + - (batch_size, n_patches_height * n_patches_width, n_classes) + - (batch_size, n_classes, n_patches_height, n_patches_width) + + Args: + patch_embeddings: The patch embeddings tensor of shape + (batch_size, hidden_size, n_patches_height, n_patches_width). + + Returns: + The logits as a tensor (batch_size, n_classes, n_patches_height, + n_patches_width). + """ + batch_size, hidden_size, height, width = patch_embeddings.shape + embeddings_reshaped = patch_embeddings.reshape(batch_size, hidden_size, height * width) + logits = self._layers(embeddings_reshaped.permute(0, 2, 1)) + return logits.permute(0, 2, 1).reshape(batch_size, -1, height, width) + + def _cls_seg( + self, + logits: torch.Tensor, + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Classify each pixel of the image. + + Args: + logits: The decoder outputs of shape (batch_size, n_classes, + height, width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + return functional.interpolate(logits, image_size, mode="bilinear") + + def forward( + self, + features: List[torch.Tensor], + image_size: Tuple[int, int], + ) -> torch.Tensor: + """Maps the patch embeddings to a segmentation mask of the image size. + + Args: + features: List of multi-level image features of shape (batch_size, + hidden_size, n_patches_height, n_patches_width). + image_size: The target image size (height, width). + + Returns: + Tensor containing scores for all of the classes with shape + (batch_size, n_classes, image_height, image_width). + """ + patch_embeddings = self._forward_features(features) + logits = self._forward_head(patch_embeddings) + return self._cls_seg(logits, image_size) diff --git a/src/eva/vision/models/networks/encoders/__init__.py b/src/eva/vision/models/networks/encoders/__init__.py new file mode 100644 index 00000000..88ae1345 --- /dev/null +++ b/src/eva/vision/models/networks/encoders/__init__.py @@ -0,0 +1,6 @@ +"""Encoder networks API.""" + +from eva.vision.models.networks.encoders.encoder import Encoder +from eva.vision.models.networks.encoders.from_timm import TimmEncoder + +__all__ = ["Encoder", "TimmEncoder"] diff --git a/src/eva/vision/models/networks/encoders/encoder.py b/src/eva/vision/models/networks/encoders/encoder.py new file mode 100644 index 00000000..b2c53aa3 --- /dev/null +++ b/src/eva/vision/models/networks/encoders/encoder.py @@ -0,0 +1,23 @@ +"""Encoder base class.""" + +import abc +from typing import List + +import torch +from torch import nn + + +class Encoder(nn.Module, abc.ABC): + """Encoder base class.""" + + @abc.abstractmethod + def forward(self, tensor: torch.Tensor) -> List[torch.Tensor]: + """Returns the multi-level feature maps of the model. + + Args: + tensor: The image tensor (batch_size, num_channels, height, width). + + Returns: + The list of multi-level image features of shape (batch_size, + hidden_size, num_patches_height, num_patches_width). + """ diff --git a/src/eva/vision/models/networks/encoders/from_timm.py b/src/eva/vision/models/networks/encoders/from_timm.py new file mode 100644 index 00000000..bf0b4084 --- /dev/null +++ b/src/eva/vision/models/networks/encoders/from_timm.py @@ -0,0 +1,64 @@ +"""Encoder wrapper for timm models.""" + +from typing import Any, Dict, List, Tuple + +import timm +import torch +from torch import nn +from typing_extensions import override + +from eva.vision.models.networks.encoders import encoder + + +class TimmEncoder(encoder.Encoder): + """Encoder wrapper for `timm` models. + + Note that only models with `forward_intermediates` + method are currently only supported. + """ + + def __init__( + self, + model_name: str, + pretrained: bool = False, + checkpoint_path: str = "", + out_indices: int | Tuple[int, ...] | None = 1, + model_arguments: Dict[str, Any] | None = None, + ) -> None: + """Initializes the encoder. + + Args: + model_name: Name of model to instantiate. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + out_indices: Returns last n blocks if `int`, all if `None`, select + matching indices if sequence. + model_arguments: Extra model arguments. + """ + super().__init__() + + self._model_name = model_name + self._pretrained = pretrained + self._checkpoint_path = checkpoint_path + self._out_indices = out_indices + self._model_arguments = model_arguments or {} + + self._feature_extractor: nn.Module + + self.configure_model() + + def configure_model(self) -> None: + """Builds and loads the timm model as feature extractor.""" + self._feature_extractor = timm.create_model( + model_name=self._model_name, + pretrained=self._pretrained, + checkpoint_path=self._checkpoint_path, + out_indices=self._out_indices, + features_only=True, + **self._model_arguments, + ) + TimmEncoder.__name__ = self._model_name + + @override + def forward(self, tensor: torch.Tensor) -> List[torch.Tensor]: + return self._feature_extractor(tensor) diff --git a/src/eva/vision/models/networks/from_timm.py b/src/eva/vision/models/networks/from_timm.py new file mode 100644 index 00000000..22053591 --- /dev/null +++ b/src/eva/vision/models/networks/from_timm.py @@ -0,0 +1,36 @@ +"""Helper wrapper class for timm models.""" + +from typing import Any, Dict + +import timm + +from eva.core.models.networks import wrappers + + +class TimmModel(wrappers.ModelFromFunction): + """Wrapper class for timm models.""" + + def __init__( + self, + model_name: str, + pretrained: bool = False, + checkpoint_path: str = "", + model_arguments: Dict[str, Any] | None = None, + ) -> None: + """Initializes and constructs the model. + + Args: + model_name: Name of model to instantiate. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + model_arguments: The extra callable function / class arguments. + """ + super().__init__( + path=timm.create_model, + arguments={ + "model_name": model_name, + "pretrained": pretrained, + "checkpoint_path": checkpoint_path, + } + | (model_arguments or {}), + ) diff --git a/src/eva/vision/utils/colormap.py b/src/eva/vision/utils/colormap.py new file mode 100644 index 00000000..2ca70604 --- /dev/null +++ b/src/eva/vision/utils/colormap.py @@ -0,0 +1,77 @@ +"""Color mapping constants.""" + +COLORS = [ + (0, 0, 0), + (255, 0, 0), # Red + (0, 255, 0), # Green + (0, 0, 255), # Blue + (255, 255, 0), # Yellow + (255, 0, 255), # Magenta + (0, 255, 255), # Cyan + (128, 128, 0), # Olive + (128, 0, 128), # Purple + (0, 128, 128), # Teal + (192, 192, 192), # Silver + (128, 128, 128), # Gray + (255, 165, 0), # Orange + (210, 105, 30), # Chocolate + (0, 128, 0), # Lime + (255, 192, 203), # Pink + (255, 69, 0), # Red-Orange + (255, 140, 0), # Dark Orange + (0, 255, 255), # Sky Blue + (0, 255, 127), # Spring Green + (0, 0, 139), # Dark Blue + (255, 20, 147), # Deep Pink + (139, 69, 19), # Saddle Brown + (0, 100, 0), # Dark Green + (106, 90, 205), # Slate Blue + (138, 43, 226), # Blue-Violet + (218, 165, 32), # Goldenrod + (199, 21, 133), # Medium Violet Red + (70, 130, 180), # Steel Blue + (165, 42, 42), # Brown + (128, 0, 0), # Maroon + (255, 0, 255), # Fuchsia + (210, 180, 140), # Tan + (0, 0, 128), # Navy + (139, 0, 139), # Dark Magenta + (144, 238, 144), # Light Green + (46, 139, 87), # Sea Green + (255, 255, 0), # Gold + (154, 205, 50), # Yellow Green + (0, 191, 255), # Deep Sky Blue + (0, 250, 154), # Medium Spring Green + (250, 128, 114), # Salmon + (255, 105, 180), # Hot Pink + (204, 255, 204), # Pastel Light Green + (51, 0, 51), # Very Dark Magenta + (255, 102, 0), # Dark Orange + (0, 255, 0), # Bright Green + (51, 153, 255), # Blue-Purple + (51, 51, 255), # Bright Blue + (204, 0, 0), # Dark Red + (90, 90, 90), # Very Dark Gray + (255, 255, 51), # Pastel Yellow + (255, 153, 255), # Pink-Magenta + (153, 0, 76), # Dark Pink + (51, 25, 0), # Very Dark Brown + (102, 51, 0), # Dark Brown + (0, 0, 51), # Very Dark Blue + (180, 180, 180), # Dark Gray + (102, 255, 204), # Pastel Green + (0, 102, 0), # Dark Green + (220, 245, 20), # Lime Yellow + (255, 204, 204), # Pastel Pink + (0, 204, 255), # Pastel Blue + (240, 240, 240), # Light Gray + (153, 153, 0), # Dark Yellow + (102, 0, 51), # Dark Red-Pink + (0, 51, 0), # Very Dark Green + (255, 102, 204), # Magenta Pink + (204, 0, 102), # Red-Pink +] +"""RGB colors.""" + +COLORMAP = dict(enumerate(COLORS)) | {255: (255, 255, 255)} +"""Class id to RGB color mapping.""" diff --git a/src/eva/vision/utils/convert.py b/src/eva/vision/utils/convert.py index cfe22f55..f013a5d7 100644 --- a/src/eva/vision/utils/convert.py +++ b/src/eva/vision/utils/convert.py @@ -1,24 +1,67 @@ """Image conversion related functionalities.""" -from typing import Any +from typing import Iterable -import numpy as np -import numpy.typing as npt +import torch +from torchvision.transforms.v2 import functional -def to_8bit(image_array: npt.NDArray[Any]) -> npt.NDArray[np.uint8]: - """Casts an image of higher bit image (i.e. 16bit) to 8bit. +def descale_and_denorm_image( + image: torch.Tensor, + mean: Iterable[float] = (0.0, 0.0, 0.0), + std: Iterable[float] = (1.0, 1.0, 1.0), + inplace: bool = True, +) -> torch.Tensor: + """De-scales and de-norms an image tensor to (0, 255) range. Args: - image_array: The image array to convert. + image: An image float tensor. + mean: The mean that the image channels are normalized with. + std: The std that the image channels are normalized with. + inplace: Whether to perform the operation in-place. Returns: - The image as normalized as a 8-bit format. + The image tensor of range (0, 255) range as uint8. """ - if np.issubdtype(image_array.dtype, np.integer): - image_array = image_array.astype(np.float64) + if not inplace: + image = image.clone() - image_scaled_array = image_array - image_array.min() - image_scaled_array /= image_scaled_array.max() - image_scaled_array *= 255 - return image_scaled_array.astype(np.uint8) + norm_image = _descale_image(image, mean=mean, std=std) + return _denorm_image(norm_image) + + +def _descale_image( + image: torch.Tensor, + mean: Iterable[float] = (0.0, 0.0, 0.0), + std: Iterable[float] = (1.0, 1.0, 1.0), +) -> torch.Tensor: + """De-scales an image tensor to (0., 1.) range. + + Args: + image: An image float tensor. + mean: The normalized channels mean values. + std: The normalized channels std values. + + Returns: + The de-normalized image tensor of range (0., 1.). + """ + return functional.normalize( + image, + mean=[-cmean / cstd for cmean, cstd in zip(mean, std, strict=False)], + std=[1 / cstd for cstd in std], + ) + + +def _denorm_image(image: torch.Tensor) -> torch.Tensor: + """De-normalizes an image tensor from (0., 1.) to (0, 255) range. + + Args: + image: An image float tensor. + + Returns: + The image tensor of range (0, 255) range as uint8. + """ + image_scaled = image - image.min() + image_scaled /= image_scaled.max() + image_scaled *= 255 + return image_scaled.to(dtype=torch.uint8) diff --git a/src/eva/vision/utils/io/__init__.py b/src/eva/vision/utils/io/__init__.py index 8fe1177b..ef2161ac 100644 --- a/src/eva/vision/utils/io/__init__.py +++ b/src/eva/vision/utils/io/__init__.py @@ -1,13 +1,14 @@ """Vision I/O utilities.""" from eva.vision.utils.io.image import read_image, read_image_as_tensor -from eva.vision.utils.io.nifti import fetch_total_nifti_slices, read_nifti_slice +from eva.vision.utils.io.nifti import fetch_nifti_shape, read_nifti, save_array_as_nifti from eva.vision.utils.io.text import read_csv __all__ = [ "read_image", "read_image_as_tensor", - "fetch_total_nifti_slices", - "read_nifti_slice", + "fetch_nifti_shape", + "read_nifti", + "save_array_as_nifti", "read_csv", ] diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 162f0a64..6859729f 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -1,21 +1,22 @@ """NIfTI I/O related functions.""" -from typing import Any +from typing import Any, Tuple import nibabel as nib +import numpy as np import numpy.typing as npt from eva.vision.utils.io import _utils -def read_nifti_slice( - path: str, slice_index: int, *, use_storage_dtype: bool = True +def read_nifti( + path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True ) -> npt.NDArray[Any]: - """Reads and loads a NIfTI image from a file path as `uint8`. + """Reads and loads a NIfTI image from a file path. Args: path: The path to the NIfTI file. - slice_index: The image slice index to return. + slice_index: Whether to read only a slice from the file. use_storage_dtype: Whether to cast the raw image array to the inferred type. @@ -28,21 +29,42 @@ def read_nifti_slice( """ _utils.check_file(path) image_data = nib.load(path) # type: ignore - image_slice = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore - image_array = image_slice.get_fdata() + if slice_index is not None: + image_data = image_data.slicer[:, :, slice_index : slice_index + 1] # type: ignore + + image_array = image_data.get_fdata() # type: ignore if use_storage_dtype: image_array = image_array.astype(image_data.get_data_dtype()) # type: ignore + return image_array -def fetch_total_nifti_slices(path: str) -> int: - """Fetches the total slides of a NIfTI image file. +def save_array_as_nifti( + array: npt.ArrayLike, + filename: str, + *, + dtype: npt.DTypeLike | None = np.int64, +) -> None: + """Saved a numpy array as a NIfTI image file. + + Args: + array: The image array to save. + filename: The name to save the image like. + dtype: The data type to save the image. + """ + nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype) # type: ignore + nifti_image.header.get_xyzt_units() + nifti_image.to_filename(filename) + + +def fetch_nifti_shape(path: str) -> Tuple[int]: + """Fetches the NIfTI image shape from a file. Args: path: The path to the NIfTI file. Returns: - The number of the total available slides. + The image shape. Raises: FileExistsError: If the path does not exist or it is unreachable. @@ -50,5 +72,4 @@ def fetch_total_nifti_slices(path: str) -> int: """ _utils.check_file(path) image = nib.load(path) # type: ignore - image_shape = image.header.get_data_shape() # type: ignore - return image_shape[-1] + return image.header.get_data_shape() # type: ignore diff --git a/src/eva/vision/utils/io/text.py b/src/eva/vision/utils/io/text.py index 7f6089dd..34120238 100644 --- a/src/eva/vision/utils/io/text.py +++ b/src/eva/vision/utils/io/text.py @@ -4,15 +4,22 @@ from typing import Dict, List -def read_csv(path: str) -> List[Dict[str, str]]: +def read_csv( + path: str, + *, + delimiter: str = ",", + encoding: str = "utf-8", +) -> List[Dict[str, str]]: """Reads a CSV file and returns its contents as a list of dictionaries. Args: path: The path to the CSV file. + delimiter: The character that separates fields in the CSV file. + encoding: The encoding of the CSV file. Returns: A list of dictionaries representing the data in the CSV file. """ - with open(path, newline="") as file: - data = csv.DictReader(file, skipinitialspace=True) + with open(path, newline="", encoding=encoding) as file: + data = csv.DictReader(file, skipinitialspace=True, delimiter=delimiter) return list(data) diff --git a/tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv b/tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv new file mode 100644 index 00000000..557084d8 --- /dev/null +++ b/tests/eva/assets/vision/datasets/total_segmentator/Totalsegmentator_dataset_v201/meta.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81f0de47859c10a45bfd02550b4d37e5c3c9c954b44657d0f340e665ce8b2cc5 +size 399 diff --git a/tests/eva/core/metrics/defaults/segmentation/__init__.py b/tests/eva/core/metrics/defaults/segmentation/__init__.py new file mode 100644 index 00000000..358c4946 --- /dev/null +++ b/tests/eva/core/metrics/defaults/segmentation/__init__.py @@ -0,0 +1 @@ +"""Tests default metric groups for segmentation tasks.""" diff --git a/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py b/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py new file mode 100644 index 00000000..1f896ba1 --- /dev/null +++ b/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py @@ -0,0 +1,62 @@ +"""MulticlassSegmentationMetrics metric tests.""" + +import pytest +import torch + +from eva.core.metrics import defaults + +NUM_CLASSES_ONE = 3 +PREDS_ONE = torch.tensor( + [ + [1, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 2, 2], + [0, 0, 2, 2], + ], +) +TARGET_ONE = torch.tensor( + [ + [1, 1, 0, 0], + [1, 1, 2, 2], + [0, 0, 2, 2], + [0, 0, 2, 2], + ], + dtype=torch.int32, +) +EXPECTED_ONE = { + "MulticlassJaccardIndex": torch.tensor(0.4722222089767456), + "MulticlassF1Score": torch.tensor(0.6222222447395325), + "MulticlassPrecision": torch.tensor(0.6666666865348816), + "MulticlassRecall": torch.tensor(0.611011104490899), +} +"""Test features.""" + + +@pytest.mark.parametrize( + "num_classes, preds, target, expected", + [ + (NUM_CLASSES_ONE, PREDS_ONE, TARGET_ONE, EXPECTED_ONE), + ], +) +def test_multiclass_segmentation_metrics( + multiclass_segmentation_metrics: defaults.MulticlassSegmentationMetrics, + preds: torch.Tensor, + target: torch.Tensor, + expected: torch.Tensor, +) -> None: + """Tests the multiclass_segmentation_metrics metric.""" + + def _calculate_metric() -> None: + multiclass_segmentation_metrics.update(preds=preds, target=target) # type: ignore + actual = multiclass_segmentation_metrics.compute() + torch.testing.assert_close(actual, expected, rtol=1e-04, atol=1e-04) + + _calculate_metric() + multiclass_segmentation_metrics.reset() + _calculate_metric() + + +@pytest.fixture(scope="function") +def multiclass_segmentation_metrics(num_classes: int) -> defaults.MulticlassSegmentationMetrics: + """MulticlassSegmentationMetrics fixture.""" + return defaults.MulticlassSegmentationMetrics(num_classes=num_classes) diff --git a/tests/eva/core/models/modules/test_head.py b/tests/eva/core/models/modules/test_head.py index b3b0bc35..00e4b7a2 100644 --- a/tests/eva/core/models/modules/test_head.py +++ b/tests/eva/core/models/modules/test_head.py @@ -35,7 +35,10 @@ def test_head_module_fit( @pytest.fixture(scope="function") -def model(input_shape: Tuple[int, ...] = (3, 8, 8), n_classes: int = 4) -> modules.HeadModule: +def model( + input_shape: Tuple[int, ...] = (3, 8, 8), + n_classes: int = 4, +) -> modules.HeadModule: """Returns a HeadModule model fixture.""" return modules.HeadModule( head=nn.Linear(math.prod(input_shape), n_classes), diff --git a/tests/eva/core/models/modules/test_inference.py b/tests/eva/core/models/modules/test_inference.py index b17f6202..247abf46 100644 --- a/tests/eva/core/models/modules/test_inference.py +++ b/tests/eva/core/models/modules/test_inference.py @@ -46,5 +46,8 @@ def model( ) -> modules.InferenceModule: """Returns a HeadModule model fixture.""" return modules.InferenceModule( - backbone=nn.Sequential(nn.Flatten(), nn.Linear(math.prod(input_shape), n_classes)) + backbone=nn.Sequential( + nn.Flatten(), + nn.Linear(math.prod(input_shape), n_classes), + ) ) diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 9607a2a8..a2845b2a 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "split, expected_length", - [("train", 9), ("val", 9), (None, 9)], + [("train", 6), ("val", 3), (None, 9)], ) def test_length( total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int diff --git a/tests/eva/vision/data/transforms/common/test_resize_and_clamp.py b/tests/eva/vision/data/transforms/common/test_resize_and_clamp.py new file mode 100644 index 00000000..b3cb0887 --- /dev/null +++ b/tests/eva/vision/data/transforms/common/test_resize_and_clamp.py @@ -0,0 +1,47 @@ +"""Test the ResizeAndClamp augmentation.""" + +from typing import Tuple + +import pytest +import torch +from torch import testing +from torchvision import tv_tensors + +from eva.vision.data.transforms import common + + +@pytest.mark.parametrize( + "image_size, target_size, clamp_range, expected_size, expected_mean", + [ + ((3, 512, 224), [112, 224], [0, 255], (3, 112, 224), 0.498039186000824), + ((3, 512, 224), [112, 224], [100, 155], (3, 112, 224), 0.4909091293811798), + ((3, 224, 512), [112, 224], [0, 100], (3, 112, 224), 1.0), + ((3, 512, 224), [112, 97], [100, 255], (3, 112, 97), 0.17419354617595673), + ((3, 512, 512), 224, [0, 255], (3, 224, 224), 0.4980391561985016), + ], +) +def test_resize_and_clamp( + image_tensor: tv_tensors.Image, + resize_and_clamp: common.ResizeAndClamp, + expected_size: Tuple[int, int, int], + expected_mean: float, +) -> None: + """Tests the ResizeAndClamp transform.""" + output = resize_and_clamp(image_tensor) + assert output.shape == expected_size + testing.assert_close(torch.tensor(expected_mean), output.mean()) + + +@pytest.fixture(scope="function") +def resize_and_clamp( + target_size: Tuple[int, int, int], clamp_range: Tuple[int, int] +) -> common.ResizeAndClamp: + """Transform ResizeAndClamp fixture.""" + return common.ResizeAndClamp(size=target_size, clamp_range=clamp_range) + + +@pytest.fixture(scope="function") +def image_tensor(image_size: Tuple[int, int, int]) -> tv_tensors.Image: + """Image tensor fixture.""" + image_tensor = 127 * torch.ones(image_size, dtype=torch.uint8) + return tv_tensors.wrap(image_tensor, like=image_tensor) # type: ignore diff --git a/tests/eva/vision/data/transforms/normalization/__init__.py b/tests/eva/vision/data/transforms/normalization/__init__.py new file mode 100644 index 00000000..043d7a45 --- /dev/null +++ b/tests/eva/vision/data/transforms/normalization/__init__.py @@ -0,0 +1 @@ +"""Pixel normalization related transforms unit tests.""" diff --git a/tests/eva/vision/data/transforms/normalization/functional/__init__.py b/tests/eva/vision/data/transforms/normalization/functional/__init__.py new file mode 100644 index 00000000..d00105b3 --- /dev/null +++ b/tests/eva/vision/data/transforms/normalization/functional/__init__.py @@ -0,0 +1 @@ +"""Pixel normalization related functional transforms unit tests.""" diff --git a/tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py b/tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py new file mode 100644 index 00000000..5ee1f095 --- /dev/null +++ b/tests/eva/vision/data/transforms/normalization/functional/test_rescale_intensity.py @@ -0,0 +1,37 @@ +"""Test the rescale intensity transform.""" + +from typing import Tuple + +import pytest +import torch +from torch import testing +from torchvision import tv_tensors + +from eva.vision.data.transforms.normalization import functional + + +@pytest.mark.parametrize( + "image_size, in_range, out_range, expected_mean", + [ + ((3, 224, 224), (0, 255), (0.0, 1.0), 0.4980391561985016), + ((3, 224, 224), (0, 255), (-0.5, 0.5), -0.001960783964022994), + ((3, 224, 224), (100, 155), (-0.5, 0.5), -0.009090900421142578), + ((3, 224, 224), (100, 155), (0.0, 1.0), 0.4909091293811798), + ], +) +def test_rescale_intensity( + image_tensor: tv_tensors.Image, + in_range: Tuple[int, int], + out_range: Tuple[int, int], + expected_mean: float, +) -> None: + """Tests the rescale_intensity functional transform.""" + output = functional.rescale_intensity(image_tensor, in_range=in_range, out_range=out_range) + testing.assert_close(torch.tensor(expected_mean), output.mean()) + + +@pytest.fixture(scope="function") +def image_tensor(image_size: Tuple[int, int, int]) -> tv_tensors.Image: + """Image tensor fixture.""" + image_tensor = 127 * torch.ones(image_size, dtype=torch.uint8) + return tv_tensors.wrap(image_tensor, like=image_tensor) # type: ignore diff --git a/tests/eva/vision/models/modules/__init__.py b/tests/eva/vision/models/modules/__init__.py new file mode 100644 index 00000000..e22daaa8 --- /dev/null +++ b/tests/eva/vision/models/modules/__init__.py @@ -0,0 +1 @@ +"""Tests for the vision model modules.""" diff --git a/tests/eva/vision/models/modules/conftest.py b/tests/eva/vision/models/modules/conftest.py new file mode 100644 index 00000000..ef18e4e6 --- /dev/null +++ b/tests/eva/vision/models/modules/conftest.py @@ -0,0 +1,83 @@ +"""Shared configuration and fixtures for models/modules unit tests.""" + +from typing import Tuple + +import pytest +import torch +from torch.utils import data as torch_data + +from eva.core.data import dataloaders, datamodules, datasets +from eva.core.trainers import trainer as eva_trainer + + +@pytest.fixture(scope="function") +def datamodule( + request: pytest.FixtureRequest, + dataset_fixture: str, + dataloader: dataloaders.DataLoader, +) -> datamodules.DataModule: + """Returns a dummy datamodule fixture.""" + dataset = request.getfixturevalue(dataset_fixture) + return datamodules.DataModule( + datasets=datamodules.DatasetsSchema( + train=dataset, + val=dataset, + predict=dataset, + ), + dataloaders=datamodules.DataloadersSchema( + train=dataloader, + val=dataloader, + predict=dataloader, + ), + ) + + +@pytest.fixture(scope="function") +def trainer(max_epochs: int = 1) -> eva_trainer.Trainer: + """Returns a model trainer fixture.""" + return eva_trainer.Trainer( + max_epochs=max_epochs, + accelerator="cpu", + default_root_dir="logs/test", + ) + + +@pytest.fixture(scope="function") +def segmentation_dataset( + n_samples: int = 4, + input_shape: Tuple[int, ...] = (3, 16, 16), + target_shape: Tuple[int, ...] = (16, 16), + n_classes: int = 4, +) -> datasets.TorchDataset: + """Dummy segmentation dataset fixture.""" + return torch_data.TensorDataset( + torch.randn((n_samples,) + input_shape), + torch.randint(n_classes, (n_samples,) + target_shape, dtype=torch.long), + ) + + +@pytest.fixture(scope="function") +def segmentation_dataset_with_metadata( + n_samples: int = 4, + input_shape: Tuple[int, ...] = (3, 16, 16), + target_shape: Tuple[int, ...] = (16, 16), + n_classes: int = 4, +) -> datasets.TorchDataset: + """Dummy segmentation dataset fixture with metadata.""" + return torch_data.TensorDataset( + torch.randn((n_samples,) + input_shape), + torch.randint(n_classes, (n_samples,) + target_shape, dtype=torch.long), + torch.randint(2, (n_samples,) + target_shape, dtype=torch.long), + ) + + +@pytest.fixture(scope="function") +def dataloader(batch_size: int = 2) -> dataloaders.DataLoader: + """Test dataloader fixture.""" + return dataloaders.DataLoader( + batch_size=batch_size, + num_workers=0, + pin_memory=False, + persistent_workers=False, + prefetch_factor=None, + ) diff --git a/tests/eva/vision/models/modules/test_semantic_segmentation.py b/tests/eva/vision/models/modules/test_semantic_segmentation.py new file mode 100644 index 00000000..73776762 --- /dev/null +++ b/tests/eva/vision/models/modules/test_semantic_segmentation.py @@ -0,0 +1,59 @@ +"""Tests the HeadModule module.""" + +import pytest +import torch +from torch import nn + +from eva.core import metrics, trainers +from eva.core.data import datamodules +from eva.vision.models import modules +from eva.vision.models.networks import encoders +from eva.vision.models.networks.decoders import segmentation + + +@pytest.mark.parametrize( + "dataset_fixture", + [ + "segmentation_dataset", + "segmentation_dataset_with_metadata", + ], +) +def test_semantic_segmentation_module_fit( + model: modules.SemanticSegmentationModule, + datamodule: datamodules.DataModule, + trainer: trainers.Trainer, +) -> None: + """Tests the SemanticSegmentationModule fit pipeline.""" + initial_decoder_weights = model.decoder._layers.weight.clone() + trainer.fit(model, datamodule=datamodule) + # verify that the metrics were updated + assert trainer.logged_metrics["train/AverageLoss"] > 0 + assert trainer.logged_metrics["val/AverageLoss"] > 0 + # verify that head weights were updated + assert not torch.all(torch.eq(initial_decoder_weights, model.decoder._layers.weight)) + + +@pytest.fixture(scope="function") +def model(n_classes: int = 4) -> modules.SemanticSegmentationModule: + """Returns a SemanticSegmentationModule model fixture.""" + return modules.SemanticSegmentationModule( + decoder=segmentation.ConvDecoder( + layers=nn.Conv2d( + in_channels=192, + out_channels=n_classes, + kernel_size=(1, 1), + ), + ), + criterion=nn.CrossEntropyLoss(), + encoder=encoders.TimmEncoder( + model_name="vit_tiny_patch16_224", + pretrained=False, + out_indices=1, + model_arguments={ + "dynamic_img_size": True, + }, + ), + metrics=metrics.MetricsSchema( + common=metrics.AverageLoss(), + ), + ) diff --git a/tests/eva/vision/models/networks/decoders/__init__.py b/tests/eva/vision/models/networks/decoders/__init__.py new file mode 100644 index 00000000..b27386c5 --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/__init__.py @@ -0,0 +1 @@ +"""Vision decoders tests.""" diff --git a/tests/eva/vision/models/networks/decoders/segmentation/__init__.py b/tests/eva/vision/models/networks/decoders/segmentation/__init__.py new file mode 100644 index 00000000..984d0cb7 --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/segmentation/__init__.py @@ -0,0 +1 @@ +"""Vision segmentation decoders tests.""" diff --git a/tests/eva/vision/models/networks/decoders/segmentation/conv.py b/tests/eva/vision/models/networks/decoders/segmentation/conv.py new file mode 100644 index 00000000..59a27cc2 --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/segmentation/conv.py @@ -0,0 +1,70 @@ +"""Tests for convolutional decoder.""" + +from typing import List, Tuple + +import pytest +import torch +from torch import nn + +from eva.vision.models.networks.decoders import segmentation +from eva.vision.models.networks.decoders.segmentation import common + + +@pytest.mark.parametrize( + "layers, features, image_size, expected_shape", + [ + ( + nn.Conv2d(384, 5, kernel_size=(1, 1)), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(384, 64, kernel_size=(3, 3), padding=(1, 1)), + nn.Upsample(scale_factor=2), + nn.Conv2d(64, 5, kernel_size=(3, 3), padding=(1, 1)), + ), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + nn.Conv2d(768, 5, kernel_size=(1, 1)), + [torch.Tensor(2, 384, 14, 14), torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + common.ConvDecoder1x1(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + common.ConvDecoderMS(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ], +) +def test_conv_decoder( + conv_decoder: segmentation.ConvDecoder, + features: List[torch.Tensor], + image_size: Tuple[int, int], + expected_shape: torch.Size, +) -> None: + """Tests the ConvDecoder network.""" + logits = conv_decoder(features, image_size) + assert isinstance(logits, torch.Tensor) + assert logits.shape == expected_shape + + +@pytest.fixture(scope="function") +def conv_decoder( + layers: nn.Module, +) -> segmentation.ConvDecoder: + """ConvDecoder fixture.""" + return segmentation.ConvDecoder(layers=layers) diff --git a/tests/eva/vision/models/networks/decoders/segmentation/linear.py b/tests/eva/vision/models/networks/decoders/segmentation/linear.py new file mode 100644 index 00000000..156d5c49 --- /dev/null +++ b/tests/eva/vision/models/networks/decoders/segmentation/linear.py @@ -0,0 +1,59 @@ +"""Tests for linear decoder.""" + +from typing import List, Tuple + +import pytest +import torch +from torch import nn + +from eva.vision.models.networks.decoders import segmentation +from eva.vision.models.networks.decoders.segmentation import common + + +@pytest.mark.parametrize( + "layers, features, image_size, expected_shape", + [ + ( + nn.Linear(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + nn.Linear(768, 5), + [torch.Tensor(2, 384, 14, 14), torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + common.SingleLinearDecoder(384, 5), + [torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ( + common.SingleLinearDecoder(768, 5), + [torch.Tensor(2, 384, 14, 14), torch.Tensor(2, 384, 14, 14)], + (224, 224), + torch.Size([2, 5, 224, 224]), + ), + ], +) +def test_linear_decoder( + linear_decoder: segmentation.LinearDecoder, + features: List[torch.Tensor], + image_size: Tuple[int, int], + expected_shape: torch.Size, +) -> None: + """Tests the ConvDecoder network.""" + logits = linear_decoder(features, image_size) + assert isinstance(logits, torch.Tensor) + assert logits.shape == expected_shape + + +@pytest.fixture(scope="function") +def linear_decoder( + layers: nn.Module, +) -> segmentation.LinearDecoder: + """LinearDecoder fixture.""" + return segmentation.LinearDecoder(layers=layers) diff --git a/tests/eva/vision/models/networks/encoders/__init__.py b/tests/eva/vision/models/networks/encoders/__init__.py new file mode 100644 index 00000000..5760a3d4 --- /dev/null +++ b/tests/eva/vision/models/networks/encoders/__init__.py @@ -0,0 +1 @@ +"""Test for vision encoders.""" diff --git a/tests/eva/vision/models/networks/encoders/test_from_timm.py b/tests/eva/vision/models/networks/encoders/test_from_timm.py new file mode 100644 index 00000000..0c356810 --- /dev/null +++ b/tests/eva/vision/models/networks/encoders/test_from_timm.py @@ -0,0 +1,66 @@ +"""TimmEncoder tests.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch + +from eva.vision.models.networks import encoders + + +@pytest.mark.parametrize( + "model_name, out_indices, model_arguments, input_tensor, expected_len, expected_shape", + [ + ( + "vit_small_patch16_224", + 1, + None, + torch.Tensor(2, 3, 224, 224), + 1, + torch.Size([2, 384, 14, 14]), + ), + ( + "vit_small_patch16_224", + 3, + None, + torch.Tensor(2, 3, 224, 224), + 3, + torch.Size([2, 384, 14, 14]), + ), + ( + "vit_small_patch16_224", + 3, + {"dynamic_img_size": True}, + torch.Tensor(2, 3, 512, 512), + 3, + torch.Size([2, 384, 32, 32]), + ), + ], +) +def test_timm_encoder( + timm_encoder: encoders.TimmEncoder, + input_tensor: torch.Tensor, + expected_len: int, + expected_shape: torch.Size, +) -> None: + """Tests the TimmEncoder network.""" + outputs = timm_encoder(input_tensor) + assert isinstance(outputs, list) + assert len(outputs) == expected_len + # individual + assert isinstance(outputs[0], torch.Tensor) + assert outputs[0].shape == expected_shape + + +@pytest.fixture(scope="function") +def timm_encoder( + model_name: str, + out_indices: int | Tuple[int, ...] | None, + model_arguments: Dict[str, Any] | None, +) -> encoders.TimmEncoder: + """TimmEncoder fixture.""" + return encoders.TimmEncoder( + model_name=model_name, + out_indices=out_indices, + model_arguments=model_arguments, + ) diff --git a/tests/eva/vision/models/networks/test_from_timm.py b/tests/eva/vision/models/networks/test_from_timm.py new file mode 100644 index 00000000..384a4f27 --- /dev/null +++ b/tests/eva/vision/models/networks/test_from_timm.py @@ -0,0 +1,48 @@ +"""TimmModel tests.""" + +from typing import Any, Dict + +import pytest +import torch + +from eva.vision.models import networks + + +@pytest.mark.parametrize( + "model_name, model_arguments, input_tensor, expected_shape", + [ + ( + "vit_small_patch16_224", + {"num_classes": 0}, + torch.Tensor(2, 3, 224, 224), + torch.Size([2, 384]), + ), + ( + "vit_small_patch16_224", + {"num_classes": 0, "dynamic_img_size": True}, + torch.Tensor(2, 3, 512, 512), + torch.Size([2, 384]), + ), + ], +) +def test_timm_model( + timm_model: networks.TimmModel, + input_tensor: torch.Tensor, + expected_shape: torch.Size, +) -> None: + """Tests the timm_model network.""" + outputs = timm_model(input_tensor) + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == expected_shape + + +@pytest.fixture(scope="function") +def timm_model( + model_name: str, + model_arguments: Dict[str, Any] | None, +) -> networks.TimmModel: + """TimmModel fixture.""" + return networks.TimmModel( + model_name=model_name, + model_arguments=model_arguments, + ) diff --git a/tests/eva/vision/test_vision_cli.py b/tests/eva/vision/test_vision_cli.py index 174f0c32..0eef61d6 100644 --- a/tests/eva/vision/test_vision_cli.py +++ b/tests/eva/vision/test_vision_cli.py @@ -18,6 +18,7 @@ "configs/vision/dino_vit/online/crc.yaml", "configs/vision/dino_vit/online/mhist.yaml", "configs/vision/dino_vit/online/patch_camelyon.yaml", + "configs/vision/dino_vit/online/total_segmentator_2d.yaml", "configs/vision/dino_vit/offline/bach.yaml", "configs/vision/dino_vit/offline/crc.yaml", "configs/vision/dino_vit/offline/mhist.yaml", diff --git a/tests/eva/vision/utils/test_convert.py b/tests/eva/vision/utils/test_convert.py index a9eb05eb..4692b8a8 100644 --- a/tests/eva/vision/utils/test_convert.py +++ b/tests/eva/vision/utils/test_convert.py @@ -1,14 +1,11 @@ """Tests image conversion related functionalities.""" -from typing import Any - -import numpy as np -import numpy.typing as npt +import torch from pytest import mark from eva.vision.utils import convert -IMAGE_ARRAY_INT16 = np.array( +IMAGE_ONE = torch.Tensor( [ [ [-794, -339, -607, -950], @@ -17,34 +14,32 @@ [-790, -325, -564, -969], ] ], - dtype=np.int16, -) +).to(dtype=torch.float16) """Test input data.""" -EXPECTED_ARRAY_INT16 = np.array( +EXPECTED_ONE = torch.Tensor( [ [ [33, 120, 69, 3], - [68, 169, 217, 25], + [69, 169, 217, 25], [74, 183, 255, 25], [34, 123, 77, 0], ] ], - dtype=np.uint8, -) +).to(dtype=torch.uint8) """Test expected/desired features.""" @mark.parametrize( - "image_array, expected", + "image, expected", [ - [IMAGE_ARRAY_INT16, EXPECTED_ARRAY_INT16], + [IMAGE_ONE, EXPECTED_ONE], ], ) -def test_to_8bit( - image_array: npt.NDArray[Any], - expected: npt.NDArray[np.uint8], +def test_descale_and_denorm_image( + image: torch.Tensor, + expected: torch.Tensor, ) -> None: - """Tests the `to_8bit` image conversion.""" - actual = convert.to_8bit(image_array) - np.testing.assert_allclose(actual, expected) + """Tests the `descale_and_denorm_image` image conversion.""" + actual = convert.descale_and_denorm_image(image, mean=(0.0,), std=(1.0,)) + torch.testing.assert_close(actual, expected)