diff --git a/keras_cv/layers/__init__.py b/keras_cv/layers/__init__.py index c8b01f2769..166d553924 100644 --- a/keras_cv/layers/__init__.py +++ b/keras_cv/layers/__init__.py @@ -17,6 +17,9 @@ from tensorflow.keras.layers import RandomWidth from keras_cv.layers.augmenter import Augmenter +from keras_cv.layers.detectron2_layers import MultiHeadAttentionWithRelativePE +from keras_cv.layers.detectron2_layers import ViTDetPatchingAndEmbedding +from keras_cv.layers.detectron2_layers import WindowedTransformerEncoder from keras_cv.layers.feature_pyramid import FeaturePyramid from keras_cv.layers.fusedmbconv import FusedMBConvBlock from keras_cv.layers.mbconv import MBConvBlock diff --git a/keras_cv/models/segmentation/segment_anything/sam_image_encoder.py b/keras_cv/layers/detectron2_layers.py similarity index 69% rename from keras_cv/models/segmentation/segment_anything/sam_image_encoder.py rename to keras_cv/layers/detectron2_layers.py index 3801c75f62..25e148e5ba 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_image_encoder.py +++ b/keras_cv/layers/detectron2_layers.py @@ -163,6 +163,14 @@ def __init__( trainable=True, ) + self.qkv.build([self.key_dim * self.num_heads]) + self.projection.build([self.key_dim * self.num_heads]) + + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + def call(self, x): B, H, W, C = x.shape qkv = ops.transpose( @@ -326,6 +334,14 @@ def __init__( ) self.mlp_block = MLPBlock(project_dim, mlp_dim, activation) + self.layer_norm1.build([None, None, None, self.project_dim]) + self.layer_norm2.build([None, None, None, self.project_dim]) + + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + def call(self, x): shortcut = x x = self.layer_norm1(x) @@ -364,7 +380,7 @@ def get_config(self): @keras.utils.register_keras_serializable(package="keras_cv") -class SAMPatchingAndEmbedding(keras.layers.Layer): +class ViTDetPatchingAndEmbedding(keras.layers.Layer): """Image to Patch Embedding using only a conv layer (without layer normalization). @@ -390,6 +406,15 @@ def __init__( self.strides = strides self.embed_dim = embed_dim + self.built = False + + def build(self, input_shape): + self.projection.build(input_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return self.projection.compute_output_shape(input_shape) + def call(self, x): x = self.projection(x) return x @@ -404,158 +429,3 @@ def get_config(self): } ) return config - - -@keras.utils.register_keras_serializable(package="keras_cv") -class ImageEncoder(keras.models.Model): - """A ViT image encoder for the segment anything model. - - Args: - img_size (int, optional): The size of the input image. Defaults to - `1024`. - patch_size (int, optional): the patch size to be supplied to the - Patching layer to turn input images into a flattened sequence of - patches. Defaults to `16`. - in_chans (int, optional): The number of channels in the input image. - Defaults to `3`. - embed_dim (int, optional): The latent dimensionality to be projected - into in the output of each stacked windowed transformer encoder. - Defaults to `1280`. - depth (int, optional): The number of transformer encoder layers to - stack in the Vision Transformer. Defaults to `32`. - mlp_dim (_type_, optional): The dimensionality of the hidden Dense - layer in the transformer MLP head. Defaults to `1280*4`. - num_heads (int, optional): the number of heads to use in the - `MultiHeadAttentionWithRelativePE` layer of each transformer - encoder. Defaults to `16`. - out_chans (int, optional): The number of channels (features) in the - output (image encodings). Defaults to `256`. - use_bias (bool, optional): Whether to use bias to project the keys, - queries, and values in the attention layer. Defaults to `True`. - use_abs_pos (bool, optional): Whether to add absolute positional - embeddings to the output patches. Defaults to `True`. - use_rel_pos (bool, optional): Whether to use relative positional - emcodings in the attention layer. Defaults to `False`. - window_size (int, optional): The size of the window for windowed - attention in the transformer encoder blocks. Defaults to `0`. - global_attention_indices (list, optional): Indexes for blocks using - global attention. Defaults to `[7, 15, 23, 31]`. - layer_norm_epsilon (int, optional): The epsilon to use in the layer - normalization blocks in transformer encoder. Defaults to `1e-6`. - """ - - def __init__( - self, - img_size=1024, - patch_size=16, - in_chans=3, - embed_dim=1280, - depth=32, - mlp_dim=1280 * 4, - num_heads=16, - out_chans=256, - use_bias=True, - use_abs_pos=True, - use_rel_pos=False, - window_size=0, - global_attention_indices=[7, 15, 23, 31], - layer_norm_epsilon=1e-6, - **kwargs - ): - super().__init__(**kwargs) - self.img_size = img_size - self.patch_size = patch_size - self.in_chans = in_chans - self.embed_dim = embed_dim - self.depth = depth - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.out_chans = out_chans - self.use_bias = use_bias - self.use_rel_pos = use_rel_pos - self.use_abs_pos = use_abs_pos - self.window_size = window_size - self.global_attention_indices = global_attention_indices - self.layer_norm_epsilon = layer_norm_epsilon - - self.patch_embed = SAMPatchingAndEmbedding( - kernel_size=(patch_size, patch_size), - strides=(patch_size, patch_size), - embed_dim=embed_dim, - ) - if self.use_abs_pos: - self.pos_embed = self.add_weight( - name="pos_embed", - shape=( - 1, - self.img_size // self.patch_size, - self.img_size // self.patch_size, - self.embed_dim, - ), - initializer="zeros", - trainable=True, - ) - else: - self.pos_embed = None - self.transformer_blocks = [] - for i in range(depth): - block = WindowedTransformerEncoder( - project_dim=embed_dim, - mlp_dim=mlp_dim, - num_heads=num_heads, - use_bias=use_bias, - use_rel_pos=use_rel_pos, - window_size=window_size - if i not in global_attention_indices - else 0, - input_size=(img_size // patch_size, img_size // patch_size), - ) - self.transformer_blocks.append(block) - self.transformer_blocks = keras.models.Sequential( - self.transformer_blocks - ) - self.bottleneck = keras.models.Sequential( - [ - keras.layers.Conv2D( - filters=out_chans, kernel_size=1, use_bias=False - ), - keras.layers.LayerNormalization(epsilon=1e-6), - keras.layers.Conv2D( - filters=out_chans, - kernel_size=3, - padding="same", - use_bias=False, - ), - keras.layers.LayerNormalization(epsilon=1e-6), - ] - ) - - def call(self, x): - B, _, _, _ = x.shape - x = self.patch_embed(x) - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.transformer_blocks(x) - return self.bottleneck(x) - - def get_config(self): - config = super().get_config() - config.update( - { - "img_size": self.img_size, - "patch_size": self.patch_size, - "in_chans": self.in_chans, - "embed_dim": self.embed_dim, - "depth": self.depth, - "mlp_dim": self.mlp_dim, - "num_heads": self.num_heads, - "out_chans": self.out_chans, - "use_bias": self.use_bias, - "use_abs_pos": self.use_abs_pos, - "use_rel_pos": self.use_rel_pos, - "window_size": self.window_size, - "global_attention_indices": self.global_attention_indices, - "layer_norm_epsilon": self.layer_norm_epsilon, - } - ) - return config diff --git a/keras_cv/layers/detectron2_layers_test.py b/keras_cv/layers/detectron2_layers_test.py new file mode 100644 index 0000000000..71b5e190fd --- /dev/null +++ b/keras_cv/layers/detectron2_layers_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from keras_cv.backend import ops +from keras_cv.layers.detectron2_layers import MultiHeadAttentionWithRelativePE +from keras_cv.layers.detectron2_layers import WindowedTransformerEncoder +from keras_cv.tests.test_case import TestCase + + +class TestDetectron2Layers(TestCase): + def test_multi_head_attention_with_relative_pe(self): + attention_with_rel_pe = MultiHeadAttentionWithRelativePE( + num_heads=16, + key_dim=1280 // 16, + use_bias=True, + input_size=(64, 64), + ) + x = np.ones(shape=(1, 64, 64, 1280)) + x_out = ops.convert_to_numpy(attention_with_rel_pe(x)) + self.assertEqual(x_out.shape, (1, 64, 64, 1280)) + + def test_windowed_transformer_encoder(self): + windowed_transformer_encoder = WindowedTransformerEncoder( + project_dim=1280, + mlp_dim=1280 * 4, + num_heads=16, + use_bias=True, + use_rel_pos=True, + window_size=14, + input_size=(64, 64), + ) + x = np.ones((1, 64, 64, 1280)) + x_out = ops.convert_to_numpy(windowed_transformer_encoder(x)) + self.assertEqual(x_out.shape, (1, 64, 64, 1280)) + self.assertAllClose(x_out, np.ones_like(x_out)) diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index fc636dbdbd..5d9102d784 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -88,6 +88,18 @@ from keras_cv.models.backbones.efficientnet_v1.efficientnet_v1_aliases import ( EfficientNetV1Backbone, ) +from keras_cv.models.backbones.detectron2.detectron2_backbone import ( + ViTDetBackbone, +) +from keras_cv.models.backbones.detectron2.detectron2_aliases import ( + SAMViTDetBBackbone, +) +from keras_cv.models.backbones.detectron2.detectron2_aliases import ( + SAMViTDetLBackbone, +) +from keras_cv.models.backbones.detectron2.detectron2_aliases import ( + SAMViTDetHBackbone, +) from keras_cv.models.backbones.efficientnet_v2.efficientnet_v2_aliases import ( EfficientNetV2B0Backbone, ) @@ -166,7 +178,6 @@ YOLOV8Detector, ) from keras_cv.models.segmentation import DeepLabV3Plus -from keras_cv.models.segmentation.segment_anything import ImageEncoder from keras_cv.models.segmentation.segment_anything import MaskDecoder from keras_cv.models.segmentation.segment_anything import PromptEncoder from keras_cv.models.segmentation.segment_anything import TwoWayTransformer diff --git a/keras_cv/models/backbones/detectron2/__init__.py b/keras_cv/models/backbones/detectron2/__init__.py new file mode 100644 index 0000000000..3992ffb59a --- /dev/null +++ b/keras_cv/models/backbones/detectron2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/models/backbones/detectron2/data/sam_vitdet_b_out.npz b/keras_cv/models/backbones/detectron2/data/sam_vitdet_b_out.npz new file mode 100644 index 0000000000..da8c732ccd Binary files /dev/null and b/keras_cv/models/backbones/detectron2/data/sam_vitdet_b_out.npz differ diff --git a/keras_cv/models/backbones/detectron2/detectron2_aliases.py b/keras_cv/models/backbones/detectron2/detectron2_aliases.py new file mode 100644 index 0000000000..6002259bc4 --- /dev/null +++ b/keras_cv/models/backbones/detectron2/detectron2_aliases.py @@ -0,0 +1,116 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.models.backbones.detectron2.detectron2_backbone import ( + ViTDetBackbone, +) +from keras_cv.models.backbones.detectron2.detectron2_backbone_presets import ( + backbone_presets, +) +from keras_cv.utils.python_utils import classproperty + +ALIAS_DOCSTRING = """{SAM}VitDet{size}Backbone model. + + Reference: + - [Detectron2](https://github.com/facebookresearch/detectron2) + - [Segment Anything](https://arxiv.org/abs/2304.02643) + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/). + + Examples: + ```python + input_data = np.ones(shape=(1, 1024, 1024, 3)) + + # Randomly initialized backbone + model = {SAM}VitDet{size}Backbone() + output = model(input_data) + ``` +""" # noqa: E501 + + +class SAMViTDetBBackbone(ViTDetBackbone): + def __new__( + cls, + **kwargs, + ): + return ViTDetBackbone.from_preset("sam_vitdet_b", **kwargs) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return { + "sam_vitdet_b": copy.deepcopy(backbone_presets["sam_vitdet_b"]), + } + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return cls.presets + + +class SAMViTDetLBackbone(ViTDetBackbone): + def __new__( + cls, + **kwargs, + ): + return ViTDetBackbone.from_preset("sam_vitdet_l", **kwargs) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return { + "sam_vitdet_l": copy.deepcopy(backbone_presets["sam_vitdet_l"]), + } + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return cls.presets + + +class SAMViTDetHBackbone(ViTDetBackbone): + def __new__( + cls, + **kwargs, + ): + return ViTDetBackbone.from_preset("sam_vitdet_h", **kwargs) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return { + "sam_vitdet_h": copy.deepcopy(backbone_presets["sam_vitdet_h"]), + } + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return cls.presets + + +setattr( + SAMViTDetBBackbone, "__doc__", ALIAS_DOCSTRING.format(SAM="SAM", size="B") +) +setattr( + SAMViTDetLBackbone, "__doc__", ALIAS_DOCSTRING.format(SAM="SAM", size="L") +) +setattr( + SAMViTDetHBackbone, "__doc__", ALIAS_DOCSTRING.format(SAM="SAM", size="H") +) diff --git a/keras_cv/models/backbones/detectron2/detectron2_backbone.py b/keras_cv/models/backbones/detectron2/detectron2_backbone.py new file mode 100644 index 0000000000..9aa57ad811 --- /dev/null +++ b/keras_cv/models/backbones/detectron2/detectron2_backbone.py @@ -0,0 +1,203 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.backend import keras +from keras_cv.layers.detectron2_layers import ViTDetPatchingAndEmbedding +from keras_cv.layers.detectron2_layers import WindowedTransformerEncoder +from keras_cv.models.backbones.backbone import Backbone +from keras_cv.models.backbones.detectron2.detectron2_backbone_presets import ( + backbone_presets, +) +from keras_cv.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_cv.models") +class ViTDetBackbone(Backbone): + """A ViT image encoder that uses a windowed transformer encoder and + relative positional encodings. + + Args: + img_size (int, optional): The size of the input image. Defaults to + `1024`. + patch_size (int, optional): the patch size to be supplied to the + Patching layer to turn input images into a flattened sequence of + patches. Defaults to `16`. + in_chans (int, optional): The number of channels in the input image. + Defaults to `3`. + embed_dim (int, optional): The latent dimensionality to be projected + into in the output of each stacked windowed transformer encoder. + Defaults to `1280`. + depth (int, optional): The number of transformer encoder layers to + stack in the Vision Transformer. Defaults to `32`. + mlp_dim (_type_, optional): The dimensionality of the hidden Dense + layer in the transformer MLP head. Defaults to `1280*4`. + num_heads (int, optional): the number of heads to use in the + `MultiHeadAttentionWithRelativePE` layer of each transformer + encoder. Defaults to `16`. + out_chans (int, optional): The number of channels (features) in the + output (image encodings). Defaults to `256`. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_abs_pos (bool, optional): Whether to add absolute positional + embeddings to the output patches. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `False`. + window_size (int, optional): The size of the window for windowed + attention in the transformer encoder blocks. Defaults to `0`. + global_attention_indices (list, optional): Indexes for blocks using + global attention. Defaults to `[7, 15, 23, 31]`. + layer_norm_epsilon (int, optional): The epsilon to use in the layer + normalization blocks in transformer encoder. Defaults to `1e-6`. + """ + + def __init__( + self, + img_size=1024, + patch_size=16, + in_chans=3, + embed_dim=1280, + depth=32, + mlp_dim=1280 * 4, + num_heads=16, + out_chans=256, + use_bias=True, + use_abs_pos=True, + use_rel_pos=False, + window_size=0, + global_attention_indices=[7, 15, 23, 31], + layer_norm_epsilon=1e-6, + **kwargs + ): + super().__init__(**kwargs) + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.out_chans = out_chans + self.use_bias = use_bias + self.use_rel_pos = use_rel_pos + self.use_abs_pos = use_abs_pos + self.window_size = window_size + self.global_attention_indices = global_attention_indices + self.layer_norm_epsilon = layer_norm_epsilon + + self.patch_embed = ViTDetPatchingAndEmbedding( + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + embed_dim=embed_dim, + ) + if self.use_abs_pos: + self.pos_embed = self.add_weight( + name="pos_embed", + shape=( + 1, + self.img_size // self.patch_size, + self.img_size // self.patch_size, + self.embed_dim, + ), + initializer="zeros", + trainable=True, + ) + else: + self.pos_embed = None + self.transformer_blocks = [] + for i in range(depth): + block = WindowedTransformerEncoder( + project_dim=embed_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + window_size=window_size + if i not in global_attention_indices + else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.transformer_blocks.append(block) + self.bottleneck = keras.models.Sequential( + [ + keras.layers.Conv2D( + filters=out_chans, kernel_size=1, use_bias=False + ), + keras.layers.LayerNormalization(epsilon=1e-6), + keras.layers.Conv2D( + filters=out_chans, + kernel_size=3, + padding="same", + use_bias=False, + ), + keras.layers.LayerNormalization(epsilon=1e-6), + ] + ) + + self.patch_embed.build( + [None, self.img_size, self.img_size, self.in_chans] + ) + self.bottleneck.build( + [ + None, + self.img_size // self.patch_size, + self.img_size // self.patch_size, + self.embed_dim, + ] + ) + + self.built = True + + def call(self, x): + B, _, _, _ = x.shape + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + for block in self.transformer_blocks: + x = block(x) + return self.bottleneck(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "img_size": self.img_size, + "patch_size": self.patch_size, + "in_chans": self.in_chans, + "embed_dim": self.embed_dim, + "depth": self.depth, + "mlp_dim": self.mlp_dim, + "num_heads": self.num_heads, + "out_chans": self.out_chans, + "use_bias": self.use_bias, + "use_abs_pos": self.use_abs_pos, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "global_attention_indices": self.global_attention_indices, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return copy.deepcopy(backbone_presets) + + # @classproperty + # def presets_with_weights(cls): + # """Dictionary of preset names and configurations that include + # weights.""" + # return copy.deepcopy(backbone_presets) diff --git a/keras_cv/models/backbones/detectron2/detectron2_backbone_presets.py b/keras_cv/models/backbones/detectron2/detectron2_backbone_presets.py new file mode 100644 index 0000000000..9c8f8ca63c --- /dev/null +++ b/keras_cv/models/backbones/detectron2/detectron2_backbone_presets.py @@ -0,0 +1,107 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VitDet model preset configurations.""" + +backbone_presets = { + "sam_vitdet_b": { + "metadata": { + "description": ( + "VitDet Backbone for the segment anything model with 12 " + "transformer encoders with embed dim 768 and attention layers" + " with 12 heads with global attention on encoders 2, 5, 8, " + "and 11." + ), + "params": 89_670_912, + "official_name": "VitDet", + "path": "detectron2", + }, + "class_name": "keras_cv.models>VitDetBackbone", + "config": { + "img_size": 1024, + "patch_size": 16, + "in_chans": 3, + "embed_dim": 768, + "depth": 12, + "mlp_dim": 768 * 4, + "num_heads": 12, + "out_chans": 256, + "use_bias": True, + "use_rel_pos": True, + "window_size": 14, + "global_attention_indices": [2, 5, 8, 11], + }, + # "weights_url": "https://storage.googleapis.com/keras-cv/models/segment_anything/sam_vit_b.weights.h5", # noqa: E501 + # "weights_hash": None + }, + "sam_vitdet_l": { + "metadata": { + "description": ( + "VitDet Backbone for the segment anything model with 24 " + "transformer encoders with embed dim " + "1024 and attention layers with 16 heads with global " + "attention on encoders 5, 11, 17, and 23." + ), + "params": 308_278_272, + "official_name": "VitDet", + "path": "detectron2", + }, + "class_name": "keras_cv.models>VitDetBackbone", + "config": { + "img_size": 1024, + "patch_size": 16, + "in_chans": 3, + "embed_dim": 1024, + "depth": 24, + "mlp_dim": 1024 * 4, + "num_heads": 16, + "out_chans": 256, + "use_bias": True, + "use_rel_pos": True, + "window_size": 14, + "global_attention_indices": [5, 11, 17, 23], + }, + # "weights_url": "https://storage.googleapis.com/keras-cv/models/segment_anything/sam_vit_l.weights.h5", # noqa: E501 + # "weights_hash": None + }, + "sam_vitdet_h": { + "metadata": { + "description": ( + "VitDet Backbone for the segment anything model " + "with 32 transformer encoders with embed dim " + "1280 and attention layers with 16 heads with global " + "attention on encoders 7, 15, 23, and 31." + ), + "params": 637_026_048, + "official_name": "VitDet", + "path": "detectron2", + }, + "class_name": "keras_cv.models>VitDetBackbone", + "config": { + "img_size": 1024, + "patch_size": 16, + "in_chans": 3, + "embed_dim": 1280, + "depth": 32, + "mlp_dim": 1280 * 4, + "num_heads": 16, + "out_chans": 256, + "use_bias": True, + "use_rel_pos": True, + "window_size": 14, + "global_attention_indices": [7, 15, 23, 31], + }, + # "weights_url": "https://storage.googleapis.com/keras-cv/models/segment_anything/sam_vit_h.weights.h5", # noqa: E501 + # "weights_hash": None + }, +} diff --git a/keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py b/keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py new file mode 100644 index 0000000000..7166dfd6e5 --- /dev/null +++ b/keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py @@ -0,0 +1,95 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for loading pretrained model presets.""" + +import pathlib + +import numpy as np +import pytest + +from keras_cv.backend import ops +from keras_cv.models.backbones.detectron2.detectron2_aliases import ( + SAMViTDetBBackbone, +) +from keras_cv.models.backbones.detectron2.detectron2_backbone import ( + ViTDetBackbone, +) +from keras_cv.tests.test_case import TestCase + + +@pytest.mark.large +class ViTDetPresetSmokeTest(TestCase): + """ + A smoke test for ViTDet presets we run continuously. + This only tests the smallest weights we have available. Run with: + `pytest keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py --run_large` # noqa: E501 + """ + + def setUp(self): + self.input_batch = np.ones(shape=(1, 1024, 1024, 3)) + + def test_backbone_output(self): + model = ViTDetBackbone.from_preset("sam_vitdet_b") + outputs = model(self.input_batch) + + # The forward pass from a preset should be stable! + # This test should catch cases where we unintentionally change our + # network code in a way that would invalidate our preset weights. + # We should only update these numbers if we are updating a weights + # file, or have found a discrepancy with the upstream source. + + expected = np.load( + pathlib.Path(__file__).parent / "data" / "sam_vitdet_b_out.npz" + ) + # Keep a high tolerance, so we are robust to different hardware. + self.assertAllClose( + ops.convert_to_numpy(outputs), + expected, + atol=1e-5, + rtol=1e-5, + ) + + def test_applications_model_output(self): + model = SAMViTDetBBackbone() + model(self.input_batch) + + def test_applications_model_output_with_preset(self): + model = SAMViTDetBBackbone.from_preset("sam_vitdet_b") + model(self.input_batch) + + def test_preset_docstring(self): + """Check we did our docstring formatting correctly.""" + for name in ViTDetBackbone.presets: + self.assertRegex(ViTDetBackbone.from_preset.__doc__, name) + + def test_unknown_preset_error(self): + # Not a preset name + with self.assertRaises(ValueError): + ViTDetBackbone.from_preset("vitdet_nonexistant") + + +@pytest.mark.extra_large +class ViTDetPresetFullTest(TestCase): + """ + Test the full enumeration of our preset. + This tests every preset for ViTDet and is only run manually. + Run with: + `pytest keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py --run_extra_large` # noqa: E501 + """ + + def test_load_ViTDet(self): + input_data = np.ones(shape=(1, 1024, 1024, 3)) + for preset in ViTDetBackbone.presets: + model = ViTDetBackbone.from_preset(preset) + model(input_data) diff --git a/keras_cv/models/backbones/detectron2/detectron2_backbone_test.py b/keras_cv/models/backbones/detectron2/detectron2_backbone_test.py new file mode 100644 index 0000000000..c880899d5b --- /dev/null +++ b/keras_cv/models/backbones/detectron2/detectron2_backbone_test.py @@ -0,0 +1,45 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.backbones.detectron2.detectron2_aliases import ( + SAMViTDetBBackbone, +) +from keras_cv.tests.test_case import TestCase + + +@pytest.mark.extra_large +class TestViTDetBackbone(TestCase): + def test_call_and_save(self): + model = SAMViTDetBBackbone() + x = np.ones((1, 1024, 1024, 3)) + x_out = ops.convert_to_numpy(model(x)) + num_parameters = sum( + np.prod(tuple(x.shape)) for x in model.trainable_variables + ) + self.assertEqual(x_out.shape, (1, 64, 64, 256)) + self.assertEqual(num_parameters, 89_670_912) + + # saving test + path = os.path.join(self.get_temp_dir(), "sam_tf_model.keras") + model.save(path) + loaded_model = keras.saving.load_model(path) + x_out_loaded = ops.convert_to_numpy(loaded_model(x)) + self.assertAllClose(x_out, x_out_loaded) diff --git a/keras_cv/models/segmentation/__init__.py b/keras_cv/models/segmentation/__init__.py index 4c9b7460fb..7410687444 100644 --- a/keras_cv/models/segmentation/__init__.py +++ b/keras_cv/models/segmentation/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from keras_cv.models.segmentation.deeplab_v3_plus import DeepLabV3Plus -from keras_cv.models.segmentation.segment_anything import ImageEncoder from keras_cv.models.segmentation.segment_anything import MaskDecoder from keras_cv.models.segmentation.segment_anything import PromptEncoder from keras_cv.models.segmentation.segment_anything import TwoWayTransformer diff --git a/keras_cv/models/segmentation/segment_anything/__init__.py b/keras_cv/models/segmentation/segment_anything/__init__.py index 7f1fda07f0..982989c216 100644 --- a/keras_cv/models/segmentation/segment_anything/__init__.py +++ b/keras_cv/models/segmentation/segment_anything/__init__.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv.models.segmentation.segment_anything.sam_image_encoder import ( - ImageEncoder, -) from keras_cv.models.segmentation.segment_anything.sam_mask_decoder import ( MaskDecoder, ) diff --git a/keras_cv/models/segmentation/segment_anything/sam_layers.py b/keras_cv/models/segmentation/segment_anything/sam_layers.py index 4dcde4a9b3..fb55f7b5cf 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_layers.py +++ b/keras_cv/models/segmentation/segment_anything/sam_layers.py @@ -37,9 +37,6 @@ def __init__(self, embedding_dim, mlp_dim, activation="gelu", **kwargs): self.mlp_dim = mlp_dim self.activation = activation - self.built = False - - def build(self, input_shape=None): self.dense_layer1.build([self.embedding_dim]) self.dense_layer2.build([self.mlp_dim]) diff --git a/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py b/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py index 486c6cd252..2df711593d 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py +++ b/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py @@ -151,33 +151,6 @@ def __init__( self.iou_token.build(None) self.mask_tokens.build(None) - self.built = False - - def build( - self, - image_embeddings_shape, - image_pe_shape, - sparse_prompt_embeddings_shape, - dense_prompt_embeddings_shape, - *args, - **kwargs, - ): - transformer_image_embed_shape = [ - None, - image_embeddings_shape[1], - image_embeddings_shape[2], - image_embeddings_shape[3], - ] - tokens_shape = [ - None, - 1 + self.num_mask_tokens + sparse_prompt_embeddings_shape[1], - self.transformer_dim, - ] - self.transformer.build( - image_embedding_shape=transformer_image_embed_shape, - image_pe_shape=transformer_image_embed_shape, - point_embedding_shape=tokens_shape, - ) self.output_upscaling.build([None, None, None, self.transformer_dim]) for mlp in self.output_hypernetworks_mlps: diff --git a/keras_cv/models/segmentation/segment_anything/sam_test.py b/keras_cv/models/segmentation/segment_anything/sam_test.py index 4492010056..109f4a1c16 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_test.py +++ b/keras_cv/models/segmentation/segment_anything/sam_test.py @@ -15,19 +15,9 @@ import os import numpy as np -import pytest from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.segmentation.segment_anything.sam_image_encoder import ( - ImageEncoder, -) -from keras_cv.models.segmentation.segment_anything.sam_image_encoder import ( - MultiHeadAttentionWithRelativePE, -) -from keras_cv.models.segmentation.segment_anything.sam_image_encoder import ( - WindowedTransformerEncoder, -) from keras_cv.models.segmentation.segment_anything.sam_mask_decoder import ( MaskDecoder, ) @@ -44,63 +34,6 @@ class TestSAM(TestCase): - def test_multi_head_attention_with_relative_pe(self): - attention_with_rel_pe = MultiHeadAttentionWithRelativePE( - num_heads=16, - key_dim=1280 // 16, - use_bias=True, - input_size=(64, 64), - ) - x = np.ones(shape=(1, 64, 64, 1280)) - x_out = ops.convert_to_numpy(attention_with_rel_pe(x)) - self.assertEqual(x_out.shape, (1, 64, 64, 1280)) - - def test_windowed_transformer_encoder(self): - windowed_transformer_encoder = WindowedTransformerEncoder( - project_dim=1280, - mlp_dim=1280 * 4, - num_heads=16, - use_bias=True, - use_rel_pos=True, - window_size=14, - input_size=(64, 64), - ) - x = np.ones((1, 64, 64, 1280)) - x_out = ops.convert_to_numpy(windowed_transformer_encoder(x)) - self.assertEqual(x_out.shape, (1, 64, 64, 1280)) - self.assertAllClose(x_out, np.ones_like(x_out)) - - @pytest.mark.extra_large - def test_image_encoder(self): - image_encoder = ImageEncoder( - img_size=1024, - patch_size=16, - in_chans=3, - embed_dim=1280, - depth=32, - mlp_dim=1280 * 4, - num_heads=16, - out_chans=256, - use_bias=True, - use_rel_pos=True, - window_size=14, - global_attention_indices=[7, 15, 23, 31], - ) - x = np.ones((1, 1024, 1024, 3)) - x_out = ops.convert_to_numpy(image_encoder(x)) - num_parameters = sum( - np.prod(tuple(x.shape)) for x in image_encoder.trainable_variables - ) - self.assertEqual(x_out.shape, (1, 64, 64, 256)) - self.assertEqual(num_parameters, 637_026_048) - - # saving test - path = os.path.join(self.get_temp_dir(), "sam_tf_image_encoder.keras") - image_encoder.save(path) - loaded_model = keras.saving.load_model(path) - x_out_loaded = ops.convert_to_numpy(loaded_model(x)) - self.assertAllClose(x_out, x_out_loaded) - def get_points_labels_box_mask(self, B): prompt_encoder = PromptEncoder( embed_dim=256, @@ -189,7 +122,7 @@ def test_two_way_multi_head_attention(self): box, input_mask, ) = self.get_points_labels_box_mask(1) - image_embeddings = np.random.randn(1, 64, 64, 256) + image_embeddings = np.random.randn(1, 64, 64, 256).astype(np.float32) sparse_embeddings, _ = prompt_encoder( points=points, labels=labels, box=box, mask=input_mask diff --git a/keras_cv/models/segmentation/segment_anything/sam_transformer.py b/keras_cv/models/segmentation/segment_anything/sam_transformer.py index 7ea664deff..4850f5b916 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_transformer.py +++ b/keras_cv/models/segmentation/segment_anything/sam_transformer.py @@ -57,7 +57,12 @@ def __init__(self, num_heads, key_dim, downsample_rate=1, **kwargs): # Upsample self.out_proj = keras.layers.Dense(self.key_dim * self.num_heads) - self.built = False + self.query_proj.build([self.num_heads * self.key_dim]) + self.key_proj.build([self.num_heads * self.key_dim]) + self.value_proj.build([self.num_heads * self.key_dim]) + self.out_proj.build([self.internal_dims * self.num_heads]) + + self.built = True def __separate_heads(self, x): B, N, C = x.shape @@ -69,14 +74,6 @@ def __recombine_heads(self, x): x = ops.transpose(x, axes=(0, 2, 1, 3)) return ops.reshape(x, (B, N_T, N_H * C_PH)) - def build(self, query_shape, value_shape, key_shape): - self.query_proj.build(query_shape) - self.key_proj.build(key_shape) - self.value_proj.build(value_shape) - self.out_proj.build([self.internal_dims * self.num_heads]) - - self.built = True - def call(self, query, value, key): query = self.query_proj(query) key = self.key_proj(key) @@ -170,34 +167,14 @@ def __init__( ) self.layer_norm4 = keras.layers.LayerNormalization(epsilon=1e-5) - self.built = False - - def build(self, queries_shape, keys_shape, query_pe_shape, key_pe_shape): - self.self_attention.build( - query_shape=queries_shape, - value_shape=queries_shape, - key_shape=queries_shape, - ) - self.layer_norm1.build(queries_shape) - self.cross_attention_token_to_image.build( - query_shape=queries_shape, - key_shape=keys_shape, - value_shape=keys_shape, - ) - self.layer_norm2.build(queries_shape) - self.mlp_block.build(queries_shape) - self.layer_norm3.build(queries_shape) - self.cross_attention_image_to_token.build( - query_shape=keys_shape, - key_shape=queries_shape, - value_shape=queries_shape, - ) - self.layer_norm4.build(keys_shape) + self.layer_norm1.build([None, None, self.num_heads * self.key_dim]) + self.layer_norm2.build([None, None, self.num_heads * self.key_dim]) + self.layer_norm3.build([None, None, self.num_heads * self.key_dim]) + self.layer_norm4.build([None, None, self.num_heads * self.key_dim]) self.built = True def call(self, queries, keys, query_pe, key_pe): - # print("Actual queries_shape:", queries.shape) if self.skip_first_layer_pe: queries = self.self_attention( query=queries, value=queries, key=queries @@ -325,26 +302,7 @@ def __init__( ) self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5) - self.built = False - - def build( - self, image_embedding_shape, image_pe_shape, point_embedding_shape - ): - B, H, W, C = image_embedding_shape - image_embedding_shape = [B, H * W, C] - for layer in self.layers: - layer.build( - queries_shape=point_embedding_shape, - keys_shape=image_embedding_shape, - query_pe_shape=point_embedding_shape, - key_pe_shape=image_embedding_shape, - ) - self.final_attention_token_to_image.build( - query_shape=point_embedding_shape, - key_shape=image_embedding_shape, - value_shape=image_embedding_shape, - ) - self.final_layer_norm.build(point_embedding_shape) + self.final_layer_norm.build([None, None, self.embedding_dim]) self.built = True