Skip to content

Commit

Permalink
[Segmentation] Add Segment Anything Model (#132)
Browse files Browse the repository at this point in the history
* starter commit, rough pytorch port

* adapt patching and embedding layer

* replaced proprietary mlp blocks with generic block

* mlp class cleanup

* mlp for TF

* refactoring start

* more refactoring

* vitdetl and vitdeth

* SAML and SAMH

* add transforms, utils and mask generator

* integrated SAM

* refactoring

* align vit det backbone to deepvision

* align ViTDet to deepvision API, TF implementation for relative positional attention, etc

* TF implementation for relative positional transformer encoder, window partitioning and unpartitioning

* TF implementation for downscaling attention and refactor for efficient multihead attention

* TF implementation for twoway attention block

* aligning twoway transformer encoder

* add decomp relative positional embedding

* fix activation function for twoway transformer encoder

* equalize twoway transformer encoder implementations

* embedding dim -> project dim for API consistency

* relative positional transformer encoder and positional attention tf implementations

* expose API for random position embeddings

* refactor

* small refactor

* small refactor

* import

* Refactor, nutshell file

* add examples
  • Loading branch information
DavidLandup0 authored May 14, 2023
1 parent 4deee81 commit e13bf5a
Show file tree
Hide file tree
Showing 49 changed files with 5,950 additions and 111 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ build/
*.egg-info
__pycache__/
*.so
.history
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,12 @@ Currently, these models are supported (parameter counts are *equal* between back

| Architecture | Parameters | FLOPs | Size (MB) |
|--------------|------------|-------|-----------|
| SegFormerB0 | 1,841,013 | | |
| SegFormerB1 | 5,910,997 | | |
| SegFormerB2 | 11,640,981 | | |
| SegFormerB3 | 16,721,301 | | |
| SegFormerB4 | 20,930,389 | | |
| SegFormerB5 | 26,234,645 | | |
| SegFormerB0 | 3,714,915 | | |
| SegFormerB1 | 13,678,019 | | |
| SegFormerB2 | 27,348,931 | | |
| SegFormerB3 | 47,224,771 | | |
| SegFormerB4 | 63,995,331 | | |
| SegFormerB5 | 84,595,651 | | |

- Mix-Transformer (MiT) Family:

Expand All @@ -394,7 +394,13 @@ Currently, these models are supported (parameter counts are *equal* between back
| MiTB4 | 60,847,818 | | |
| MiTB5 | 81,448,138 | | |

#### PyTorch-Only Models

| Architecture | Parameters | FLOPs | Size (MB) |
|--------------|-------------|-------|-----------|
| SAM_B | 93,735,472 | | |
| SAM_L | 312,342,832 | | |
| SAM_H | 641,090,608 | | |

## DeepVision as a Components Provider

Expand Down
4 changes: 3 additions & 1 deletion deepvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from deepvision import datasets, evaluation, models
from deepvision import datasets
from deepvision import evaluation
from deepvision import models

__version__ = "0.1.6"
3 changes: 2 additions & 1 deletion deepvision/datasets/tiny_nerf/tiny_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import requests

from deepvision.datasets.tiny_nerf import tiny_nerf_pt, tiny_nerf_tf
from deepvision.datasets.tiny_nerf import tiny_nerf_pt
from deepvision.datasets.tiny_nerf import tiny_nerf_tf

file_name = "tiny_nerf_data.npz"
url = "https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz"
Expand Down
15 changes: 14 additions & 1 deletion deepvision/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from deepvision.layers.efficient_attention import EfficientAttention
from deepvision.layers.downscaling_attention import DownscalingMultiheadAttention
from deepvision.layers.efficient_attention import EfficientMultiheadAttention
from deepvision.layers.fused_mbconv import FusedMBConv
from deepvision.layers.hierarchical_transformer_encoder import (
HierarchicalTransformerEncoder,
)
from deepvision.layers.identity import Identity
from deepvision.layers.layernorm2d import LayerNorm2d
from deepvision.layers.mbconv import MBConv
from deepvision.layers.mix_ffn import MixFFN
from deepvision.layers.overlapping_patching_and_embedding import (
OverlappingPatchingAndEmbedding,
)
from deepvision.layers.patching_and_embedding import PatchingAndEmbedding
from deepvision.layers.random_position_encoding import RandomPositionEmbedding
from deepvision.layers.relative_positional_attention import (
RelativePositionalMultiheadAttention,
)
from deepvision.layers.relative_positional_transformer_encoder import (
RelativePositionalTransformerEncoder,
)
from deepvision.layers.stochasticdepth import StochasticDepth
from deepvision.layers.transformer_encoder import TransformerEncoder
from deepvision.layers.twoway_attention_block import TwoWayAttentionBlock
from deepvision.layers.twoway_transformer_decoder import TwoWayTransformerDecoder
from deepvision.layers.window_partitioning import WindowPartitioning
from deepvision.layers.window_unpartitioning import WindowUnpartitioning
236 changes: 236 additions & 0 deletions deepvision/layers/decomposed_relative_positional_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Ported and adapted from the original code from Meta Platforms, Inc. and affiliates. Copyright
# Original code Copyright / (c) Meta Platforms, Inc. and affiliates.
# Modifications and adaptations / Copyright 2023 David Landup
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F


class __AddDecomposedRelativePositionsPT(nn.Module):
def __init__(self, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor):
super().__init__()
self.rel_pos_h = nn.Parameter(rel_pos_h, requires_grad=False)
self.rel_pos_w = nn.Parameter(rel_pos_w, requires_grad=False)

def forward(
self,
attn: torch.Tensor,
q: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = self.__get_rel_pos(q_h, k_h, self.rel_pos_h)
Rw = self.__get_rel_pos(q_w, k_w, self.rel_pos_w)

B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

attn = (
attn.view(B, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, None]
+ rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)

return attn

def __get_rel_pos(
self, q_size: int, k_size: int, rel_pos: torch.Tensor
) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos

# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(
q_size / k_size, 1.0
)
return rel_pos_resized[relative_coords.long()]


class __AddDecomposedRelativePositionsTF(tf.keras.layers.Layer):
def __init__(self, rel_pos_h: tf.Tensor, rel_pos_w: tf.Tensor):
super().__init__()
self.rel_pos_h = self.add_weight(
name="rel_pos_h",
shape=rel_pos_h.shape,
initializer=tf.keras.initializers.Constant(rel_pos_h),
trainable=False,
)
self.rel_pos_w = self.add_weight(
name="rel_pos_w",
shape=rel_pos_w.shape,
initializer=tf.keras.initializers.Constant(rel_pos_w),
trainable=False,
)

def call(
self,
attn: tf.Tensor,
q: tf.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> tf.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from `mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = self.__get_rel_pos(q_h, k_h, self.rel_pos_h)
Rw = self.__get_rel_pos(q_w, k_w, self.rel_pos_w)

B, _, dim = q.shape
r_q = tf.reshape(q, (B, q_h, q_w, dim))
# print(r_q.shape, Rh.shape)
rel_h = tf.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = tf.einsum("bhwc,wkc->bhwk", r_q, Rw)

attn = (
tf.reshape(attn, (B, q_h, q_w, k_h, k_w))
+ tf.expand_dims(rel_h, -1)
+ tf.expand_dims(rel_w, -2)
)
attn = tf.reshape(attn, (B, q_h * q_w, k_h * k_w))

return attn

def __get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""

max_rel_dist = int(2 * max(q_size, k_size) - 1)
if rel_pos.shape[0] != max_rel_dist:
"""
We should resize from (145, 96) -> (96, 145) and interpolate to (96, 127).
However, tf.image.resize() doesn't operate only on one dimension, so we have to resize to the same
dimension on shape[0] and interpolate the dimension on shape[1]. Since channels-last format is forced here,
we also need to reshape to (145, 96, 1) and interpolate to (127, 96, 1), hence the difference in the implementations.
"""
rel_pos = tf.reshape(rel_pos, shape=[1, rel_pos.shape[0], -1])
rel_pos = tf.transpose(rel_pos, perm=[1, 2, 0])
rel_pos_resized = tf.image.resize(
rel_pos,
size=[rel_pos.shape[1], max_rel_dist],
method="bilinear",
)
rel_pos_resized = tf.transpose(rel_pos_resized, perm=[2, 0, 1])
rel_pos_resized = tf.transpose(
tf.reshape(rel_pos_resized, shape=[-1, max_rel_dist]), perm=[1, 0]
)
else:
rel_pos_resized = rel_pos

q_coords = tf.cast(
tf.reshape(tf.range(q_size), [int(q_size), 1]), tf.float32
) * tf.cast(tf.math.maximum(k_size / q_size, 1.0), tf.float32)
k_coords = tf.cast(
tf.reshape(tf.range(k_size), [int(k_size), 1]), tf.float32
) * tf.cast(tf.math.maximum(q_size / k_size, 1.0), tf.float32)
relative_coords = tf.cast((q_coords - k_coords), tf.float32) + tf.cast(
(k_size - 1), tf.float32
) * tf.cast(tf.math.maximum(q_size / k_size, 1.0), tf.float32)

return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))


LAYER_BACKBONES = {
"tensorflow": __AddDecomposedRelativePositionsTF,
"pytorch": __AddDecomposedRelativePositionsPT,
}


def AddDecomposedRelativePositions(rel_pos_h, rel_pos_w, backend):
"""
Calculate decomposed Relative Positional Embeddings from `mvitv2`.
"MViTv2: Improved Multiscale Vision Transformers for Classification and Detection":
- https://arxiv.org/abs/2112.01526
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
q_size: tuple specifying the spatial sequence size of query q with (q_h, q_w).
k_size: tuple specifying the spatial sequence size of key k with (k_h, k_w).
Returns:
Attention map with added relative positional embeddings.
"""
layer_class = LAYER_BACKBONES.get(backend)
if layer_class is None:
raise ValueError(
f"Backend not supported: {backend}. Supported backbones are {LAYER_BACKBONES.keys()}"
)

layer = layer_class(
rel_pos_h=rel_pos_h,
rel_pos_w=rel_pos_w,
)

return layer
Loading

0 comments on commit e13bf5a

Please sign in to comment.