Skip to content

Commit

Permalink
Move the image encoder to detectron2 backbone and fix for tf.keras ba…
Browse files Browse the repository at this point in the history
…ckend
  • Loading branch information
tirthasheshpatel committed Aug 18, 2023
1 parent 025dd15 commit 43b0f2b
Show file tree
Hide file tree
Showing 17 changed files with 680 additions and 312 deletions.
3 changes: 3 additions & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from tensorflow.keras.layers import RandomWidth

from keras_cv.layers.augmenter import Augmenter
from keras_cv.layers.detectron2_layers import MultiHeadAttentionWithRelativePE
from keras_cv.layers.detectron2_layers import ViTDetPatchingAndEmbedding
from keras_cv.layers.detectron2_layers import WindowedTransformerEncoder
from keras_cv.layers.feature_pyramid import FeaturePyramid
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
from keras_cv.layers.mbconv import MBConvBlock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def __init__(
trainable=True,
)

self.qkv.build([self.key_dim * self.num_heads])
self.projection.build([self.key_dim * self.num_heads])

self.built = True

def compute_output_shape(self, input_shape):
return input_shape

def call(self, x):
B, H, W, C = x.shape
qkv = ops.transpose(
Expand Down Expand Up @@ -326,6 +334,14 @@ def __init__(
)
self.mlp_block = MLPBlock(project_dim, mlp_dim, activation)

self.layer_norm1.build([None, None, None, self.project_dim])
self.layer_norm2.build([None, None, None, self.project_dim])

self.built = True

def compute_output_shape(self, input_shape):
return input_shape

def call(self, x):
shortcut = x
x = self.layer_norm1(x)
Expand Down Expand Up @@ -364,7 +380,7 @@ def get_config(self):


@keras.utils.register_keras_serializable(package="keras_cv")
class SAMPatchingAndEmbedding(keras.layers.Layer):
class ViTDetPatchingAndEmbedding(keras.layers.Layer):
"""Image to Patch Embedding using only a conv layer (without
layer normalization).
Expand All @@ -390,6 +406,15 @@ def __init__(
self.strides = strides
self.embed_dim = embed_dim

self.built = False

def build(self, input_shape):
self.projection.build(input_shape)
self.built = True

def compute_output_shape(self, input_shape):
return self.projection.compute_output_shape(input_shape)

def call(self, x):
x = self.projection(x)
return x
Expand All @@ -404,158 +429,3 @@ def get_config(self):
}
)
return config


@keras.utils.register_keras_serializable(package="keras_cv")
class ImageEncoder(keras.models.Model):
"""A ViT image encoder for the segment anything model.
Args:
img_size (int, optional): The size of the input image. Defaults to
`1024`.
patch_size (int, optional): the patch size to be supplied to the
Patching layer to turn input images into a flattened sequence of
patches. Defaults to `16`.
in_chans (int, optional): The number of channels in the input image.
Defaults to `3`.
embed_dim (int, optional): The latent dimensionality to be projected
into in the output of each stacked windowed transformer encoder.
Defaults to `1280`.
depth (int, optional): The number of transformer encoder layers to
stack in the Vision Transformer. Defaults to `32`.
mlp_dim (_type_, optional): The dimensionality of the hidden Dense
layer in the transformer MLP head. Defaults to `1280*4`.
num_heads (int, optional): the number of heads to use in the
`MultiHeadAttentionWithRelativePE` layer of each transformer
encoder. Defaults to `16`.
out_chans (int, optional): The number of channels (features) in the
output (image encodings). Defaults to `256`.
use_bias (bool, optional): Whether to use bias to project the keys,
queries, and values in the attention layer. Defaults to `True`.
use_abs_pos (bool, optional): Whether to add absolute positional
embeddings to the output patches. Defaults to `True`.
use_rel_pos (bool, optional): Whether to use relative positional
emcodings in the attention layer. Defaults to `False`.
window_size (int, optional): The size of the window for windowed
attention in the transformer encoder blocks. Defaults to `0`.
global_attention_indices (list, optional): Indexes for blocks using
global attention. Defaults to `[7, 15, 23, 31]`.
layer_norm_epsilon (int, optional): The epsilon to use in the layer
normalization blocks in transformer encoder. Defaults to `1e-6`.
"""

def __init__(
self,
img_size=1024,
patch_size=16,
in_chans=3,
embed_dim=1280,
depth=32,
mlp_dim=1280 * 4,
num_heads=16,
out_chans=256,
use_bias=True,
use_abs_pos=True,
use_rel_pos=False,
window_size=0,
global_attention_indices=[7, 15, 23, 31],
layer_norm_epsilon=1e-6,
**kwargs
):
super().__init__(**kwargs)
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.out_chans = out_chans
self.use_bias = use_bias
self.use_rel_pos = use_rel_pos
self.use_abs_pos = use_abs_pos
self.window_size = window_size
self.global_attention_indices = global_attention_indices
self.layer_norm_epsilon = layer_norm_epsilon

self.patch_embed = SAMPatchingAndEmbedding(
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
embed_dim=embed_dim,
)
if self.use_abs_pos:
self.pos_embed = self.add_weight(
name="pos_embed",
shape=(
1,
self.img_size // self.patch_size,
self.img_size // self.patch_size,
self.embed_dim,
),
initializer="zeros",
trainable=True,
)
else:
self.pos_embed = None
self.transformer_blocks = []
for i in range(depth):
block = WindowedTransformerEncoder(
project_dim=embed_dim,
mlp_dim=mlp_dim,
num_heads=num_heads,
use_bias=use_bias,
use_rel_pos=use_rel_pos,
window_size=window_size
if i not in global_attention_indices
else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.transformer_blocks.append(block)
self.transformer_blocks = keras.models.Sequential(
self.transformer_blocks
)
self.bottleneck = keras.models.Sequential(
[
keras.layers.Conv2D(
filters=out_chans, kernel_size=1, use_bias=False
),
keras.layers.LayerNormalization(epsilon=1e-6),
keras.layers.Conv2D(
filters=out_chans,
kernel_size=3,
padding="same",
use_bias=False,
),
keras.layers.LayerNormalization(epsilon=1e-6),
]
)

def call(self, x):
B, _, _, _ = x.shape
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.transformer_blocks(x)
return self.bottleneck(x)

def get_config(self):
config = super().get_config()
config.update(
{
"img_size": self.img_size,
"patch_size": self.patch_size,
"in_chans": self.in_chans,
"embed_dim": self.embed_dim,
"depth": self.depth,
"mlp_dim": self.mlp_dim,
"num_heads": self.num_heads,
"out_chans": self.out_chans,
"use_bias": self.use_bias,
"use_abs_pos": self.use_abs_pos,
"use_rel_pos": self.use_rel_pos,
"window_size": self.window_size,
"global_attention_indices": self.global_attention_indices,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config
48 changes: 48 additions & 0 deletions keras_cv/layers/detectron2_layers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from keras_cv.backend import ops
from keras_cv.layers.detectron2_layers import MultiHeadAttentionWithRelativePE
from keras_cv.layers.detectron2_layers import WindowedTransformerEncoder
from keras_cv.tests.test_case import TestCase


class TestDetectron2Layers(TestCase):
def test_multi_head_attention_with_relative_pe(self):
attention_with_rel_pe = MultiHeadAttentionWithRelativePE(
num_heads=16,
key_dim=1280 // 16,
use_bias=True,
input_size=(64, 64),
)
x = np.ones(shape=(1, 64, 64, 1280))
x_out = ops.convert_to_numpy(attention_with_rel_pe(x))
self.assertEqual(x_out.shape, (1, 64, 64, 1280))

def test_windowed_transformer_encoder(self):
windowed_transformer_encoder = WindowedTransformerEncoder(
project_dim=1280,
mlp_dim=1280 * 4,
num_heads=16,
use_bias=True,
use_rel_pos=True,
window_size=14,
input_size=(64, 64),
)
x = np.ones((1, 64, 64, 1280))
x_out = ops.convert_to_numpy(windowed_transformer_encoder(x))
self.assertEqual(x_out.shape, (1, 64, 64, 1280))
self.assertAllClose(x_out, np.ones_like(x_out))
13 changes: 12 additions & 1 deletion keras_cv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@
from keras_cv.models.backbones.efficientnet_v1.efficientnet_v1_aliases import (
EfficientNetV1Backbone,
)
from keras_cv.models.backbones.detectron2.detectron2_backbone import (
ViTDetBackbone,
)
from keras_cv.models.backbones.detectron2.detectron2_aliases import (
SAMViTDetBBackbone,
)
from keras_cv.models.backbones.detectron2.detectron2_aliases import (
SAMViTDetLBackbone,
)
from keras_cv.models.backbones.detectron2.detectron2_aliases import (
SAMViTDetHBackbone,
)
from keras_cv.models.backbones.efficientnet_v2.efficientnet_v2_aliases import (
EfficientNetV2B0Backbone,
)
Expand Down Expand Up @@ -166,7 +178,6 @@
YOLOV8Detector,
)
from keras_cv.models.segmentation import DeepLabV3Plus
from keras_cv.models.segmentation.segment_anything import ImageEncoder
from keras_cv.models.segmentation.segment_anything import MaskDecoder
from keras_cv.models.segmentation.segment_anything import PromptEncoder
from keras_cv.models.segmentation.segment_anything import TwoWayTransformer
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/models/backbones/detectron2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Binary file not shown.
Loading

0 comments on commit 43b0f2b

Please sign in to comment.