Skip to content

Commit

Permalink
refactored segmentqtion decoder setup to support original images as i…
Browse files Browse the repository at this point in the history
…nput
  • Loading branch information
nkaenzig committed Oct 30, 2024
1 parent 4ca244b commit 38b0707
Show file tree
Hide file tree
Showing 14 changed files with 192 additions and 48 deletions.
7 changes: 4 additions & 3 deletions src/eva/vision/models/modules/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from eva.core.models.modules.utils import batch_postprocess, grad
from eva.core.utils import parser
from eva.vision.models.networks import decoders
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs


class SemanticSegmentationModule(module.ModelModule):
Expand Down Expand Up @@ -101,9 +102,9 @@ def forward(
"Please provide the expected `to_size` that the "
"decoder should map the embeddings (`inputs`) to."
)

patch_embeddings = self.encoder(inputs) if self.encoder else inputs
return self.decoder(patch_embeddings, to_size or inputs.shape[-2:])
features = self.encoder(inputs) if self.encoder else inputs
decoder_inputs = DecoderInputs(features, inputs.shape[-2:], inputs) # type: ignore
return self.decoder(decoder_inputs)

@override
def training_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
Expand Down
2 changes: 1 addition & 1 deletion src/eva/vision/models/networks/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Decoder heads API."""

from eva.vision.models.networks.decoders import segmentation
from eva.vision.models.networks.decoders.decoder import Decoder
from eva.vision.models.networks.decoders.segmentation.base import Decoder

__all__ = ["segmentation", "Decoder"]
7 changes: 0 additions & 7 deletions src/eva/vision/models/networks/decoders/decoder.py

This file was deleted.

16 changes: 12 additions & 4 deletions src/eva/vision/models/networks/decoders/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Segmentation decoder heads API."""

