Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support semantic segmentation downstream evaluation tasks #517

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2b5a1ce
Add simple decoder networks for segmentation tasks (#404)
ioangatop Apr 30, 2024
a7e65e5
Add `timm` encoder networks (#403)
ioangatop Apr 30, 2024
ad392e8
Add SemanticSegmentation module (#410)
ioangatop May 6, 2024
4ef3afb
Add `TotalSegmentator2D` segmentation downstream task (#413)
ioangatop May 7, 2024
07fbd08
Add dice score in `TotalSegmentator2D` task (#423)
ioangatop May 7, 2024
a6fd246
Allow to use subclasses in `TotalSegmentator2D` (#435)
ioangatop May 13, 2024
db5a578
Create a callback to visualise the segmentation results (#424)
ioangatop May 13, 2024
7656ae6
Improve the mask loading in `TotalSegmentator2D` (#440)
ioangatop May 15, 2024
702f634
Add per class metrics dice score in `TotalSegmentator2D` (#447)
ioangatop May 16, 2024
434cacb
Support `int16` training on `TotalSegementator2D` (#443)
ioangatop May 16, 2024
35bb532
Normalisations and transforms for `int16` image types (#457)
ioangatop May 22, 2024
c43b8ad
Fix default segmentation metrics (#503)
ioangatop Jun 6, 2024
b009aec
Add support for multi-level embeddings training in segmentation tasks…
ioangatop Jun 6, 2024
19c92cd
merge with main
ioangatop Jun 6, 2024
cf54cad
update with main
ioangatop Jun 6, 2024
cfa47c9
Add dataset licence and env var for dataset download (#516)
ioangatop Jun 10, 2024
cc23d6a
Merge branch 'main' into 402-aggregated-feature-segmentation-downstre…
ioangatop Jun 10, 2024
6fefade
Minor updates on semantic segmentation tasks (#519)
ioangatop Jun 10, 2024
3b326a0
Fix SemanticSegmentationLogger callback write frequency (#521)
ioangatop Jun 10, 2024
fb948a0
Add `ModelCheckpoint` and `EarlyStopping` in TotalSegmentator2D` task…
ioangatop Jun 10, 2024
4ac910e
Minor fixes in `SemanticSegmentationModule` (#525)
ioangatop Jun 10, 2024
6a00235
Support the full `TotalSegmentator2D` dataset (#535)
ioangatop Jun 12, 2024
85a4aba
Merge branch 'main' into 402-aggregated-feature-segmentation-downstre…
ioangatop Jun 12, 2024
3d523ca
Add test dataloader in TotalSegmentator2D
ioangatop Jun 12, 2024
157c5a8
Merge branch 'main' into 402-aggregated-feature-segmentation-downstre…
ioangatop Jun 18, 2024
bb9ca72
Merge branch 'main' into 402-aggregated-feature-segmentation-downstre…
ioangatop Jun 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,9 @@ cython_debug/

# ignore local data
/data/

# numpy data
*.npy

# NiFti data
*.nii.gz
109 changes: 109 additions & 0 deletions configs/vision/dino_vit/online/total_segmentator_2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/total_segmentator_2d/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000}
callbacks:
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5]
std: &NORMALIZE_STD [0.5, 0.5, 0.5]
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassJaccardIndex}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 5
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
encoder:
class_path: eva.vision.models.networks.encoders.TimmEncoder
init_args:
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}
pretrained: true
out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 1}
model_arguments:
dynamic_img_size: true
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:DECODER_IN_FEATURES, 384}
num_classes: &NUM_CLASSES 118
criterion: torch.nn.CrossEntropyLoss
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.0001
weight_decay: 0.05
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: eva.core.metrics.wrappers.ClasswiseWrapper
init_args:
metric:
class_path: torchmetrics.classification.MulticlassF1Score
init_args:
num_classes: *NUM_CLASSES
average: null
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.TotalSegmentator2D
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data}/total_segmentator
split: train
download: ${oc.env:DOWNLOAD_DATA, false}
# Set `download: true` to download the dataset from https://zenodo.org/records/10047292
# The TotalSegmentator dataset is distributed under the following license:
# "Creative Commons Attribution 4.0 International"
# (see: https://creativecommons.org/licenses/by/4.0/deed.en)
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
mean: *NORMALIZE_MEAN
std: *NORMALIZE_STD
val:
class_path: eva.vision.datasets.TotalSegmentator2D
init_args:
<<: *DATASET_ARGS
split: val
test:
class_path: eva.vision.datasets.TotalSegmentator2D
init_args:
<<: *DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 16}
shuffle: true
val:
batch_size: *BATCH_SIZE
test:
batch_size: *BATCH_SIZE
9 changes: 4 additions & 5 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ vision = [
"h5py>=3.10.0",
"nibabel>=5.2.0",
"opencv-python-headless>=4.9.0.80",
"timm>=0.9.12",
"timm @ git+https://github.com/huggingface/pytorch-image-models.git@main",
ioangatop marked this conversation as resolved.
Show resolved Hide resolved
"torchvision>=0.17.0",
"openslide-python>=1.3.1",
]
all = [
"h5py>=3.10.0",
"nibabel>=5.2.0",
"opencv-python-headless>=4.9.0.80",
"timm>=0.9.12",
"timm @ git+https://github.com/huggingface/pytorch-image-models.git@main",
"torchvision>=0.17.0",
"openslide-python>=1.3.1",
]
Expand Down
2 changes: 1 addition & 1 deletion src/eva/core/callbacks/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Callbacks API."""
"""Writers callbacks API."""

from eva.core.callbacks.writers.embeddings import ClassificationEmbeddingsWriter

Expand Down
5 changes: 3 additions & 2 deletions src/eva/core/loggers/log/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Experiment loggers actions."""
"""Experiment loggers operations."""

from eva.core.loggers.log.image import log_image
from eva.core.loggers.log.parameters import log_parameters

__all__ = ["log_parameters"]
__all__ = ["log_image", "log_parameters"]
59 changes: 59 additions & 0 deletions src/eva/core/loggers/log/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Image log functionality."""

import functools

import torch

from eva.core.loggers import loggers
from eva.core.loggers.log import utils


@functools.singledispatch
def log_image(
logger,
tag: str,
image: torch.Tensor,
step: int = 0,
) -> None:
"""Adds an image to the logger.

Args:
logger: The desired logger.
tag: The log tag.
image: The image tensor to log. It should have
the shape of (3,H,W) and (0,1) normalized.
step: The global step of the log.
"""
utils.raise_not_supported(logger, "image")


@log_image.register
def _(
loggers: list,
tag: str,
image: torch.Tensor,
step: int = 0,
) -> None:
"""Adds an image to a list of supported loggers."""
for logger in loggers:
log_image(
logger,
tag=tag,
image=image,
step=step,
)


@log_image.register
def _(
logger: loggers.TensorBoardLogger,
tag: str,
image: torch.Tensor,
step: int = 0,
) -> None:
"""Adds an image to a TensorBoard logger."""
logger.experiment.add_image(
tag=tag,
img_tensor=image,
global_step=step,
)
6 changes: 6 additions & 0 deletions src/eva/core/loggers/loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Experimental loggers."""

from lightning.pytorch.loggers import TensorBoardLogger

Loggers = TensorBoardLogger
"""Supported loggers."""
13 changes: 10 additions & 3 deletions src/eva/core/metrics/defaults/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Default metric collections API."""

from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
from eva.core.metrics.defaults.classification import (
BinaryClassificationMetrics,
MulticlassClassificationMetrics,
)
from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics

__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
__all__ = [
"MulticlassClassificationMetrics",
"BinaryClassificationMetrics",
"MulticlassSegmentationMetrics",
]
2 changes: 1 addition & 1 deletion src/eva/core/metrics/defaults/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics

__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
__all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"]
9 changes: 0 additions & 9 deletions src/eva/core/metrics/defaults/classification/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@ def __init__(
) -> None:
"""Initializes the binary classification metrics.

The metrics instantiated here are:

- BinaryAUROC
- BinaryAccuracy
- BinaryBalancedAccuracy
- BinaryF1Score
- BinaryPrecision
- BinaryRecall

Args:
threshold: Threshold for transforming probability to binary (0,1) predictions
ignore_index: Specifies a target value that is ignored and does not
Expand Down
8 changes: 0 additions & 8 deletions src/eva/core/metrics/defaults/classification/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ def __init__(
) -> None:
"""Initializes the multi-class classification metrics.

The metrics instantiated here are:

- MulticlassAccuracy
- MulticlassPrecision
- MulticlassRecall
- MulticlassF1Score
- MulticlassAUROC

Args:
num_classes: Integer specifying the number of classes.
average: Defines the reduction that is applied over labels.
Expand Down
5 changes: 5 additions & 0 deletions src/eva/core/metrics/defaults/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Default segmentation metric collections API."""

from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics

__all__ = ["MulticlassSegmentationMetrics"]
66 changes: 66 additions & 0 deletions src/eva/core/metrics/defaults/segmentation/multiclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Default metric collection for multiclass semantic segmentation tasks."""

from typing import Literal

from torchmetrics import classification

from eva.core.metrics import structs


class MulticlassSegmentationMetrics(structs.MetricCollection):
"""Default metrics for multi-class semantic segmentation tasks."""

def __init__(
self,
num_classes: int,
average: Literal["macro", "weighted", "none"] = "macro",
ignore_index: int | None = None,
prefix: str | None = None,
postfix: str | None = None,
) -> None:
"""Initializes the multi-class semantic segmentation metrics.

Args:
num_classes: Integer specifying the number of classes.
average: Defines the reduction that is applied over labels.
ignore_index: Specifies a target value that is ignored and
does not contribute to the metric calculation.
prefix: A string to add before the keys in the output dictionary.
postfix: A string to add after the keys in the output dictionary.
"""
super().__init__(
metrics=[
classification.MulticlassJaccardIndex(
num_classes=num_classes,
average=average,
ignore_index=ignore_index,
),
classification.MulticlassF1Score(
nkaenzig marked this conversation as resolved.
Show resolved Hide resolved
num_classes=num_classes,
average=average,
ignore_index=ignore_index,
),
classification.MulticlassPrecision(
num_classes=num_classes,
average=average,
ignore_index=ignore_index,
),
classification.MulticlassRecall(
num_classes=num_classes,
average=average,
ignore_index=ignore_index,
),
],
prefix=prefix,
postfix=postfix,
compute_groups=[
[
"MulticlassJaccardIndex",
],
[
"MulticlassF1Score",
"MulticlassPrecision",
"MulticlassRecall",
],
],
)
4 changes: 3 additions & 1 deletion src/eva/core/metrics/structs/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ def _join_with_common(self, metrics: MetricModuleType | None) -> MetricModuleTyp
if metrics is None or self.common is None:
return self.common or metrics

return [self.common, metrics] # type: ignore
metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore
common = self.common if isinstance(self.common, list) else [self.common]
return common + metrics # type: ignore
5 changes: 5 additions & 0 deletions src/eva/core/metrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Metric wrappers API."""

from eva.core.metrics.wrappers.classwise import ClasswiseWrapper

__all__ = ["ClasswiseWrapper"]
Loading