-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Segmentation] Add Segment Anything Model (#132)
* starter commit, rough pytorch port * adapt patching and embedding layer * replaced proprietary mlp blocks with generic block * mlp class cleanup * mlp for TF * refactoring start * more refactoring * vitdetl and vitdeth * SAML and SAMH * add transforms, utils and mask generator * integrated SAM * refactoring * align vit det backbone to deepvision * align ViTDet to deepvision API, TF implementation for relative positional attention, etc * TF implementation for relative positional transformer encoder, window partitioning and unpartitioning * TF implementation for downscaling attention and refactor for efficient multihead attention * TF implementation for twoway attention block * aligning twoway transformer encoder * add decomp relative positional embedding * fix activation function for twoway transformer encoder * equalize twoway transformer encoder implementations * embedding dim -> project dim for API consistency * relative positional transformer encoder and positional attention tf implementations * expose API for random position embeddings * refactor * small refactor * small refactor * import * Refactor, nutshell file * add examples
- Loading branch information
1 parent
4deee81
commit e13bf5a
Showing
49 changed files
with
5,950 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ build/ | |
*.egg-info | ||
__pycache__/ | ||
*.so | ||
.history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from deepvision import datasets, evaluation, models | ||
from deepvision import datasets | ||
from deepvision import evaluation | ||
from deepvision import models | ||
|
||
__version__ = "0.1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,27 @@ | ||
from deepvision.layers.efficient_attention import EfficientAttention | ||
from deepvision.layers.downscaling_attention import DownscalingMultiheadAttention | ||
from deepvision.layers.efficient_attention import EfficientMultiheadAttention | ||
from deepvision.layers.fused_mbconv import FusedMBConv | ||
from deepvision.layers.hierarchical_transformer_encoder import ( | ||
HierarchicalTransformerEncoder, | ||
) | ||
from deepvision.layers.identity import Identity | ||
from deepvision.layers.layernorm2d import LayerNorm2d | ||
from deepvision.layers.mbconv import MBConv | ||
from deepvision.layers.mix_ffn import MixFFN | ||
from deepvision.layers.overlapping_patching_and_embedding import ( | ||
OverlappingPatchingAndEmbedding, | ||
) | ||
from deepvision.layers.patching_and_embedding import PatchingAndEmbedding | ||
from deepvision.layers.random_position_encoding import RandomPositionEmbedding | ||
from deepvision.layers.relative_positional_attention import ( | ||
RelativePositionalMultiheadAttention, | ||
) | ||
from deepvision.layers.relative_positional_transformer_encoder import ( | ||
RelativePositionalTransformerEncoder, | ||
) | ||
from deepvision.layers.stochasticdepth import StochasticDepth | ||
from deepvision.layers.transformer_encoder import TransformerEncoder | ||
from deepvision.layers.twoway_attention_block import TwoWayAttentionBlock | ||
from deepvision.layers.twoway_transformer_decoder import TwoWayTransformerDecoder | ||
from deepvision.layers.window_partitioning import WindowPartitioning | ||
from deepvision.layers.window_unpartitioning import WindowUnpartitioning |
236 changes: 236 additions & 0 deletions
236
deepvision/layers/decomposed_relative_positional_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
# Ported and adapted from the original code from Meta Platforms, Inc. and affiliates. Copyright | ||
# Original code Copyright / (c) Meta Platforms, Inc. and affiliates. | ||
# Modifications and adaptations / Copyright 2023 David Landup | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Tuple | ||
|
||
import tensorflow as tf | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class __AddDecomposedRelativePositionsPT(nn.Module): | ||
def __init__(self, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor): | ||
super().__init__() | ||
self.rel_pos_h = nn.Parameter(rel_pos_h, requires_grad=False) | ||
self.rel_pos_w = nn.Parameter(rel_pos_w, requires_grad=False) | ||
|
||
def forward( | ||
self, | ||
attn: torch.Tensor, | ||
q: torch.Tensor, | ||
q_size: Tuple[int, int], | ||
k_size: Tuple[int, int], | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
attn (Tensor): attention map. | ||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | ||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | ||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | ||
Returns: | ||
attn (Tensor): attention map with added relative positional embeddings. | ||
""" | ||
q_h, q_w = q_size | ||
k_h, k_w = k_size | ||
Rh = self.__get_rel_pos(q_h, k_h, self.rel_pos_h) | ||
Rw = self.__get_rel_pos(q_w, k_w, self.rel_pos_w) | ||
|
||
B, _, dim = q.shape | ||
r_q = q.reshape(B, q_h, q_w, dim) | ||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | ||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | ||
|
||
attn = ( | ||
attn.view(B, q_h, q_w, k_h, k_w) | ||
+ rel_h[:, :, :, :, None] | ||
+ rel_w[:, :, :, None, :] | ||
).view(B, q_h * q_w, k_h * k_w) | ||
|
||
return attn | ||
|
||
def __get_rel_pos( | ||
self, q_size: int, k_size: int, rel_pos: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Get relative positional embeddings according to the relative positions of | ||
query and key sizes. | ||
Args: | ||
q_size (int): size of query q. | ||
k_size (int): size of key k. | ||
rel_pos (Tensor): relative position embeddings (L, C). | ||
Returns: | ||
Extracted positional embeddings according to relative positions. | ||
""" | ||
max_rel_dist = int(2 * max(q_size, k_size) - 1) | ||
# Interpolate rel pos if needed. | ||
if rel_pos.shape[0] != max_rel_dist: | ||
# Interpolate rel pos. | ||
rel_pos_resized = F.interpolate( | ||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | ||
size=max_rel_dist, | ||
mode="linear", | ||
) | ||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | ||
else: | ||
rel_pos_resized = rel_pos | ||
|
||
# Scale the coords with short length if shapes for q and k are different. | ||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | ||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | ||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max( | ||
q_size / k_size, 1.0 | ||
) | ||
return rel_pos_resized[relative_coords.long()] | ||
|
||
|
||
class __AddDecomposedRelativePositionsTF(tf.keras.layers.Layer): | ||
def __init__(self, rel_pos_h: tf.Tensor, rel_pos_w: tf.Tensor): | ||
super().__init__() | ||
self.rel_pos_h = self.add_weight( | ||
name="rel_pos_h", | ||
shape=rel_pos_h.shape, | ||
initializer=tf.keras.initializers.Constant(rel_pos_h), | ||
trainable=False, | ||
) | ||
self.rel_pos_w = self.add_weight( | ||
name="rel_pos_w", | ||
shape=rel_pos_w.shape, | ||
initializer=tf.keras.initializers.Constant(rel_pos_w), | ||
trainable=False, | ||
) | ||
|
||
def call( | ||
self, | ||
attn: tf.Tensor, | ||
q: tf.Tensor, | ||
q_size: Tuple[int, int], | ||
k_size: Tuple[int, int], | ||
) -> tf.Tensor: | ||
""" | ||
Calculate decomposed Relative Positional Embeddings from `mvitv2`. | ||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py | ||
Args: | ||
attn (Tensor): attention map. | ||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | ||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | ||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | ||
Returns: | ||
attn (Tensor): attention map with added relative positional embeddings. | ||
""" | ||
q_h, q_w = q_size | ||
k_h, k_w = k_size | ||
Rh = self.__get_rel_pos(q_h, k_h, self.rel_pos_h) | ||
Rw = self.__get_rel_pos(q_w, k_w, self.rel_pos_w) | ||
|
||
B, _, dim = q.shape | ||
r_q = tf.reshape(q, (B, q_h, q_w, dim)) | ||
# print(r_q.shape, Rh.shape) | ||
rel_h = tf.einsum("bhwc,hkc->bhwk", r_q, Rh) | ||
rel_w = tf.einsum("bhwc,wkc->bhwk", r_q, Rw) | ||
|
||
attn = ( | ||
tf.reshape(attn, (B, q_h, q_w, k_h, k_w)) | ||
+ tf.expand_dims(rel_h, -1) | ||
+ tf.expand_dims(rel_w, -2) | ||
) | ||
attn = tf.reshape(attn, (B, q_h * q_w, k_h * k_w)) | ||
|
||
return attn | ||
|
||
def __get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: | ||
""" | ||
Get relative positional embeddings according to the relative positions of | ||
query and key sizes. | ||
Args: | ||
q_size (int): size of query q. | ||
k_size (int): size of key k. | ||
rel_pos (Tensor): relative position embeddings (L, C). | ||
Returns: | ||
Extracted positional embeddings according to relative positions. | ||
""" | ||
|
||
max_rel_dist = int(2 * max(q_size, k_size) - 1) | ||
if rel_pos.shape[0] != max_rel_dist: | ||
""" | ||
We should resize from (145, 96) -> (96, 145) and interpolate to (96, 127). | ||
However, tf.image.resize() doesn't operate only on one dimension, so we have to resize to the same | ||
dimension on shape[0] and interpolate the dimension on shape[1]. Since channels-last format is forced here, | ||
we also need to reshape to (145, 96, 1) and interpolate to (127, 96, 1), hence the difference in the implementations. | ||
""" | ||
rel_pos = tf.reshape(rel_pos, shape=[1, rel_pos.shape[0], -1]) | ||
rel_pos = tf.transpose(rel_pos, perm=[1, 2, 0]) | ||
rel_pos_resized = tf.image.resize( | ||
rel_pos, | ||
size=[rel_pos.shape[1], max_rel_dist], | ||
method="bilinear", | ||
) | ||
rel_pos_resized = tf.transpose(rel_pos_resized, perm=[2, 0, 1]) | ||
rel_pos_resized = tf.transpose( | ||
tf.reshape(rel_pos_resized, shape=[-1, max_rel_dist]), perm=[1, 0] | ||
) | ||
else: | ||
rel_pos_resized = rel_pos | ||
|
||
q_coords = tf.cast( | ||
tf.reshape(tf.range(q_size), [int(q_size), 1]), tf.float32 | ||
) * tf.cast(tf.math.maximum(k_size / q_size, 1.0), tf.float32) | ||
k_coords = tf.cast( | ||
tf.reshape(tf.range(k_size), [int(k_size), 1]), tf.float32 | ||
) * tf.cast(tf.math.maximum(q_size / k_size, 1.0), tf.float32) | ||
relative_coords = tf.cast((q_coords - k_coords), tf.float32) + tf.cast( | ||
(k_size - 1), tf.float32 | ||
) * tf.cast(tf.math.maximum(q_size / k_size, 1.0), tf.float32) | ||
|
||
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) | ||
|
||
|
||
LAYER_BACKBONES = { | ||
"tensorflow": __AddDecomposedRelativePositionsTF, | ||
"pytorch": __AddDecomposedRelativePositionsPT, | ||
} | ||
|
||
|
||
def AddDecomposedRelativePositions(rel_pos_h, rel_pos_w, backend): | ||
""" | ||
Calculate decomposed Relative Positional Embeddings from `mvitv2`. | ||
"MViTv2: Improved Multiscale Vision Transformers for Classification and Detection": | ||
- https://arxiv.org/abs/2112.01526 | ||
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py | ||
Args: | ||
q_size: tuple specifying the spatial sequence size of query q with (q_h, q_w). | ||
k_size: tuple specifying the spatial sequence size of key k with (k_h, k_w). | ||
Returns: | ||
Attention map with added relative positional embeddings. | ||
""" | ||
layer_class = LAYER_BACKBONES.get(backend) | ||
if layer_class is None: | ||
raise ValueError( | ||
f"Backend not supported: {backend}. Supported backbones are {LAYER_BACKBONES.keys()}" | ||
) | ||
|
||
layer = layer_class( | ||
rel_pos_h=rel_pos_h, | ||
rel_pos_w=rel_pos_w, | ||
) | ||
|
||
return layer |
Oops, something went wrong.