from eva.vision.models.networks.decoders.segmentation.common import (
from eva.vision.models.networks.decoders.segmentation.decoder2d import Decoder2D
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder
from eva.vision.models.networks.decoders.segmentation.semantic import (
ConvDecoder1x1,
ConvDecoderMS,
ConvDecoderWithImagePrior,
SingleLinearDecoder,
)
from eva.vision.models.networks.decoders.segmentation.conv2d import ConvDecoder
from eva.vision.models.networks.decoders.segmentation.linear import LinearDecoder

__all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoder", "LinearDecoder"]
__all__ = [
"ConvDecoder1x1",
"ConvDecoderMS",
"SingleLinearDecoder",
"ConvDecoderWithImagePrior",
"Decoder2D",
"LinearDecoder",
]
16 changes: 16 additions & 0 deletions src/eva/vision/models/networks/decoders/segmentation/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Semantic segmentation decoder base class."""

import abc

import torch
from torch import nn

from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs


class Decoder(nn.Module, abc.ABC):
"""Abstract base class for segmentation decoders."""

@abc.abstractmethod
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
"""Forward pass of the decoder."""
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
"""Convolutional based semantic segmentation decoder."""

from typing import List, Tuple
from typing import List, Sequence, Tuple

import torch
from torch import nn
from torch.nn import functional

from eva.vision.models.networks.decoders import decoder
from eva.vision.models.networks.decoders.segmentation import base
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs


class ConvDecoder(decoder.Decoder):
"""Convolutional segmentation decoder."""
class Decoder2D(base.Decoder):
"""Segmentation decoder for 2D applications."""

def __init__(self, layers: nn.Module) -> None:
"""Initializes the convolutional based decoder head.
def __init__(self, layers: nn.Module, combine_features: bool = True) -> None:
"""Initializes the based decoder head.
Here the input nn layers will be directly applied to the
features of shape (batch_size, hidden_size, n_patches_height,
n_patches_width), where n_patches is image_size / patch_size.
Note the n_patches is also known as grid_size.
Args:
layers: The convolutional layers to be used as the decoder head.
layers: The layers to be used as the decoder head.
combine_features: Whether to combine the features from different
feature levels into one tensor before applying the decoder head.
"""
super().__init__()

self._layers = layers
self._combine_features = combine_features

def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torch.Tensor:
"""Forward function for multi-level feature maps to a single one.
It will interpolate the features and concat them into a single tensor
Expand Down Expand Up @@ -63,7 +67,9 @@ def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
]
return torch.cat(upsampled_features, dim=1)

def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
def _forward_head(
self, patch_embeddings: torch.Tensor | Sequence[torch.Tensor]
) -> torch.Tensor:
"""Forward of the decoder head.
Args:
Expand All @@ -75,12 +81,12 @@ def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
"""
return self._layers(patch_embeddings)

def _cls_seg(
def _upscale(
self,
logits: torch.Tensor,
image_size: Tuple[int, int],
) -> torch.Tensor:
"""Classify each pixel of the image.
"""Upscales the calculated logits to the target image size.
Args:
logits: The decoder outputs of shape (batch_size, n_classes,
Expand All @@ -93,22 +99,18 @@ def _cls_seg(
"""
return functional.interpolate(logits, image_size, mode="bilinear")

def forward(
self,
features: List[torch.Tensor],
image_size: Tuple[int, int],
) -> torch.Tensor:
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
"""Maps the patch embeddings to a segmentation mask of the image size.
Args:
features: List of multi-level image features of shape (batch_size,
hidden_size, n_patches_height, n_patches_width).
image_size: The target image size (height, width).
decoder_inputs: Inputs required by the decoder.
Returns:
Tensor containing scores for all of the classes with shape
(batch_size, n_classes, image_height, image_width).
"""
patch_embeddings = self._forward_features(features)
logits = self._forward_head(patch_embeddings)
return self._cls_seg(logits, image_size)
features, image_size, _ = DecoderInputs(*decoder_inputs)
if self._combine_features:
features = self._forward_features(features)
logits = self._forward_head(features)
return self._upscale(logits, image_size)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from torch import nn
from torch.nn import functional

from eva.vision.models.networks.decoders import decoder
from eva.vision.models.networks.decoders.segmentation import base


class LinearDecoder(decoder.Decoder):
class LinearDecoder(base.Decoder):
"""Linear decoder."""

def __init__(self, layers: nn.Module) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Semantic Segmentation decoder heads API."""

from eva.vision.models.networks.decoders.segmentation.semantic.common import (
ConvDecoder1x1,
ConvDecoderMS,
SingleLinearDecoder,
)
from eva.vision.models.networks.decoders.segmentation.semantic.with_image_prior import (
ConvDecoderWithImagePrior,
)

__all__ = ["ConvDecoder1x1", "ConvDecoderMS", "SingleLinearDecoder", "ConvDecoderWithImagePrior"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from torch import nn

from eva.vision.models.networks.decoders.segmentation import conv2d, linear
from eva.vision.models.networks.decoders.segmentation import decoder2d, linear


class ConvDecoder1x1(conv2d.ConvDecoder):
class ConvDecoder1x1(decoder2d.Decoder2D):
"""A convolutional decoder with a single 1x1 convolutional layer."""

def __init__(self, in_features: int, num_classes: int) -> None:
Expand All @@ -29,7 +29,7 @@ def __init__(self, in_features: int, num_classes: int) -> None:
)


class ConvDecoderMS(conv2d.ConvDecoder):
class ConvDecoderMS(decoder2d.Decoder2D):
"""A multi-stage convolutional decoder with upsampling and convolutional layers.
This decoder applies a series of upsampling and convolutional layers to transform
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Convolutional semantic segmentation decoders that use input image & feature maps as input."""

from typing import List

import torch
from torch import nn
from torchvision.transforms.functional import rgb_to_grayscale
from typing_extensions import override

from eva.vision.models.networks.decoders.segmentation import decoder2d
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs


class ConvDecoderWithImagePrior(decoder2d.Decoder2D):
"""A convolutional that in addition to encoded features, also takes the input image as input.
In a first stage, the input features are upsampled and passed through a convolutional layer,
while in the second stage, the input image channels are concatenated with the upsampled features
and passed through additional convolutional blocks in order to combine the image prior
information with the encoded features. Lastly, a 1x1 conv operation reduces the number of
channels to the number of classes.
"""

_default_hidden_dims = [64, 32, 32]

def __init__(
self,
in_features: int,
num_classes: int,
greyscale: bool = False,
hidden_dims: List[int] | None = None,
) -> None:
"""Initializes the decoder.
Args:
in_features: The hidden dimension size of the embeddings.
num_classes: Number of output classes as channels.
greyscale: Whether to convert input images to greyscale.
hidden_dims: List of hidden dimensions for the convolutional layers.
"""
hidden_dims = hidden_dims or self._default_hidden_dims
if len(hidden_dims) != 3:
raise ValueError("Hidden dims must have 3 elements.")

super().__init__(
layers=nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBnReLU(in_features, hidden_dims[0]),
)
)
self.greyscale = greyscale

additional_hidden_dims = 1 if greyscale else 3
self.image_block = nn.Sequential(
Conv2dBnReLU(hidden_dims[0] + additional_hidden_dims, hidden_dims[1]),
Conv2dBnReLU(hidden_dims[1], hidden_dims[2]),
)

self.classifier = nn.Conv2d(hidden_dims[2], num_classes, kernel_size=1)

@override
def forward(self, decoder_inputs: DecoderInputs) -> torch.Tensor:
if decoder_inputs.images is None:
raise ValueError("Input images are missing.")

logits = super().forward(decoder_inputs)
in_images = (
rgb_to_grayscale(decoder_inputs.images) if self.greyscale else decoder_inputs.images
)
logits = torch.cat([logits, in_images], dim=1)
logits = self.image_block(logits)

return self.classifier(logits)


class Conv2dBnReLU(nn.Sequential):
"""A single convolutional layer with batch normalization and ReLU activation."""

def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1
) -> None:
"""Initializes the layer.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: Size of the convolutional kernel.
padding: Padding size for the convolutional layer.
"""
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
18 changes: 18 additions & 0 deletions src/eva/vision/models/networks/decoders/segmentation/typings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Type-hints for segmentation decoders."""

from typing import List, NamedTuple, Tuple

import torch


class DecoderInputs(NamedTuple):
"""Input scheme for segmentation decoders."""

features: List[torch.Tensor]
"""List of image features generated by the encoder from the original images."""

image_size: Tuple[int, int]
"""Size of the original input images to be used for upsampling."""

images: torch.Tensor | None = None
"""The original input images for which the encoder generated the encoded_images."""
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_semantic_segmentation_module_fit(
def model(n_classes: int = 4) -> modules.SemanticSegmentationModule:
"""Returns a SemanticSegmentationModule model fixture."""
return modules.SemanticSegmentationModule(
decoder=segmentation.ConvDecoder(
decoder=segmentation.Decoder2D(
layers=nn.Conv2d(
in_channels=192,
out_channels=n_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from eva.vision.models.networks.decoders import segmentation
from eva.vision.models.networks.decoders.segmentation import common
from eva.vision.models.networks.decoders.segmentation.semantic import common


@pytest.mark.parametrize(
Expand Down Expand Up @@ -51,7 +51,7 @@
],
)
def test_conv_decoder(
conv_decoder: segmentation.ConvDecoder,
conv_decoder: segmentation.Decoder2D,
features: List[torch.Tensor],
image_size: Tuple[int, int],
expected_shape: torch.Size,
Expand All @@ -65,6 +65,6 @@ def test_conv_decoder(
@pytest.fixture(scope="function")
def conv_decoder(
layers: nn.Module,
) -> segmentation.ConvDecoder:
) -> segmentation.Decoder2D:
"""ConvDecoder fixture."""
return segmentation.ConvDecoder(layers=layers)
return segmentation.Decoder2D(layers=layers)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from eva.vision.models.networks.decoders import segmentation
from eva.vision.models.networks.decoders.segmentation import common
from eva.vision.models.networks.decoders.segmentation.semantic import common


@pytest.mark.parametrize(
Expand Down

0 comments on commit 38b0707

Please sign in to comment.