Skip to content

Commit

Permalink
Add conv2d and linear decoders
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop committed Apr 24, 2024
1 parent 767fbbd commit d112070
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 0 deletions.
Empty file added main.py
Empty file.
7 changes: 7 additions & 0 deletions src/eva/vision/models/networks/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Decoder heads API."""

from eva.vision.models.networks.decoders.conv import ConvDecoder
from eva.vision.models.networks.decoders.decoder import Decoder
from eva.vision.models.networks.decoders.linear import LinearDecoder

__all__ = ["ConvDecoder", "Decoder", "LinearDecoder"]
88 changes: 88 additions & 0 deletions src/eva/vision/models/networks/decoders/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Convolutional based semantic segmentation decoder."""

from typing import List, Tuple

import torch
from torch import nn
from torch.nn import functional

from eva.vision.models.networks.decoders import decoder


class ConvDecoder(decoder.Decoder):
"""Convolutional segmentation decoder."""

def __init__(self, layers: nn.Module) -> None:
"""Initializes the convolutional based decoder head.
Here the input nn layers will be directly applied to the
features of shape (batch_size, hidden_size, height, width)
Args:
layers: The convolutional layers to be used as the decoder head.
"""
super().__init__()

self._layers = layers

def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
"""Forward function for multi-level feature maps to a single one.
Args:
features: List of multi-level image features of shape (batch_size,
hidden_size, num_patches_height, num_patches_width).
Returns:
A tensor of shape (batch_size, hidden_size, num_patches_height,
num_patches_width) which is feature map of the decoder head.
"""
return features[-1]

def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
"""Forward of the decoder head.
Args:
patch_embeddings: The model patch embeddings reshaped to
(batch_size, hidden_size, num_patches_height, num_patches_width).
Returns:
The logits as a tensor (batch_size, num_classes, height, width).
"""
return self._layers(patch_embeddings)

def _cls_seg(
self,
logits: torch.Tensor,
image_size: Tuple[int, int],
) -> torch.Tensor:
"""Classify each pixel of the image.
Args:
logits: The decoder outputs of shape
(batch_size, num_classes, height, width).
image_size: The target image size (height, width).
Returns:
Tensor containing scores for all of the classes with shape
(batch_size, num_classes, image_height, image_width).
"""
return functional.interpolate(logits, image_size, mode="bilinear")

def forward(
self,
features: List[torch.Tensor],
image_size: Tuple[int, int],
) -> torch.Tensor:
"""Maps the patch embeddings to a segmentation mask of the image size.
Args:
features: List of multi-level image features.
image_size: The target image size (height, width).
Returns:
Tensor containing scores for all of the classes with shape
(batch_size, num_classes, image_height, image_width).
"""
patch_embeddings = self._forward_features(features)
logits = self._forward_head(patch_embeddings)
return self._cls_seg(logits, image_size)
7 changes: 7 additions & 0 deletions src/eva/vision/models/networks/decoders/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Semantic segmentation decoder base class."""

from torch import nn


class Decoder(nn.Module):
"""Semantic segmentation decoder base class."""
93 changes: 93 additions & 0 deletions src/eva/vision/models/networks/decoders/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Linear based decoder."""

from typing import List, Tuple

import torch
from torch import nn
from torch.nn import functional

from eva.vision.models.networks.decoders import decoder


class LinearDecoder(decoder.Decoder):
"""Linear decoder."""

def __init__(self, layers: nn.Module) -> None:
"""Initializes the linear based decoder head.
Here the input nn layers will be applied to the reshaped
features (batch_size, patch_embeddings, hidden_size) from
the input (batch_size, hidden_size, height, width) and then
unwrapped again to (batch_size, num_classes, height, width).
Args:
layers: The linear layers to be used as the decoder head.
"""
super().__init__()

self._layers = layers

def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
"""Forward function for multi-level feature maps to a single one.
Args:
features: List of multi-level image features of shape (batch_size,
hidden_size, num_patches_height, num_patches_width).
Returns:
A tensor of shape (batch_size, hidden_size, num_patches_height,
num_patches_width) which is feature map of the decoder head.
"""
return features[-1]

