-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Decoder heads API.""" | ||
|
||
from eva.vision.models.networks.decoders.conv import ConvDecoder | ||
from eva.vision.models.networks.decoders.decoder import Decoder | ||
from eva.vision.models.networks.decoders.linear import LinearDecoder | ||
|
||
__all__ = ["ConvDecoder", "Decoder", "LinearDecoder"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Convolutional based semantic segmentation decoder.""" | ||
|
||
from typing import List, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional | ||
|
||
from eva.vision.models.networks.decoders import decoder | ||
|
||
|
||
class ConvDecoder(decoder.Decoder): | ||
"""Convolutional segmentation decoder.""" | ||
|
||
def __init__(self, layers: nn.Module) -> None: | ||
"""Initializes the convolutional based decoder head. | ||
Here the input nn layers will be directly applied to the | ||
features of shape (batch_size, hidden_size, height, width) | ||
Args: | ||
layers: The convolutional layers to be used as the decoder head. | ||
""" | ||
super().__init__() | ||
|
||
self._layers = layers | ||
|
||
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: | ||
"""Forward function for multi-level feature maps to a single one. | ||
Args: | ||
features: List of multi-level image features of shape (batch_size, | ||
hidden_size, num_patches_height, num_patches_width). | ||
Returns: | ||
A tensor of shape (batch_size, hidden_size, num_patches_height, | ||
num_patches_width) which is feature map of the decoder head. | ||
""" | ||
return features[-1] | ||
|
||
def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: | ||
"""Forward of the decoder head. | ||
Args: | ||
patch_embeddings: The model patch embeddings reshaped to | ||
(batch_size, hidden_size, num_patches_height, num_patches_width). | ||
Returns: | ||
The logits as a tensor (batch_size, num_classes, height, width). | ||
""" | ||
return self._layers(patch_embeddings) | ||
|
||
def _cls_seg( | ||
self, | ||
logits: torch.Tensor, | ||
image_size: Tuple[int, int], | ||
) -> torch.Tensor: | ||
"""Classify each pixel of the image. | ||
Args: | ||
logits: The decoder outputs of shape | ||
(batch_size, num_classes, height, width). | ||
image_size: The target image size (height, width). | ||
Returns: | ||
Tensor containing scores for all of the classes with shape | ||
(batch_size, num_classes, image_height, image_width). | ||
""" | ||
return functional.interpolate(logits, image_size, mode="bilinear") | ||
|
||
def forward( | ||
self, | ||
features: List[torch.Tensor], | ||
image_size: Tuple[int, int], | ||
) -> torch.Tensor: | ||
"""Maps the patch embeddings to a segmentation mask of the image size. | ||
Args: | ||
features: List of multi-level image features. | ||
image_size: The target image size (height, width). | ||
Returns: | ||
Tensor containing scores for all of the classes with shape | ||
(batch_size, num_classes, image_height, image_width). | ||
""" | ||
patch_embeddings = self._forward_features(features) | ||
logits = self._forward_head(patch_embeddings) | ||
return self._cls_seg(logits, image_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Semantic segmentation decoder base class.""" | ||
|
||
from torch import nn | ||
|
||
|
||
class Decoder(nn.Module): | ||
"""Semantic segmentation decoder base class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
"""Linear based decoder.""" | ||
|
||
from typing import List, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional | ||
|
||
from eva.vision.models.networks.decoders import decoder | ||
|
||
|
||
class LinearDecoder(decoder.Decoder): | ||
"""Linear decoder.""" | ||
|
||
def __init__(self, layers: nn.Module) -> None: | ||
"""Initializes the linear based decoder head. | ||
Here the input nn layers will be applied to the reshaped | ||
features (batch_size, patch_embeddings, hidden_size) from | ||
the input (batch_size, hidden_size, height, width) and then | ||
unwrapped again to (batch_size, num_classes, height, width). | ||
Args: | ||
layers: The linear layers to be used as the decoder head. | ||
""" | ||
super().__init__() | ||
|
||
self._layers = layers | ||
|
||
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor: | ||
"""Forward function for multi-level feature maps to a single one. | ||
Args: | ||
features: List of multi-level image features of shape (batch_size, | ||
hidden_size, num_patches_height, num_patches_width). | ||
Returns: | ||
A tensor of shape (batch_size, hidden_size, num_patches_height, | ||
num_patches_width) which is feature map of the decoder head. | ||
""" | ||
return features[-1] | ||
|
||
def _forward_head(self, patch_embeddings: torch.Tensor) -> torch.Tensor: | ||
"""Forward of the decoder head. | ||
Args: | ||
patch_embeddings: The model patch embeddings reshaped to | ||
(batch_size, hidden_size, num_patches_height, num_patches_width). | ||
Returns: | ||
The logits as a tensor (batch_size, num_classes, height, width). | ||
""" | ||
batch_size, _, height, width = patch_embeddings.shape | ||
embeddings_reshaped = patch_embeddings.reshape(batch_size, _, height * width) | ||
logits = self._layers(embeddings_reshaped.permute(0, 2, 1)) | ||
return logits.permute(0, 2, 1).reshape(batch_size, -1, height, width) | ||
|
||
def _cls_seg( | ||
self, | ||
logits: torch.Tensor, | ||
image_size: Tuple[int, int], | ||
) -> torch.Tensor: | ||
"""Classify each pixel of the image. | ||
Args: | ||
logits: The decoder outputs of shape | ||
(batch_size, num_classes, height, width). | ||
image_size: The target image size (height, width). | ||
Returns: | ||
Tensor containing scores for all of the classes with shape | ||
(batch_size, num_classes, image_height, image_width). | ||
""" | ||
return functional.interpolate(logits, image_size, mode="bilinear") | ||
|
||
def forward( | ||
self, | ||
features: List[torch.Tensor], | ||
image_size: Tuple[int, int], | ||
) -> torch.Tensor: | ||
"""Maps the patch embeddings to a segmentation mask of the image size. | ||
Args: | ||
features: List of multi-level image features. | ||
image_size: The target image size (height, width). | ||
Returns: | ||
Tensor containing scores for all of the classes with shape | ||
(batch_size, num_classes, image_height, image_width). | ||
""" | ||
patch_embeddings = self._forward_features(features) | ||
logits = self._forward_head(patch_embeddings) | ||
return self._cls_seg(logits, image_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Vision segmentation decoders tests.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Tests for convolutional decoder.""" | ||
|
||
from typing import List, Tuple | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from eva.vision.models.networks import decoders | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"layers, features, image_size, expected_shape", | ||
[ | ||
( | ||
nn.Conv2d(384, 5, kernel_size=(1, 1)), | ||
[torch.Tensor(2, 384, 14, 14)], | ||
(224, 224), | ||
torch.Size([2, 5, 224, 224]), | ||
), | ||
( | ||
nn.Sequential( | ||
nn.Upsample(scale_factor=2), | ||
nn.Conv2d(384, 64, kernel_size=(3, 3), padding=(1, 1)), | ||
nn.Upsample(scale_factor=2), | ||
nn.Conv2d(64, 5, kernel_size=(3, 3), padding=(1, 1)), | ||
), | ||
[torch.Tensor(2, 384, 14, 14)], | ||
(224, 224), | ||
torch.Size([2, 5, 224, 224]), | ||
), | ||
], | ||
) | ||
def test_conv_decoder( | ||
conv_decoder: decoders.ConvDecoder, | ||
features: List[torch.Tensor], | ||
image_size: Tuple[int, int], | ||
expected_shape: torch.Size, | ||
) -> None: | ||
"""Tests the ConvDecoder network.""" | ||
logits = conv_decoder(features, image_size) | ||
assert isinstance(logits, torch.Tensor) | ||
assert logits.shape == expected_shape | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def conv_decoder( | ||
layers: nn.Module, | ||
) -> decoders.ConvDecoder: | ||
"""ConvDecoder fixture.""" | ||
return decoders.ConvDecoder(layers=layers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Tests for linear decoder.""" | ||
|
||
from typing import List, Tuple | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from eva.vision.models.networks import decoders | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"layers, features, image_size, expected_shape", | ||
[ | ||
( | ||
nn.Linear(384, 5), | ||
[torch.Tensor(2, 384, 14, 14)], | ||
(224, 224), | ||
torch.Size([2, 5, 224, 224]), | ||
), | ||
], | ||
) | ||
def test_linear_decoder( | ||
linear_decoder: decoders.LinearDecoder, | ||
features: List[torch.Tensor], | ||
image_size: Tuple[int, int], | ||
expected_shape: torch.Size, | ||
) -> None: | ||
"""Tests the ConvDecoder network.""" | ||
logits = linear_decoder(features, image_size) | ||
assert isinstance(logits, torch.Tensor) | ||
assert logits.shape == expected_shape | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def linear_decoder( | ||
layers: nn.Module, | ||
) -> decoders.LinearDecoder: | ||
"""LinearDecoder fixture.""" | ||
return decoders.LinearDecoder(layers=layers) |