def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
"""Forward of the decoder head.
Args:
patch_embeddings: The model patch embeddings reshaped to
(batch_size, hidden_size, num_patches_height, num_patches_width).
Returns:
The logits as a tensor (batch_size, num_classes, height, width).
"""
batch_size, _, height, width = patch_embeddings.shape
embeddings_reshaped = patch_embeddings.reshape(batch_size, _, height * width)
logits = self._layers(embeddings_reshaped.permute(0, 2, 1))
return logits.permute(0, 2, 1).reshape(batch_size, -1, height, width)

def _cls_seg(
self,
logits: torch.Tensor,
image_size: Tuple[int, int],
) -> torch.Tensor:
"""Classify each pixel of the image.
Args:
logits: The decoder outputs of shape
(batch_size, num_classes, height, width).
image_size: The target image size (height, width).
Returns:
Tensor containing scores for all of the classes with shape
(batch_size, num_classes, image_height, image_width).
"""
return functional.interpolate(logits, image_size, mode="bilinear")

def forward(
self,
features: List[torch.Tensor],
image_size: Tuple[int, int],
) -> torch.Tensor:
"""Maps the patch embeddings to a segmentation mask of the image size.
Args:
features: List of multi-level image features.
image_size: The target image size (height, width).
Returns:
Tensor containing scores for all of the classes with shape
(batch_size, num_classes, image_height, image_width).
"""
patch_embeddings = self._forward_features(features)
logits = self._forward_head(patch_embeddings)
return self._cls_seg(logits, image_size)
1 change: 1 addition & 0 deletions tests/eva/vision/models/networks/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Vision segmentation decoders tests."""
51 changes: 51 additions & 0 deletions tests/eva/vision/models/networks/decoders/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for convolutional decoder."""

from typing import List, Tuple

import pytest
import torch
from torch import nn

from eva.vision.models.networks import decoders


@pytest.mark.parametrize(
"layers, features, image_size, expected_shape",
[
(
nn.Conv2d(384, 5, kernel_size=(1, 1)),
[torch.Tensor(2, 384, 14, 14)],
(224, 224),
torch.Size([2, 5, 224, 224]),
),
(
nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(384, 64, kernel_size=(3, 3), padding=(1, 1)),
nn.Upsample(scale_factor=2),
nn.Conv2d(64, 5, kernel_size=(3, 3), padding=(1, 1)),
),
[torch.Tensor(2, 384, 14, 14)],
(224, 224),
torch.Size([2, 5, 224, 224]),
),
],
)
def test_conv_decoder(
conv_decoder: decoders.ConvDecoder,
features: List[torch.Tensor],
image_size: Tuple[int, int],
expected_shape: torch.Size,
) -> None:
"""Tests the ConvDecoder network."""
logits = conv_decoder(features, image_size)
assert isinstance(logits, torch.Tensor)
assert logits.shape == expected_shape


@pytest.fixture(scope="function")
def conv_decoder(
layers: nn.Module,
) -> decoders.ConvDecoder:
"""ConvDecoder fixture."""
return decoders.ConvDecoder(layers=layers)
40 changes: 40 additions & 0 deletions tests/eva/vision/models/networks/decoders/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Tests for linear decoder."""

from typing import List, Tuple

import pytest
import torch
from torch import nn

from eva.vision.models.networks import decoders


@pytest.mark.parametrize(
"layers, features, image_size, expected_shape",
[
(
nn.Linear(384, 5),
[torch.Tensor(2, 384, 14, 14)],
(224, 224),
torch.Size([2, 5, 224, 224]),
),
],
)
def test_linear_decoder(
linear_decoder: decoders.LinearDecoder,
features: List[torch.Tensor],
image_size: Tuple[int, int],
expected_shape: torch.Size,
) -> None:
"""Tests the ConvDecoder network."""
logits = linear_decoder(features, image_size)
assert isinstance(logits, torch.Tensor)
assert logits.shape == expected_shape


@pytest.fixture(scope="function")
def linear_decoder(
layers: nn.Module,
) -> decoders.LinearDecoder:
"""LinearDecoder fixture."""
return decoders.LinearDecoder(layers=layers)

0 comments on commit d112070

Please sign in to comment.