diff --git a/keras_cv/backend/tf_ops.py b/keras_cv/backend/tf_ops.py index dc93712095..106c9d0a33 100644 --- a/keras_cv/backend/tf_ops.py +++ b/keras_cv/backend/tf_ops.py @@ -21,6 +21,7 @@ from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401 # Some TF APIs where the numpy API doesn't support raggeds that we need +from tensorflow import broadcast_to # noqa: F403, F401 from tensorflow import concat as concatenate # noqa: F403, F401 from tensorflow import range as arange # noqa: F403, F401 from tensorflow import reduce_all as all # noqa: F403, F401 diff --git a/keras_cv/layers/__init__.py b/keras_cv/layers/__init__.py index 342a942f64..0bfa2aa8ec 100644 --- a/keras_cv/layers/__init__.py +++ b/keras_cv/layers/__init__.py @@ -135,4 +135,9 @@ ) from keras_cv.layers.spatial_pyramid import SpatialPyramidPooling from keras_cv.layers.transformer_encoder import TransformerEncoder +from keras_cv.layers.vit_det_layers import AddRelativePositionalEmbedding +from keras_cv.layers.vit_det_layers import MultiHeadAttentionWithRelativePE +from keras_cv.layers.vit_det_layers import ViTDetPatchingAndEmbedding +from keras_cv.layers.vit_det_layers import WindowedTransformerEncoder +from keras_cv.layers.vit_det_layers import WindowPartitioning from keras_cv.layers.vit_layers import PatchingAndEmbedding diff --git a/keras_cv/layers/vit_det_layers.py b/keras_cv/layers/vit_det_layers.py new file mode 100644 index 0000000000..78c0b0bfb6 --- /dev/null +++ b/keras_cv/layers/vit_det_layers.py @@ -0,0 +1,590 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops + + +class MLP(keras.layers.Layer): + """A MLP block with architecture + `input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim`. + + Args: + hidden_dim (int): The number of units in the hidden layers. + output_dim (int): The number of units in the output layer. + num_layers (int): The total number of dense layers to use. + activation (str): Activation to use in the hidden layers. + Default is `"relu"`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + - [Detectron2](https://github.com/facebookresearch/detectron2) + """ # noqa: E501 + + def __init__( + self, hidden_dim, output_dim, num_layers, activation="relu", **kwargs + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_layers = num_layers + self.activation = activation + h = [hidden_dim] * (num_layers - 1) + self.dense_net = [] + for hidden_dim in h: + self.dense_net.append(keras.layers.Dense(hidden_dim)) + self.dense_net.append(keras.layers.Activation(activation)) + self.dense_net.append(keras.layers.Dense(output_dim)) + self.dense_net = keras.models.Sequential(self.dense_net) + + def build(self, input_shape): + self.dense_net.build(input_shape) + self.built = True + + def call(self, x): + return self.dense_net(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "num_layers": self.num_layers, + "activation": self.activation, + } + ) + return config + + +@keras_cv_export( + "keras_cv.layers.AddRelativePositionalEmbedding", package="keras_cv.layers" +) +class AddRelativePositionalEmbedding(keras.layers.Layer): + def __init__(self, input_size, key_dim, **kwargs): + super().__init__(**kwargs) + self.input_size = input_size + self.key_dim = key_dim + self.rel_pos_h = self.add_weight( + name="rel_pos_h", + shape=(2 * self.input_size[0] - 1, self.key_dim), + initializer="zeros", + trainable=True, + ) + self.rel_pos_w = self.add_weight( + name="rel_pos_w", + shape=(2 * self.input_size[1] - 1, self.key_dim), + initializer="zeros", + trainable=True, + ) + self.built = True + + def _get_rel_pos(self, query_size, key_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions + of query and key sizes. + + Args: + query_size (int): The number of features of the queries. + key_size (int): The number of features of the keys. + rel_pos (tensor): Relative positional embedding tensor. + + Returns: + tensor: Extracted positional embeddings according to relative + positions. + """ + max_rel_dist = 2 * max(query_size, key_size) - 1 + + if ops.shape(rel_pos)[0] != max_rel_dist: + rel_pos_resized = ops.image.resize( + image=ops.reshape( + rel_pos, + (1, ops.shape(rel_pos)[0], ops.shape(rel_pos)[1], 1), + ), + size=(max_rel_dist, ops.shape(rel_pos)[1]), + interpolation="bilinear", + ) + rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1)) + return rel_pos_resized + else: + rel_pos_resized = rel_pos + query_coordinates = np.arange(query_size, dtype="float32")[:, None] * ( + max(key_size / query_size, 1.0) + ) + key_coordinates = np.arange(key_size, dtype="float32")[None, :] * ( + max(query_size / key_size, 1.0) + ) + relative_coordinates = (query_coordinates - key_coordinates) + ( + key_size - 1 + ) * max(query_size / key_size, 1.0) + relative_coordinates = relative_coordinates.astype("int32") + return ops.take(rel_pos_resized, relative_coordinates, 0) + + def call(self, attention_map, queries, query_size, key_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + + Args: + attention_map (tensor): Attention map. + queries (tensor): Queries in the attention layer with shape + `(B, q_h * q_w, C)`. + query_size (tuple[int, int]): Spatial sequence size of queries with + `(q_h, q_w)`. + key_size (tuple[int, int]): Spatial sequence size of keys with + `(k_h, k_w)`. + + Returns: + tensor: attention map with added relative positional embeddings. + + References: + - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501 + """ + query_height, query_width = query_size[0], query_size[1] + key_height, key_width = key_size[0], key_size[1] + rel_heights = self._get_rel_pos( + query_height, key_height, self.rel_pos_h + ) + rel_widths = self._get_rel_pos(query_width, key_width, self.rel_pos_w) + + shape = ops.shape(queries) + B, C = shape[0], shape[2] + rel_queries = ops.reshape(queries, (B, query_height, query_width, C)) + rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights) + rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths) + + attention_map = ops.reshape( + attention_map, (B, query_height, query_width, key_height, key_width) + ) + attention_map = attention_map + rel_heights[..., :, None] + attention_map = attention_map + rel_widths[..., None, :] + attention_map = ops.reshape( + attention_map, + (B, query_height * query_width, key_height * key_width), + ) + return attention_map + + def get_config(self): + config = super().get_config() + config.update({"input_size": self.input_size, "key_dim": self.key_dim}) + return config + + +@keras_cv_export( + "keras_cv.layers.MultiHeadAttentionWithRelativePE", + package="keras_cv.layers", +) +class MultiHeadAttentionWithRelativePE(keras.layers.Layer): + """Multi-head Attention block with relative position embeddings. + + Args: + num_heads (int): Number of attention heads. + key_dim (int): Size of each attention head for query, key, and + value. + use_bias (bool, optional): Whether to use bias when projecting + the queries, keys, and values. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + embeddings or not. Defaults to `False`. + input_size (tuple[int, int], optional): Size of the input image. + Must be provided when using relative positional embeddings. + Defaults to `None`. + + Raises: + ValueError: When `input_size = None` with `use_rel_pos = True`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + - [Detectron2](https://github.com/facebookresearch/detectron2) + """ # noqa: E501 + + def __init__( + self, + num_heads, + key_dim, + use_bias=True, + use_rel_pos=False, + input_size=None, + **kwargs + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.key_dim = key_dim + self.scale = self.key_dim**-0.5 + self.use_bias = use_bias + self.input_size = input_size + self.use_rel_pos = use_rel_pos + + self.qkv = keras.layers.Dense( + key_dim * self.num_heads * 3, use_bias=self.use_bias + ) + self.projection = keras.layers.Dense(key_dim * self.num_heads) + + if self.use_rel_pos: + if input_size is None: + raise ValueError( + "Input size must be provided if using relative " + "positional encoding." + ) + self.add_decomposed_reative_pe = AddRelativePositionalEmbedding( + self.input_size, self.key_dim + ) + + def build(self, input_shape=None): + self.qkv.build([self.key_dim * self.num_heads]) + self.projection.build([self.key_dim * self.num_heads]) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + shape = ops.shape(x) + B, H, W, C = shape[0], shape[1], shape[2], shape[3] + qkv = ops.transpose( + ops.reshape( + self.qkv(x), (B, H * W, 3, self.num_heads, self.key_dim) + ), + axes=(2, 0, 3, 1, 4), + ) + qkv = ops.reshape(qkv, (3, B * self.num_heads, H * W, self.key_dim)) + queries, keys, values = ops.unstack(qkv, axis=0) + attention_map = (queries * self.scale) @ ops.transpose( + keys, axes=(0, 2, 1) + ) + + if self.use_rel_pos: + attention_map = self.add_decomposed_reative_pe( + attention_map, + queries=queries, + query_size=(H, W), + key_size=(H, W), + ) + attention_map = ops.softmax(attention_map, axis=-1) + x = ops.reshape( + attention_map @ values, (B, self.num_heads, H, W, self.key_dim) + ) + x = ops.transpose(x, axes=(0, 2, 3, 1, 4)) + x = ops.reshape(x, (B, H, W, C)) + x = self.projection(x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "use_bias": self.use_bias, + "use_rel_pos": self.use_rel_pos, + "input_size": self.input_size, + } + ) + return config + + +@keras_cv_export( + "keras_cv.layers.WindowPartitioning", package="keras_cv.layers" +) +class WindowPartitioning(keras.layers.Layer): + def __init__(self, window_size, **kwargs): + super().__init__(**kwargs) + self.window_size = window_size + self.built = True + + def partition(self, x): + shape = ops.shape(x) + B, H, W, C = shape[0], shape[1], shape[2], shape[3] + pad_height = ( + self.window_size - H % self.window_size + ) % self.window_size + pad_width = (self.window_size - W % self.window_size) % self.window_size + if pad_height > 0 or pad_width > 0: + x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0))) + H_padded, W_padded = H + pad_height, W + pad_width + x = ops.reshape( + x, + ( + B, + H_padded // self.window_size, + self.window_size, + W_padded // self.window_size, + self.window_size, + C, + ), + ) + windows = ops.reshape( + ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), + (-1, self.window_size, self.window_size, C), + ) + return windows, (H_padded, W_padded) + + def unpartition(self, windows, HW_padded, HW): + H_padded, W_padded = HW_padded + H, W = HW + B = ops.shape(windows)[0] // ( + (H_padded // self.window_size) * (W_padded // self.window_size) + ) + x = ops.reshape( + windows, + ( + B, + H_padded // self.window_size, + W_padded // self.window_size, + self.window_size, + self.window_size, + -1, + ), + ) + x = ops.reshape( + ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), + (B, H_padded, W_padded, -1), + ) + return x[:, :H, :W, :] + + def get_config(self): + config = super().get_config() + config.update({"window_size": self.window_size}) + return config + + +@keras_cv_export( + "keras_cv.layers.WindowedTransformerEncoder", package="keras_cv.layers" +) +class WindowedTransformerEncoder(keras.layers.Layer): + """Transformer blocks with support of window attention and residual + propagation blocks. + + Args: + project_dim (int): the dimensionality of the projection of the + encoder, and output of the `MultiHeadAttention`. + mlp_dim (int): the intermediate dimensionality of the MLP head before + projecting to `project_dim`. + num_heads (int): the number of heads for the `MultiHeadAttention` + layer. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `False`. + window_size (int, optional): Window size for windowed attention. + Defaults to `0`. + input_size (tuple[int, int], optional): Height and width of the input + image as a tuple of integers. Must be provided when using relative + positional embeddings. Defaults to `None`. + activation (str, optional): the activation function to apply in the + MLP head - should be a function. Defaults to `"gelu"`. + layer_norm_epsilon (float, optional): The epsilon to use in the layer + normalization layers. Defaults to `1e-6`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + - [Detectron2](https://github.com/facebookresearch/detectron2) + """ # noqa: E501 + + def __init__( + self, + project_dim, + mlp_dim, + num_heads, + use_bias=True, + use_rel_pos=False, + window_size=0, + input_size=None, + activation="gelu", + layer_norm_epsilon=1e-6, + **kwargs + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.use_bias = use_bias + self.input_size = input_size + self.activation = activation + self.layer_norm_epsilon = layer_norm_epsilon + self.window_size = window_size + self.use_rel_pos = use_rel_pos + + self.layer_norm1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self.layer_norm2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self.attention = MultiHeadAttentionWithRelativePE( + num_heads=self.num_heads, + key_dim=self.project_dim // self.num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + input_size=input_size + if window_size == 0 + else (window_size, window_size), + ) + self.mlp_block = MLP( + mlp_dim, + project_dim, + num_layers=2, + activation="gelu", + ) + self.window_partitioning = WindowPartitioning(window_size) + + def build(self, input_shape=None): + self.layer_norm1.build([None, None, None, self.project_dim]) + self.layer_norm2.build([None, None, None, self.project_dim]) + self.attention.build() + self.mlp_block.build([None, None, None, self.project_dim]) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + shortcut = x + x = self.layer_norm1(x) + # Window Partition + if self.window_size > 0: + H, W = ops.shape(x)[1], ops.shape(x)[2] + x, HW_padded = self.window_partitioning.partition(x) + + x = self.attention(x) + # Reverse Window Partition + if self.window_size > 0: + x = self.window_partitioning.unpartition( + x, HW_padded=HW_padded, HW=(H, W) + ) + + x = shortcut + x + x = x + self.mlp_block(self.layer_norm2(x)) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "mlp_dim": self.mlp_dim, + "num_heads": self.num_heads, + "use_bias": self.use_bias, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "input_size": self.input_size, + "activation": self.activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + +@keras_cv_export( + "keras_cv.layers.ViTDetPatchingAndEmbedding", package="keras_cv.layers" +) +class ViTDetPatchingAndEmbedding(keras.layers.Layer): + """Image to Patch Embedding using only a conv layer (without + layer normalization). + + Args: + kernel_size (tuple[int, int], optional): Kernel size of the + projection layer. Defaults to `(16, 16)`. + strides (tuple, optional): Strides of the projection layer. + Defaults to `(16, 16)`. + embed_dim (int, optional): Number of filters to use in the + projection layer i.e. projection size. Defaults to `768`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + - [Detectron2](https://github.com/facebookresearch/detectron2) + """ # noqa: E501 + + def __init__( + self, kernel_size=(16, 16), strides=(16, 16), embed_dim=768, **kwargs + ): + super().__init__(**kwargs) + + self.projection = keras.layers.Conv2D( + embed_dim, kernel_size=kernel_size, strides=strides + ) + + self.kernel_size = kernel_size + self.strides = strides + self.embed_dim = embed_dim + + def build(self, input_shape): + self.projection.build(input_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return self.projection.compute_output_shape(input_shape) + + def call(self, x): + x = self.projection(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "kernel_size": self.kernel_size, + "strides": self.strides, + "embed_dim": self.embed_dim, + } + ) + return config + + +# TODO: Merge this with the `keras_cv.layers.PatchingAndEmbedding` class once +# it has been ported to Keras Core. +@keras_cv_export( + "keras_cv.layers.AddPositionalEmbedding", package="keras_cv.layers" +) +class AddPositionalEmbedding(keras.layers.Layer): + def __init__(self, img_size, patch_size, embed_dim, **kwargs): + super().__init__(**kwargs) + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.pos_embed = self.add_weight( + name="pos_embed", + shape=( + 1, + img_size // patch_size, + img_size // patch_size, + embed_dim, + ), + initializer="zeros", + trainable=True, + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + return x + self.pos_embed + + def get_confg(self): + config = super().get_config() + config.update( + { + "img_size": self.img_size, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + } + ) + return config diff --git a/keras_cv/layers/vit_det_layers_test.py b/keras_cv/layers/vit_det_layers_test.py new file mode 100644 index 0000000000..05c698730e --- /dev/null +++ b/keras_cv/layers/vit_det_layers_test.py @@ -0,0 +1,64 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from keras_cv.backend import ops +from keras_cv.layers.vit_det_layers import AddPositionalEmbedding +from keras_cv.layers.vit_det_layers import MultiHeadAttentionWithRelativePE +from keras_cv.layers.vit_det_layers import ViTDetPatchingAndEmbedding +from keras_cv.layers.vit_det_layers import WindowedTransformerEncoder +from keras_cv.tests.test_case import TestCase + + +class TestViTDetLayers(TestCase): + def test_multi_head_attention_with_relative_pe(self): + attention_with_rel_pe = MultiHeadAttentionWithRelativePE( + num_heads=16, + key_dim=1280 // 16, + use_bias=True, + input_size=(64, 64), + ) + x = np.ones(shape=(1, 64, 64, 1280)) + x_out = ops.convert_to_numpy(attention_with_rel_pe(x)) + self.assertEqual(x_out.shape, (1, 64, 64, 1280)) + + def test_windowed_transformer_encoder(self): + windowed_transformer_encoder = WindowedTransformerEncoder( + project_dim=1280, + mlp_dim=1280 * 4, + num_heads=16, + use_bias=True, + use_rel_pos=True, + window_size=14, + input_size=(64, 64), + ) + x = np.ones((1, 64, 64, 1280)) + x_out = ops.convert_to_numpy(windowed_transformer_encoder(x)) + self.assertEqual(x_out.shape, (1, 64, 64, 1280)) + self.assertAllClose(x_out, np.ones_like(x_out)) + + def test_vit_patching_and_embedding(self): + vit_patching_and_embedding = ViTDetPatchingAndEmbedding() + x = np.ones((1, 1024, 1024, 3)) + x_out = vit_patching_and_embedding(x) + self.assertEqual(x_out.shape, (1, 64, 64, 768)) + + def test_add_positional_embedding(self): + add_positional_embedding = AddPositionalEmbedding( + img_size=1024, patch_size=16, embed_dim=256 + ) + x = np.ones((1, 64, 64, 256)) + x_out = add_positional_embedding(x) + self.assertEqual(x_out.shape, (1, 64, 64, 256)) diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index 9c83a3891a..ae775ed824 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -178,6 +178,10 @@ from keras_cv.models.backbones.resnet_v2.resnet_v2_backbone import ( ResNetV2Backbone, ) +from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone +from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetHBackbone +from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone +from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone from keras_cv.models.classification.image_classifier import ImageClassifier from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import ( @@ -187,6 +191,10 @@ YOLOV8Detector, ) from keras_cv.models.segmentation import DeepLabV3Plus +from keras_cv.models.segmentation import SAMMaskDecoder +from keras_cv.models.segmentation import SAMPromptEncoder +from keras_cv.models.segmentation import SegmentAnythingModel +from keras_cv.models.segmentation import TwoWayTransformer from keras_cv.models.segmentation.segformer.segformer_aliases import SegFormer from keras_cv.models.segmentation.segformer.segformer_aliases import SegFormerB0 from keras_cv.models.segmentation.segformer.segformer_aliases import SegFormerB1 diff --git a/keras_cv/models/backbones/backbone_presets.py b/keras_cv/models/backbones/backbone_presets.py index 614f85cd24..95d3ccd522 100644 --- a/keras_cv/models/backbones/backbone_presets.py +++ b/keras_cv/models/backbones/backbone_presets.py @@ -28,6 +28,7 @@ from keras_cv.models.backbones.mobilenet_v3 import mobilenet_v3_backbone_presets from keras_cv.models.backbones.resnet_v1 import resnet_v1_backbone_presets from keras_cv.models.backbones.resnet_v2 import resnet_v2_backbone_presets +from keras_cv.models.backbones.vit_det import vit_det_backbone_presets from keras_cv.models.object_detection.yolo_v8 import yolo_v8_backbone_presets backbone_presets_no_weights = { @@ -40,6 +41,7 @@ **densenet_backbone_presets.backbone_presets_no_weights, **efficientnet_lite_backbone_presets.backbone_presets_no_weights, **yolo_v8_backbone_presets.backbone_presets_no_weights, + **vit_det_backbone_presets.backbone_presets_no_weights, } backbone_presets_with_weights = { @@ -52,6 +54,7 @@ **densenet_backbone_presets.backbone_presets_with_weights, **efficientnet_lite_backbone_presets.backbone_presets_with_weights, **yolo_v8_backbone_presets.backbone_presets_with_weights, + **vit_det_backbone_presets.backbone_presets_with_weights, } backbone_presets = { diff --git a/keras_cv/models/backbones/vit_det/__init__.py b/keras_cv/models/backbones/vit_det/__init__.py new file mode 100644 index 0000000000..3992ffb59a --- /dev/null +++ b/keras_cv/models/backbones/vit_det/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/models/backbones/vit_det/data/vitdet_base_out.npz b/keras_cv/models/backbones/vit_det/data/vitdet_base_out.npz new file mode 100644 index 0000000000..da8c732ccd Binary files /dev/null and b/keras_cv/models/backbones/vit_det/data/vitdet_base_out.npz differ diff --git a/keras_cv/models/backbones/vit_det/vit_det_aliases.py b/keras_cv/models/backbones/vit_det/vit_det_aliases.py new file mode 100644 index 0000000000..dec51207e3 --- /dev/null +++ b/keras_cv/models/backbones/vit_det/vit_det_aliases.py @@ -0,0 +1,115 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone +from keras_cv.models.backbones.vit_det.vit_det_backbone_presets import ( + backbone_presets, +) +from keras_cv.utils.python_utils import classproperty + +ALIAS_DOCSTRING = """VitDet{size}Backbone model. + + Reference: + - [Detectron2](https://github.com/facebookresearch/detectron2) + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/). + + Examples: + ```python + input_data = np.ones(shape=(1, 1024, 1024, 3)) + + # Randomly initialized backbone + model = VitDet{size}Backbone() + output = model(input_data) + ``` +""" # noqa: E501 + + +class ViTDetBBackbone(ViTDetBackbone): + def __new__( + cls, + **kwargs, + ): + return ViTDetBackbone.from_preset("vitdet_base", **kwargs) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return { + "vitdet_base_sa1b": copy.deepcopy( + backbone_presets["vitdet_base_sa1b"] + ), + } + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return cls.presets + + +class ViTDetLBackbone(ViTDetBackbone): + def __new__( + cls, + **kwargs, + ): + return ViTDetBackbone.from_preset("vitdet_large", **kwargs) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return { + "vitdet_large_sa1b": copy.deepcopy( + backbone_presets["vitdet_large_sa1b"] + ), + } + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return cls.presets + + +class ViTDetHBackbone(ViTDetBackbone): + def __new__( + cls, + **kwargs, + ): + return ViTDetBackbone.from_preset("vitdet_huge", **kwargs) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return { + "vitdet_huge_sa1b": copy.deepcopy( + backbone_presets["vitdet_huge_sa1b"] + ), + } + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return cls.presets + + +setattr(ViTDetBBackbone, "__doc__", ALIAS_DOCSTRING.format(size="B")) +setattr(ViTDetLBackbone, "__doc__", ALIAS_DOCSTRING.format(size="L")) +setattr(ViTDetHBackbone, "__doc__", ALIAS_DOCSTRING.format(size="H")) diff --git a/keras_cv/models/backbones/vit_det/vit_det_backbone.py b/keras_cv/models/backbones/vit_det/vit_det_backbone.py new file mode 100644 index 0000000000..26e2d5d190 --- /dev/null +++ b/keras_cv/models/backbones/vit_det/vit_det_backbone.py @@ -0,0 +1,218 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.layers.vit_det_layers import AddPositionalEmbedding +from keras_cv.layers.vit_det_layers import ViTDetPatchingAndEmbedding +from keras_cv.layers.vit_det_layers import WindowedTransformerEncoder +from keras_cv.models import utils +from keras_cv.models.backbones.backbone import Backbone +from keras_cv.models.backbones.vit_det.vit_det_backbone_presets import ( + backbone_presets, +) +from keras_cv.models.backbones.vit_det.vit_det_backbone_presets import ( + backbone_presets_with_weights, +) +from keras_cv.utils.python_utils import classproperty + + +@keras_cv_export("keras_cv.models.ViTDetBackbone", package="keras_cv.models") +class ViTDetBackbone(Backbone): + """A ViT image encoder that uses a windowed transformer encoder and + relative positional encodings. + + Args: + input_shape (tuple[int], optional): The size of the input image in + `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. + input_tensor (KerasTensor, optional): Output of + `keras.layers.Input()`) to use as image input for the model. + Defaults to `None`. + include_rescaling (bool, optional): Whether to rescale the inputs. If + set to `True`, inputs will be passed through a + `Rescaling(1/255.0)` layer. Defaults to `False`. + patch_size (int, optional): the patch size to be supplied to the + Patching layer to turn input images into a flattened sequence of + patches. Defaults to `16`. + embed_dim (int, optional): The latent dimensionality to be projected + into in the output of each stacked windowed transformer encoder. + Defaults to `768`. + depth (int, optional): The number of transformer encoder layers to + stack in the Vision Transformer. Defaults to `12`. + mlp_dim (int, optional): The dimensionality of the hidden Dense + layer in the transformer MLP head. Defaults to `768*4`. + num_heads (int, optional): the number of heads to use in the + `MultiHeadAttentionWithRelativePE` layer of each transformer + encoder. Defaults to `12`. + out_chans (int, optional): The number of channels (features) in the + output (image encodings). Defaults to `256`. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_abs_pos (bool, optional): Whether to add absolute positional + embeddings to the output patches. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `True`. + window_size (int, optional): The size of the window for windowed + attention in the transformer encoder blocks. Defaults to `14`. + global_attention_indices (list, optional): Indexes for blocks using + global attention. Defaults to `[2, 5, 8, 11]`. + layer_norm_epsilon (int, optional): The epsilon to use in the layer + normalization blocks in transformer encoder. Defaults to `1e-6`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + - [Detectron2](https://github.com/facebookresearch/detectron2) + """ # noqa: E501 + + def __init__( + self, + *, + input_shape=(1024, 1024, 3), + input_tensor=None, + include_rescaling=False, + patch_size=16, + embed_dim=768, + depth=12, + mlp_dim=768 * 4, + num_heads=12, + out_chans=256, + use_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attention_indices=[2, 5, 8, 11], + layer_norm_epsilon=1e-6, + **kwargs + ): + img_input = utils.parse_model_inputs( + input_shape, input_tensor, name="images" + ) + + # Check that the input image is well specified. + if img_input.shape[-3] is None or img_input.shape[-2] is None: + raise ValueError( + "Height and width of the image must be specified" + " in `input_shape`." + ) + if img_input.shape[-3] != img_input.shape[-2]: + raise ValueError( + "Input image must be square i.e. the height must" + " be equal to the width in the `input_shape`" + " tuple/tensor." + ) + + img_size = img_input.shape[-3] + + x = img_input + + if include_rescaling: + # Use common rescaling strategy across keras_cv + x = keras.layers.Rescaling(1.0 / 255.0)(x) + + x = ViTDetPatchingAndEmbedding( + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + embed_dim=embed_dim, + )(x) + if use_abs_pos: + x = AddPositionalEmbedding(img_size, patch_size, embed_dim)(x) + + for i in range(depth): + x = WindowedTransformerEncoder( + project_dim=embed_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + window_size=window_size + if i not in global_attention_indices + else 0, + input_size=(img_size // patch_size, img_size // patch_size), + )(x) + x = keras.models.Sequential( + [ + keras.layers.Conv2D( + filters=out_chans, kernel_size=1, use_bias=False + ), + keras.layers.LayerNormalization(epsilon=1e-6), + keras.layers.Conv2D( + filters=out_chans, + kernel_size=3, + padding="same", + use_bias=False, + ), + keras.layers.LayerNormalization(epsilon=1e-6), + ] + )(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + self.patch_size = patch_size + self.embed_dim = embed_dim + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.out_chans = out_chans + self.use_bias = use_bias + self.use_rel_pos = use_rel_pos + self.use_abs_pos = use_abs_pos + self.window_size = window_size + self.global_attention_indices = global_attention_indices + self.layer_norm_epsilon = layer_norm_epsilon + self.input_tensor = input_tensor + self.include_rescaling = include_rescaling + + @property + def pyramid_level_inputs(self): + raise NotImplementedError( + "The `ViTDetBackbone` model doesn't compute" + " pyramid level features." + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_shape": self.input_shape[1:], + "input_tensor": self.input_tensor, + "include_rescaling": self.include_rescaling, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "depth": self.depth, + "mlp_dim": self.mlp_dim, + "num_heads": self.num_heads, + "out_chans": self.out_chans, + "use_bias": self.use_bias, + "use_abs_pos": self.use_abs_pos, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "global_attention_indices": self.global_attention_indices, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return copy.deepcopy(backbone_presets) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy(backbone_presets_with_weights) diff --git a/keras_cv/models/backbones/vit_det/vit_det_backbone_presets.py b/keras_cv/models/backbones/vit_det/vit_det_backbone_presets.py new file mode 100644 index 0000000000..825f157ed7 --- /dev/null +++ b/keras_cv/models/backbones/vit_det/vit_det_backbone_presets.py @@ -0,0 +1,162 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VitDet model preset configurations.""" + +backbone_presets_no_weights = { + "vitdet_base": { + "metadata": { + "description": ( + "Detectron2 ViT basebone with 12 " + "transformer encoders with embed dim 768 and attention layers" + " with 12 heads with global attention on encoders 2, 5, 8, " + "and 11." + ), + "params": 89_670_912, + "official_name": "VitDet", + "path": "vit_det", + }, + "class_name": "keras_cv.models>ViTDetBackbone", + "config": { + "input_shape": (1024, 1024, 3), + "input_tensor": None, + "include_rescaling": False, + "patch_size": 16, + "embed_dim": 768, + "depth": 12, + "mlp_dim": 768 * 4, + "num_heads": 12, + "out_chans": 256, + "use_bias": True, + "use_abs_pos": True, + "use_rel_pos": True, + "window_size": 14, + "global_attention_indices": [2, 5, 8, 11], + "layer_norm_epsilon": 1e-6, + }, + }, + "vitdet_large": { + "metadata": { + "description": ( + "Detectron2 ViT basebone with 24 " + "transformer encoders with embed dim " + "1024 and attention layers with 16 heads with global " + "attention on encoders 5, 11, 17, and 23." + ), + "params": 308_278_272, + "official_name": "VitDet", + "path": "vit_det", + }, + "class_name": "keras_cv.models>ViTDetBackbone", + "config": { + "input_shape": (1024, 1024, 3), + "input_tensor": None, + "include_rescaling": False, + "patch_size": 16, + "embed_dim": 1024, + "depth": 24, + "mlp_dim": 1024 * 4, + "num_heads": 16, + "out_chans": 256, + "use_bias": True, + "use_abs_pos": True, + "use_rel_pos": True, + "window_size": 14, + "global_attention_indices": [5, 11, 17, 23], + "layer_norm_epsilon": 1e-6, + }, + }, + "vitdet_huge": { + "metadata": { + "description": ( + "Detectron2 ViT basebone model " + "with 32 transformer encoders with embed dim " + "1280 and attention layers with 16 heads with global " + "attention on encoders 7, 15, 23, and 31." + ), + "params": 637_026_048, + "official_name": "VitDet", + "path": "vit_det", + }, + "class_name": "keras_cv.models>ViTDetBackbone", + "config": { + "input_shape": (1024, 1024, 3), + "input_tensor": None, + "include_rescaling": False, + "patch_size": 16, + "embed_dim": 1280, + "depth": 32, + "mlp_dim": 1280 * 4, + "num_heads": 16, + "out_chans": 256, + "use_bias": True, + "use_abs_pos": True, + "use_rel_pos": True, + "window_size": 14, + "global_attention_indices": [7, 15, 23, 31], + "layer_norm_epsilon": 1e-6, + }, + }, +} + + +backbone_presets_with_weights = { + "vitdet_base_sa1b": { + "metadata": { + "description": ( + "A base Detectron2 ViT backbone trained on the SA1B dataset." + ), + "params": 89_670_912, + "official_name": "VitDet", + "path": "vit_det", + }, + "class_name": "keras_cv.models>ViTDetBackbone", + "config": backbone_presets_no_weights["vitdet_base"]["config"], + "weights_url": "https://storage.googleapis.com/keras-cv/models/vitdet/vitdet_base.h5", # noqa: E501 + "weights_hash": "63c0ca6ff422142f95c24a0223445906728b353469be10c8e34018392207c93a", # noqa: E501 + }, + "vitdet_large_sa1b": { + "metadata": { + "description": ( + "A large Detectron2 ViT backbone trained on the SA1B dataset." + ), + "params": 308_278_272, + "official_name": "VitDet", + "path": "vit_det", + }, + "class_name": "keras_cv.models>ViTDetBackbone", + "config": backbone_presets_no_weights["vitdet_large"]["config"], + "weights_url": "https://storage.googleapis.com/keras-cv/models/vitdet/vitdet_large.h5", # noqa: E501 + "weights_hash": "b85f73ee5a82842aecbc7c706ca69530aaa828d3324d0793a93730c94727b30e", # noqa: E501 + }, + "vitdet_huge_sa1b": { + "metadata": { + "description": ( + "A huge Detectron2 ViT backbone trained on the SA1B dataset." + ), + "params": 637_026_048, + "official_name": "VitDet", + "path": "vit_det", + }, + "class_name": "keras_cv.models>ViTDetBackbone", + "config": backbone_presets_no_weights["vitdet_huge"]["config"], + "weights_url": "https://storage.googleapis.com/keras-cv/models/vitdet/vitdet_huge.h5", # noqa: E501 + "weights_hash": "ae6e1a95acd748f783bddeadd5915fdc6d1c15d23909df3cd4fa446c7d6b9fc1", # noqa: E501 + }, +} + + +backbone_presets = { + **backbone_presets_no_weights, + **backbone_presets_with_weights, +} diff --git a/keras_cv/models/backbones/vit_det/vit_det_backbone_presets_test.py b/keras_cv/models/backbones/vit_det/vit_det_backbone_presets_test.py new file mode 100644 index 0000000000..7575ecd1d8 --- /dev/null +++ b/keras_cv/models/backbones/vit_det/vit_det_backbone_presets_test.py @@ -0,0 +1,96 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for loading pretrained model presets.""" + +import pathlib + +import numpy as np +import pytest + +from keras_cv.backend import ops +from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone +from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone +from keras_cv.tests.test_case import TestCase + + +@pytest.mark.large +class ViTDetPresetSmokeTest(TestCase): + """ + A smoke test for ViTDet presets we run continuously. + This only tests the smallest weights we have available. Run with: + `pytest keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py --run_large` # noqa: E501 + """ + + def setUp(self): + self.input_batch = np.ones(shape=(1, 1024, 1024, 3)) + + def test_backbone_output(self): + model = ViTDetBackbone.from_preset("vitdet_base_sa1b") + outputs = model(self.input_batch) + + # The forward pass from a preset should be stable! + # This test should catch cases where we unintentionally change our + # network code in a way that would invalidate our preset weights. + # We should only update these numbers if we are updating a weights + # file, or have found a discrepancy with the upstream source. + + expected = np.load( + pathlib.Path(__file__).parent / "data" / "vitdet_base_out.npz" + ) + # Keep a high tolerance, so we are robust to different hardware. + self.assertAllClose( + ops.convert_to_numpy(outputs), + expected, + atol=1e-5, + rtol=1e-5, + ) + + def test_applications_model_output(self): + model = ViTDetBBackbone() + model(self.input_batch) + + def test_applications_model_output_with_preset(self): + model = ViTDetBackbone.from_preset("vitdet_base") + model(self.input_batch) + + def test_applications_model_predict(self): + model = ViTDetBBackbone() + # Test that the model XLA compiles + model.predict(self.input_batch) + + def test_preset_docstring(self): + """Check we did our docstring formatting correctly.""" + for name in ViTDetBackbone.presets: + self.assertRegex(ViTDetBackbone.from_preset.__doc__, name) + + def test_unknown_preset_error(self): + # Not a preset name + with self.assertRaises(ValueError): + ViTDetBackbone.from_preset("vitdet_nonexistant") + + +@pytest.mark.extra_large +class ViTDetPresetFullTest(TestCase): + """ + Test the full enumeration of our preset. + This tests every preset for ViTDet and is only run manually. + Run with: + `pytest keras_cv/models/backbones/detectron2/detectron2_backbone_presets_test.py --run_extra_large` # noqa: E501 + """ + + def test_load_ViTDet(self): + input_data = np.ones(shape=(1, 1024, 1024, 3)) + for preset in ViTDetBackbone.presets: + model = ViTDetBackbone.from_preset(preset) + model(input_data) diff --git a/keras_cv/models/backbones/vit_det/vit_det_backbone_test.py b/keras_cv/models/backbones/vit_det/vit_det_backbone_test.py new file mode 100644 index 0000000000..fb12438119 --- /dev/null +++ b/keras_cv/models/backbones/vit_det/vit_det_backbone_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone +from keras_cv.tests.test_case import TestCase + + +class TestViTDetBackbone(TestCase): + @pytest.mark.large + def test_call(self): + model = ViTDetBBackbone() + x = np.ones((1, 1024, 1024, 3)) + x_out = ops.convert_to_numpy(model(x)) + num_parameters = sum( + np.prod(tuple(x.shape)) for x in model.trainable_variables + ) + self.assertEqual(x_out.shape, (1, 64, 64, 256)) + self.assertEqual(num_parameters, 89_670_912) + + @pytest.mark.extra_large + def teat_save(self): + # saving test + model = ViTDetBBackbone() + x = np.ones((1, 1024, 1024, 3)) + x_out = ops.convert_to_numpy(model(x)) + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + loaded_model = keras.saving.load_model(path) + x_out_loaded = ops.convert_to_numpy(loaded_model(x)) + self.assertAllClose(x_out, x_out_loaded) + + @pytest.mark.extra_large + def test_fit(self): + model = ViTDetBBackbone() + x = np.ones((1, 1024, 1024, 3)) + y = np.zeros((1, 64, 64, 256)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) + + def test_pyramid_level_inputs_error(self): + model = ViTDetBBackbone() + with self.assertRaises(NotImplementedError, msg="doesn't compute"): + model.pyramid_level_inputs diff --git a/keras_cv/models/segmentation/__init__.py b/keras_cv/models/segmentation/__init__.py index f25ee4ea7c..aa4ffab4a4 100644 --- a/keras_cv/models/segmentation/__init__.py +++ b/keras_cv/models/segmentation/__init__.py @@ -14,3 +14,7 @@ from keras_cv.models.segmentation.deeplab_v3_plus import DeepLabV3Plus from keras_cv.models.segmentation.segformer import SegFormer +from keras_cv.models.segmentation.segment_anything import SAMMaskDecoder +from keras_cv.models.segmentation.segment_anything import SAMPromptEncoder +from keras_cv.models.segmentation.segment_anything import SegmentAnythingModel +from keras_cv.models.segmentation.segment_anything import TwoWayTransformer diff --git a/keras_cv/models/segmentation/segment_anything/__init__.py b/keras_cv/models/segmentation/segment_anything/__init__.py new file mode 100644 index 0000000000..945955261f --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.segmentation.segment_anything.sam import ( + SegmentAnythingModel, +) +from keras_cv.models.segmentation.segment_anything.sam_mask_decoder import ( + SAMMaskDecoder, +) +from keras_cv.models.segmentation.segment_anything.sam_prompt_encoder import ( + SAMPromptEncoder, +) +from keras_cv.models.segmentation.segment_anything.sam_transformer import ( + TwoWayTransformer, +) diff --git a/keras_cv/models/segmentation/segment_anything/data/sam_base_out_iou_pred.npy b/keras_cv/models/segmentation/segment_anything/data/sam_base_out_iou_pred.npy new file mode 100644 index 0000000000..974516d6f5 Binary files /dev/null and b/keras_cv/models/segmentation/segment_anything/data/sam_base_out_iou_pred.npy differ diff --git a/keras_cv/models/segmentation/segment_anything/data/sam_base_out_masks.npy b/keras_cv/models/segmentation/segment_anything/data/sam_base_out_masks.npy new file mode 100644 index 0000000000..a7514e780f Binary files /dev/null and b/keras_cv/models/segmentation/segment_anything/data/sam_base_out_masks.npy differ diff --git a/keras_cv/models/segmentation/segment_anything/sam.py b/keras_cv/models/segmentation/segment_anything/sam.py new file mode 100644 index 0000000000..c3aec549e1 --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam.py @@ -0,0 +1,256 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.models.backbones.backbone_presets import backbone_presets +from keras_cv.models.backbones.backbone_presets import ( + backbone_presets_with_weights, +) +from keras_cv.models.segmentation.segment_anything.sam_presets import ( + sam_presets, +) +from keras_cv.models.task import Task +from keras_cv.utils.python_utils import classproperty + + +@keras_cv_export( + "keras_cv.models.SegmentAnythingModel", package="keras_cv.models" +) +class SegmentAnythingModel(Task): + """ + The Segment Anything (SAM) Model. + + Args: + backbone (keras_cv.models.Backbone): A feature extractor for the input + images. + prompt_encoder (keras_cv.models.SAMPromptEncoder): A Keras layer to + compute embeddings for points, box, and mask prompt. + mask_decoder (keras_cv.models.SAMMaskDecoder): A Keras layer to + generate segmentation masks given the embeddings generated by the + backbone and the prompt encoder. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + + Examples: + + >>> import numpy as np + >>> from keras_cv.models import ViTDetBBackbone + >>> from keras_cv.models import SAMPromptEncoder + >>> from keras_cv.models import SAMMaskDecoder + + Create all the components of the SAM model: + + >>> backbone = ViTDetBBackbone() + >>> prompt_encoder = SAMPromptEncoder() + >>> mask_decoder = SAMMaskDecoder() + + Instantiate the model: + + >>> sam = SegmentAnythingModel( + ... backbone=backbone, + ... prompt_encoder=prompt_encoder, + ... mask_decoder=mask_decoder + ... ) + + Define the input of the backbone. This must be a batch of images of shape + `(1024, 1024, 3)` for the ViT backbone we are using: + + >>> image = np.ones((1, 1024, 1024, 3)) + + SAM works by prompting the input images. There are three ways to prompt: + + (1) Labelled Points: Foreground points (points with label 1) are encoded + such that the output masks generated by the mask decoder contain them + and background points (points with label 0) are encoded such that the + generated masks don't contain them. + (2) Box: A box tells the model which part/crop of the image to segment. + (3) Mask: An input mask can be used to refine the output of the mask + decoder. + + These prompts can be mixed and matched but at least one of the prompts + must be present. To turn off a particular prompt, a zero size array of the + following shapes must be passed: + + (1) For points prompts, the expected shape is `(batch, num_points, 2)`. If + no point prompt is desired, pass an input of shape `(batch, 0, 2)`. + The labels must have shape `(batch, 0)` in case of no point prompt. + (2) For box prompt, the expected shape is `(batch, 1, 2, 2)`. The second + dimension (`box.shape[1]`) represents whether a box prompt is present + or not. If no box prompt is present, an input of shape + `(batch, 0, 2, 2)` is expected. + (3) Similarly, mask prompts have shape `(batch, 1, H, W, 1)`. Here too, + the first dimension (`mask.shape[1]`) indicates the presence of a mask + prompt. To turn off mask prompts, an input of shape + `(batch, 0, H, W, 1)` must be passed. + + For example, to pass in all the prompts, do: + + >>> points = np.array([[[512., 512.], [100., 100.]]]) + >>> # For labels: 1 means foreground point, 0 means background + >>> labels = np.array([[1., 0.]]) + >>> box = np.array([[[[384., 384.], [640., 640.]]]]) + >>> input_mask = np.ones((1, 1, 256, 256, 1)) + + Prepare an input dictionary: + + >>> inputs = { + ... "images": image, + ... "points": points, + ... "labels": labels, + ... "box": box, + ... "mask": input_mask + ... } + ... + >>> outputs = sam.predict(inputs) + >>> masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + The first mask in the output `masks` (i.e. `masks[:, 0, ...]`) is the best + mask predicted by the model based on the prompts. Other `masks` + (i.e. `masks[:, 1:, ...]`) are alternate predictions that can be used if + they are desired over the first one. + + Now, in case of only points and box prompts, do: + + >>> no_input_mask = np.empty((1, 0, 256, 256, 1)) + >>> inputs = { + ... "images": image, + ... "points": points, + ... "labels": labels, + ... "box": box, + ... "mask": no_input_mask + ... } + ... + >>> outputs = sam.predict(inputs) + >>> masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + Anothe example is that only points prompts are present. + Note that if point prompts are present (i.e. `points.shape[1] != 0`), + but no box prompt is present (i.e. `box.shape[1] == 0`), the points must + be passed using a zero point and -1 label: + + >>> no_box = np.empty((1, 0, 2, 2)) + >>> padded_points = np.concatenate( + ... [points, np.zeros(1, 1, 2)], axis=1 + ... ) + ... + >>> padded_labels = np.concatenate( + ... [labels, -np.ones((1, 1))], axis=1 + ... ) + >>> inputs = { + ... "images": image, + ... "points": padded_points, + ... "labels": padded_labels, + ... "box": no_box, + ... "mask": no_input_mask + ... } + ... + >>> outputs = sam.predict(inputs) + >>> masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + Note that the segment anything model only supports inference and training + isn't support yet. So, calling the `fit` method will fail for now. + """ # noqa: E501 + + def __init__(self, *, backbone, prompt_encoder, mask_decoder, **kwargs): + # Get the image encoder input -- Images + backbone_input = backbone.input + + # Define the prompt encoder inputs -- Prompts + prompt_inputs = { + "points": keras.Input(shape=[None, 2], name="points"), + "labels": keras.Input(shape=[None], name="labels"), + "box": keras.Input(shape=[None, 2, 2], name="box"), + "mask": keras.Input(shape=[None, None, None, 1], name="mask"), + } + + # All Inputs -- Images + Prompts + all_inputs = {"images": backbone_input} + all_inputs.update(prompt_inputs) + + # Build the prompt encoder + prompt_embeddings = prompt_encoder(prompt_inputs) + + # Define the mask decoder inputs + mask_decoder_inputs = { + "image_embeddings": backbone.output, + "image_pe": prompt_embeddings["dense_positional_embeddings"], + "sparse_prompt_embeddings": prompt_embeddings["sparse_embeddings"], + "dense_prompt_embeddings": prompt_embeddings["dense_embeddings"], + } + + # Build the mask decoder + outputs = mask_decoder(mask_decoder_inputs) + + super().__init__(inputs=all_inputs, outputs=outputs, **kwargs) + + self.backbone = backbone + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Segment Anything Model only supports inference for now. Training" + " the model isn't supported yet." + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "backbone": keras.saving.serialize_keras_object(self.backbone), + "prompt_encoder": keras.saving.serialize_keras_object( + self.prompt_encoder + ), + "mask_decoder": keras.saving.serialize_keras_object( + self.mask_decoder + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + config.update( + { + "prompt_encoder": keras.layers.deserialize( + config["prompt_encoder"] + ), + "mask_decoder": keras.layers.deserialize( + config["mask_decoder"] + ), + } + ) + return super().from_config(config) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return copy.deepcopy({**backbone_presets, **sam_presets}) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy({**backbone_presets_with_weights, **sam_presets}) + + @classproperty + def backbone_presets(cls): + """Dictionary of preset names and configurations of compatible + backbones.""" + return copy.deepcopy(backbone_presets) diff --git a/keras_cv/models/segmentation/segment_anything/sam_layers.py b/keras_cv/models/segmentation/segment_anything/sam_layers.py new file mode 100644 index 0000000000..127db266c4 --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam_layers.py @@ -0,0 +1,348 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.layers.vit_det_layers import MLP + + +@keras_cv_export( + "keras_cv.layers.MultiHeadAttentionWithDownsampling", + package="keras_cv.layers", +) +class MultiHeadAttentionWithDownsampling(keras.layers.Layer): + """Multi-Head Attention with downsampling. + + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + + This layer first downscales the features of input queries, keys, and + values using a dense layer. Multi-head attention is then performed + and the attention map is projected back (upscaled) to the number of + input features. + + Args: + num_heads (int): Number of attention heads. + key_dim (int): Size of each attention head for query, key, and + value. + downsample_rate (int, optional): The factor by which to downscale the + input features i.e. the input features of size `key_dim` are + projected down to `key_dim // downsample_rate`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + """ # noqa: E501 + + def __init__(self, num_heads, key_dim, downsample_rate=1, **kwargs): + super().__init__(**kwargs) + self.num_heads = num_heads + self.key_dim = key_dim + self.downsample_rate = downsample_rate + self.internal_dims = key_dim // downsample_rate + + # Downsample + self.query_proj = keras.layers.Dense( + self.internal_dims * self.num_heads + ) + self.key_proj = keras.layers.Dense(self.internal_dims * self.num_heads) + self.value_proj = keras.layers.Dense( + self.internal_dims * self.num_heads + ) + + # Upsample + self.out_proj = keras.layers.Dense(self.key_dim * self.num_heads) + + def build(self, input_shape=None): + self.query_proj.build([None, None, self.num_heads * self.key_dim]) + self.key_proj.build([None, None, self.num_heads * self.key_dim]) + self.value_proj.build([None, None, self.num_heads * self.key_dim]) + self.out_proj.build([None, None, self.internal_dims * self.num_heads]) + self.built = True + + def __separate_heads(self, x): + shape = ops.shape(x) + B, N, C = shape[0], shape[1], shape[2] + x = ops.reshape(x, (B, N, self.num_heads, C // self.num_heads)) + return ops.transpose(x, axes=(0, 2, 1, 3)) + + def __recombine_heads(self, x): + shape = ops.shape(x) + B, N_H, N_T, C_PH = shape[0], shape[1], shape[2], shape[3] + x = ops.transpose(x, axes=(0, 2, 1, 3)) + return ops.reshape(x, (B, N_T, N_H * C_PH)) + + def call(self, query, value, key): + query = self.query_proj(query) + key = self.key_proj(key) + value = self.value_proj(value) + + # Separate into heads + query = self.__separate_heads(query) + key = self.__separate_heads(key) + value = self.__separate_heads(value) + + # Attention + C_PH = ops.shape(query)[-1] + out = query @ ops.transpose(key, (0, 1, 3, 2)) + out = out / ops.sqrt(ops.cast(C_PH, dtype=self.dtype)) + out = ops.softmax(out, axis=-1) + + # Get output + attention_map = out @ value + attention_map = self.__recombine_heads(attention_map) + return self.out_proj(attention_map) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "downsample_rate": self.downsample_rate, + } + ) + return config + + +@keras_cv_export( + "keras_cv.layers.TwoWayMultiHeadAttention", package="keras_cv.layers" +) +class TwoWayMultiHeadAttention(keras.layers.Layer): + """Two-way multi-head attention layer. + + Args: + num_heads (int): Number of attention heads. + key_dim (int): Size of each attention head for query, key, and + value. + mlp_dim (int): Number of hidden dims to use in the mlp block. + skip_first_layer_pe (bool): A boolean indicating whether to skip the + first layer positional embeddings. + attention_downsample_rate (int, optional): The downsample rate to use + in the attention layers. Defaults to 2. + activation (str, optional): The activation for the mlp block's output + layer. Defaults to "relu". + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + """ # noqa: E501 + + def __init__( + self, + num_heads, + key_dim, + mlp_dim, + skip_first_layer_pe, + attention_downsample_rate=2, + activation="relu", + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.key_dim = key_dim + self.mlp_dim = mlp_dim + self.skip_first_layer_pe = skip_first_layer_pe + self.attention_downsample_rate = attention_downsample_rate + self.activation = activation + + self.self_attention = MultiHeadAttentionWithDownsampling( + num_heads=num_heads, key_dim=key_dim + ) + self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5) + self.cross_attention_token_to_image = ( + MultiHeadAttentionWithDownsampling( + num_heads=num_heads, + key_dim=key_dim, + downsample_rate=attention_downsample_rate, + ) + ) + self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5) + + self.mlp_block = MLP( + mlp_dim, + key_dim * num_heads, + num_layers=2, + activation=activation, + ) + + self.layer_norm3 = keras.layers.LayerNormalization(epsilon=1e-5) + self.cross_attention_image_to_token = ( + MultiHeadAttentionWithDownsampling( + num_heads=num_heads, + key_dim=key_dim, + downsample_rate=attention_downsample_rate, + ) + ) + self.layer_norm4 = keras.layers.LayerNormalization(epsilon=1e-5) + + def build(self, input_shape=None): + self.self_attention.build() + self.layer_norm1.build([None, None, self.num_heads * self.key_dim]) + self.cross_attention_token_to_image.build() + self.layer_norm2.build([None, None, self.num_heads * self.key_dim]) + self.mlp_block.build([None, None, self.num_heads * self.key_dim]) + self.layer_norm3.build([None, None, self.num_heads * self.key_dim]) + self.cross_attention_image_to_token.build() + self.layer_norm4.build([None, None, self.num_heads * self.key_dim]) + self.built = True + + def call(self, queries, keys, query_pe, key_pe): + if self.skip_first_layer_pe: + queries = self.self_attention( + query=queries, value=queries, key=queries + ) + else: + queries_with_pe = queries + query_pe + attention_map = self.self_attention( + query=queries_with_pe, key=queries_with_pe, value=queries + ) + queries = queries + attention_map + queries = self.layer_norm1(queries) + + queries_with_pe = queries + query_pe + keys_with_pe = keys + key_pe + attention_map = self.cross_attention_token_to_image( + query=queries_with_pe, key=keys_with_pe, value=keys + ) + queries = queries + attention_map + queries = self.layer_norm2(queries) + + mlp_out = self.mlp_block(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + queries_with_pe = queries + query_pe + keys_with_pe = keys + key_pe + attention_map = self.cross_attention_image_to_token( + query=keys_with_pe, key=queries_with_pe, value=queries + ) + keys = keys + attention_map + keys = self.layer_norm4(keys) + + return queries, keys + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "mlp_dim": self.mlp_dim, + "skip_first_layer_pe": self.skip_first_layer_pe, + "attention_downsample_rate": self.attention_downsample_rate, + "activation": self.activation, + } + ) + return config + + +@keras_cv_export( + "keras_cv.layers.RandomFrequencyPositionalEmbeddings", + package="keras_cv.layers", +) +class RandomFrequencyPositionalEmbeddings(keras.layers.Layer): + """Positional encoding using random spatial frequencies. + + This layer maps coordinates/points in 2D space to positional + encodings using random spatial frequencies. + + Args: + num_positional_features (int): Number of positional features + in the output. + scale (float): The standard deviation of the random frequencies. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + """ # noqa: E501 + + def __init__(self, num_positional_features, scale, **kwargs): + super().__init__(**kwargs) + self.num_positional_features = num_positional_features + self.scale = scale + init_func = lambda *a, **kw: self.scale * ops.random.normal( + shape=(2, self.num_positional_features), dtype=self.dtype + ) + self.positional_encoding_gaussian_matrix = self.add_weight( + name="positional_encoding_gaussian_matrix", + shape=(2, self.num_positional_features), + dtype=self.dtype, + trainable=False, + initializer=init_func, + ) + + def build(self, input_shape=None): + self.built = True + + def __positional_encodings(self, coords): + coords = coords * 2 - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = coords * (2 * math.pi) + return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1) + + def call(self, size): + return self.encode_image(size) + + def encode_image(self, size): + """Generate a positional encoding for an image of any given size. + + Args: + size (tuple[int, int]): The size of the image. + + Returns: + tensor: Positional encoding of the image. + """ + H, W = size + grid = ops.ones(shape=(H, W), dtype=self.dtype) + y_embed = ops.cumsum(grid, axis=0) - 0.5 + x_embed = ops.cumsum(grid, axis=1) - 0.5 + y_embed = y_embed / ops.cast(H, self.dtype) + x_embed = x_embed / ops.cast(W, self.dtype) + return self.__positional_encodings( + ops.stack([x_embed, y_embed], axis=-1) + ) + + def encode_coordinates(self, coords_input, image_size): + """Positionally encode points that are not normalized to `[0, 1]`. + + Args: + coords_input (tensor): 2D coordinates/points to map. + image_size (tuple[int, int]): Height and width of the image + being prompted. + + Returns: + tensor: Positional encodings of the normalized coordinates. + """ + coords_normalized = ops.stack( + [ + coords_input[..., 0] / image_size[1], + coords_input[..., 1] / image_size[0], + ], + axis=-1, + ) + return self.__positional_encodings(coords_normalized) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_positional_features": self.num_positional_features, + "scale": self.scale, + } + ) + return config diff --git a/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py b/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py new file mode 100644 index 0000000000..141f90addc --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py @@ -0,0 +1,240 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.layers.vit_det_layers import MLP +from keras_cv.models.segmentation.segment_anything.sam_transformer import ( + TwoWayTransformer, +) + + +@keras_cv_export("keras_cv.models.SAMMaskDecoder", package="keras_cv.models") +class SAMMaskDecoder(keras.layers.Layer): + """Mask decoder for the Segment Anything Model (SAM). + + This lightweight module efficiently maps the image embedding and a set of + prompt embeddings to an output mask. Before applying the transformer + decoder, the layer first inserts into the set of prompt embeddings a + learned output token embedding that will be used at the decoder's output. + For simplicity, these embeddings (not including the image embedding) are + collectively called "tokens". + + The image embeddings, positional image embeddings, and tokens are passed + through a transformer decoder. After running the decoder, the layer + upsamples the updated image embedding by 4x with two transposed + convolutional layers (now it's downscaled 4x relative to the input + image). Then, the tokens attend once more to the image embedding and + the updated output token embedding are passed to a small 3-layer MLP that + outputs a vector matching the channel dimension of the upscaled image + embedding. Finally, a mask is predicted with a spatially point-wise + product between the upscaled image embedding and the MLP's output. + + Args: + transformer_dim (int, optional): The number of input features to the + transformer decoder. Defaults to `256`. + transformer (keras.layers.Layer, optional): A transformer decoder. + Defaults to `None`. When `None`, a + `keras_cv.models.TwoWayTransformer` layer is used. + num_multimask_outputs (int, optional): Number of multimask outputs. + The model would generate these many extra masks. The total masks + generated by the model are `1 + num_multimask_outputs`. Defaults + to `3`. + iou_head_depth (int, optional): The depth of the dense net used to + predict the IoU confidence score. Defaults to `3`. + iou_head_hidden_dim (int, optional): The number of units in the hidden + layers used in the dense net to predict the IoU confidence score. + Defaults to `256`. + activation (str, optional): Activation to use in the mask upscaler + network. Defaults to `"gelu"`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + """ # noqa: E501 + + def __init__( + self, + *, + transformer_dim=256, + transformer=None, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + activation="gelu", + **kwargs, + ): + super().__init__(**kwargs) + self.transformer_dim = transformer_dim + if transformer is None: + transformer = TwoWayTransformer() + self.transformer = transformer + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.activation = activation + + self.iou_token = keras.layers.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = keras.layers.Embedding( + self.num_mask_tokens, transformer_dim + ) + + self.output_upscaling = keras.models.Sequential( + [ + keras.layers.Conv2DTranspose( + transformer_dim // 4, kernel_size=2, strides=2 + ), + keras.layers.LayerNormalization(epsilon=1e-6), + keras.layers.Activation(activation), + keras.layers.Conv2DTranspose( + transformer_dim // 8, kernel_size=2, strides=2 + ), + keras.layers.Activation(activation), + ] + ) + + self.output_hypernetworks_mlps = [ + MLP(transformer_dim, transformer_dim // 8, 3) + for _ in range(self.num_mask_tokens) + ] + + self.iou_prediction_head = MLP( + iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def build(self, input_shape=None): + self.transformer.build() + self.iou_token.build([None]) + self.mask_tokens.build([None]) + self.output_upscaling.build([None, None, None, self.transformer_dim]) + for mlp in self.output_hypernetworks_mlps: + mlp.build([None, self.transformer_dim]) + self.iou_prediction_head.build([None, self.transformer_dim]) + self.built = True + + def call(self, inputs): + image_embeddings = inputs["image_embeddings"] + image_pe = inputs["image_pe"] + sparse_prompt_embeddings = inputs["sparse_prompt_embeddings"] + dense_prompt_embeddings = inputs["dense_prompt_embeddings"] + + masks, iou_pred = self._predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + return {"masks": masks, "iou_pred": iou_pred} + + def _predict_masks( + self, + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + ): + indices_iou = ops.arange(1, dtype="int32") + indices_mask = ops.arange(self.num_mask_tokens, dtype="int32") + + output_tokens = ops.concatenate( + [self.iou_token(indices_iou), self.mask_tokens(indices_mask)], + axis=0, + ) + output_tokens = ops.broadcast_to( + output_tokens[None, ...], + shape=( + ops.shape(sparse_prompt_embeddings)[0], + ops.shape(output_tokens)[0], + ops.shape(output_tokens)[1], + ), + ) + tokens = ops.concatenate( + [output_tokens, sparse_prompt_embeddings], axis=1 + ) + + source = ops.broadcast_to( + image_embeddings, + shape=( + ops.shape(tokens)[0], + ops.shape(image_embeddings)[1], + ops.shape(image_embeddings)[2], + ops.shape(image_embeddings)[3], + ), + ) + source = source + dense_prompt_embeddings + positional_source = ops.broadcast_to( + image_pe, + shape=( + ops.shape(tokens)[0], + ops.shape(image_embeddings)[1], + ops.shape(image_embeddings)[2], + ops.shape(image_embeddings)[3], + ), + ) + shape = ops.shape(source) + B, H, W, C = shape[0], shape[1], shape[2], shape[3] + + hidden_state, source = self.transformer( + source, positional_source, tokens + ) + iou_token_out = hidden_state[:, 0, :] + mask_tokens_out = hidden_state[:, 1 : (1 + self.num_mask_tokens), :] + + source = ops.reshape(source, (B, H, W, C)) + upscaled_embeddings = self.output_upscaling(source) + hyper_in_list = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = ops.stack(hyper_in_list, axis=1) + shape = ops.shape(upscaled_embeddings) + B, H, W, C = shape[0], shape[1], shape[2], shape[3] + upscaled_embeddings = ops.reshape( + ops.transpose(upscaled_embeddings, axes=(0, 3, 1, 2)), + (B, C, H * W), + ) + masks = ops.reshape( + hyper_in @ upscaled_embeddings, (B, self.num_mask_tokens, H, W) + ) + + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + def get_config(self): + config = super().get_config() + config.update( + { + "transformer_dim": self.transformer_dim, + "transformer": keras.saving.serialize_keras_object( + self.transformer + ), + "num_multimask_outputs": self.num_multimask_outputs, + "iou_head_depth": self.iou_head_depth, + "iou_head_hidden_dim": self.iou_head_hidden_dim, + "activation": self.activation, + } + ) + return config + + @classmethod + def from_config(cls, config): + config.update( + {"transformer": keras.layers.deserialize(config["transformer"])} + ) + return super().from_config(config) diff --git a/keras_cv/models/segmentation/segment_anything/sam_presets.py b/keras_cv/models/segmentation/segment_anything/sam_presets.py new file mode 100644 index 0000000000..6460e5f517 --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam_presets.py @@ -0,0 +1,103 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAM model preset configurations.""" + +from keras_cv.models.backbones.vit_det import vit_det_backbone_presets + +prompt_encoder_preset = { + "class_name": "keras_cv.models>SAMPromptEncoder", + "config": { + "embed_dim": 256, + "image_embedding_size": (64, 64), + "input_image_size": (1024, 1024), + "mask_in_chans": 16, + "activation": "gelu", + }, +} + +mask_decoder_preset = { + "class_name": "keras_cv.models>SAMMaskDecoder", + "config": { + "transformer_dim": 256, + "transformer": { + "class_name": "keras_cv.models>TwoWayTransformer", + "config": { + "depth": 2, + "embed_dim": 256, + "num_heads": 8, + "mlp_dim": 2048, + "activation": "relu", + "attention_downsample_rate": 2, + }, + }, + "num_multimask_outputs": 3, + "iou_head_depth": 3, + "iou_head_hidden_dim": 256, + "activation": "gelu", + }, +} + +sam_presets = { + "sam_base_sa1b": { + "metadata": { + "description": "The base SAM model trained on the SA1B dataset.", + "params": 93_735_728, + "official_name": "SAM", + "path": "segment_anything", + }, + "config": { + "backbone": vit_det_backbone_presets.backbone_presets[ + "vitdet_base" + ], + "prompt_encoder": prompt_encoder_preset, + "mask_decoder": mask_decoder_preset, + }, + "weights_url": "https://storage.googleapis.com/keras-cv/models/segment_anything/sam_base.h5", # noqa: E501 + "weights_hash": "5a18868ed227b6f093d4a6cb7ed689868dd11f288a8311ae69002a9a9d86d192", # noqa: E501 + }, + "sam_large_sa1b": { + "metadata": { + "description": "The large SAM model trained on the SA1B dataset.", + "params": 312_343_088, + "official_name": "SAM", + "path": "segment_anything", + }, + "config": { + "backbone": vit_det_backbone_presets.backbone_presets[ + "vitdet_large" + ], + "prompt_encoder": prompt_encoder_preset, + "mask_decoder": mask_decoder_preset, + }, + "weights_url": "https://storage.googleapis.com/keras-cv/models/segment_anything/sam_large.h5", # noqa: E501 + "weights_hash": "4ef43d3a8e24200c14a086a043dec8e956fef500c6171268a35029ea720305f0", # noqa: E501 + }, + "sam_huge_sa1b": { + "metadata": { + "description": "The huge SAM model trained on the SA1B dataset.", + "params": 641_090_864, + "official_name": "SAM", + "path": "segment_anything", + }, + "config": { + "backbone": vit_det_backbone_presets.backbone_presets[ + "vitdet_huge" + ], + "prompt_encoder": prompt_encoder_preset, + "mask_decoder": mask_decoder_preset, + }, + "weights_url": "https://storage.googleapis.com/keras-cv/models/segment_anything/sam_huge.h5", # noqa: E501 + "weights_hash": "3284c7c3c91274e8cb1ec2de69da3b6d6cee4483f7d8b0e17e1042b9dfc86fe5", # noqa: E501 + }, +} diff --git a/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py b/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py new file mode 100644 index 0000000000..6a729231ca --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py @@ -0,0 +1,303 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.segmentation.segment_anything.sam_layers import ( + RandomFrequencyPositionalEmbeddings, +) + + +@keras_cv_export("keras_cv.models.SAMPromptEncoder", package="keras_cv.models") +class SAMPromptEncoder(keras.layers.Layer): + """Prompt Encoder for the Segment Anything Model (SAM). + + The prompt encoder generates encodings for three types of prompts: + + - Point prompts: Points on the image along with a label indicating whether + the point is in the foreground (part of the mask) or in the background + (not a part of the mask). + - Box prompts: A batch of bounding boxes with format [(x1, y1), (x2, y2)] + used to determine the location of the masks in the image. + - Masks: An input mask can be passed to refine the positional embeddings + for the output mask. + + First, the point prompts and box prompts are concatenated and positional + encodings are generated using random spatial frequencies. A point is + represented as the sum of a positional encoding of the point's location + and one of two learned embeddings that indicate if the point is either in + the foreground or background. A box is represented by an embedding pair: + + (1) the positional encoding of its top-left corner summed with a learned + embedding representing "top-left corner" and + (2) the same structure but using a learned embedding indicating + "bottom-right corner". + + The box and point encodings are referred to as "sparse encodings" + + If a mask prompt is passed, a convolutional neural net is used to + downscale it to generate "dense encodings". If no mask prompt is passed, + an embedding layer is used instead to generate a "no mask" embedding. + + Args: + embed_dim (int, optional): The number of features in the output + embeddings. Defaults to `256`. + image_embedding_size (int, optional): The number of features in the + image embeddings generated by an image encoder. Defaults to + `(64, 64)`. + input_image_size (tuple[int], optional): A tuple of the height and + width of the image being prompted. Defaults to `(1024, 1024)`. + mask_in_chans (int, optional): The number of channels of the mask + prompt. Defaults to `16`. + activation (str, optional): The activation to use in the mask + downscaler neural net. Defaults to `"gelu"`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + """ # noqa: E501 + + def __init__( + self, + *, + embed_dim=256, + image_embedding_size=(64, 64), + input_image_size=(1024, 1024), + mask_in_chans=16, + activation="gelu", + **kwargs + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.image_embedding_size = image_embedding_size + self.input_image_size = input_image_size + self.mask_in_chans = mask_in_chans + self.activation = activation + + self.positional_embedding_layer = RandomFrequencyPositionalEmbeddings( + num_positional_features=self.embed_dim // 2, scale=1 + ) + + self.foreground_point_embed = keras.layers.Embedding( + 1, embed_dim, name="foreground_point_embed" + ) + self.background_point_embed = keras.layers.Embedding( + 1, embed_dim, name="background_point_embed" + ) + self.top_left_corner_embed = keras.layers.Embedding( + 1, embed_dim, name="top_left_corner_embed" + ) + self.bottom_right_corner_embed = keras.layers.Embedding( + 1, embed_dim, name="bottom_right_corner_embed" + ) + self.not_a_point_embed = keras.layers.Embedding( + 1, embed_dim, name="not_a_point_embed" + ) + + self.mask_downscaler = keras.models.Sequential( + [ + keras.layers.Conv2D( + mask_in_chans // 4, kernel_size=2, strides=2 + ), + keras.layers.LayerNormalization(epsilon=1e-6), + keras.layers.Activation(activation), + keras.layers.Conv2D(mask_in_chans, kernel_size=2, strides=2), + keras.layers.LayerNormalization(epsilon=1e-6), + keras.layers.Activation(activation), + keras.layers.Conv2D(embed_dim, kernel_size=1), + ], + name="mask_downscaler", + ) + self.no_mask_embed = keras.layers.Embedding( + 1, embed_dim, name="no_mask_embed" + ) + + def build(self, input_shape=None): + self.positional_embedding_layer.build() + for layer in [ + self.foreground_point_embed, + self.background_point_embed, + self.top_left_corner_embed, + self.bottom_right_corner_embed, + self.not_a_point_embed, + self.no_mask_embed, + ]: + layer.build([None]) + self.mask_downscaler.build( + [ + None, + 4 * self.image_embedding_size[0], + 4 * self.image_embedding_size[1], + 1, + ] + ) + self.built = True + + def compute_output_shape(self, input_shape): + return { + "sparse_embeddings": [None, None, self.embed_dim], + "dense_embeddings": [ + None, + self.image_embedding_size[0], + self.image_embedding_size[1], + self.embed_dim, + ], + "dense_positional_embeddings": [ + None, + self.image_embedding_size[0], + self.image_embedding_size[1], + self.embed_dim, + ], + } + + def __embed_points(self, points, labels): + points = points + 0.5 + indices = ops.arange(1, dtype="int32") + + point_embeddings = self.positional_embedding_layer.encode_coordinates( + points, self.input_image_size + ) + labels = ops.broadcast_to( + labels[..., None], ops.shape(point_embeddings) + ) + point_embeddings = ops.where( + labels == 0, + point_embeddings + self.background_point_embed(indices), + point_embeddings + self.foreground_point_embed(indices), + ) + point_embeddings = ops.where( + labels == -1, + self.not_a_point_embed(indices), + point_embeddings, + ) + return point_embeddings + + def __embed_box(self, box): + shape = ops.shape(box) + B, N = shape[0], shape[1] + box = box + 0.5 + indices = ops.arange(1, dtype="int32") + corner_embedding = self.positional_embedding_layer.encode_coordinates( + box, self.input_image_size + ) + top_left_embedding = corner_embedding[ + :, :, 0, : + ] + self.top_left_corner_embed(indices) + bottom_right_embedding = corner_embedding[ + :, :, 1, : + ] + self.bottom_right_corner_embed(indices) + corner_embedding = ops.stack( + [top_left_embedding, bottom_right_embedding], axis=2 + ) + return ops.reshape(corner_embedding, (B, N * 2, self.embed_dim)) + + def __embed_mask(self, mask): + mask_embedding = self.mask_downscaler(mask) + return mask_embedding + + def call(self, inputs): + points, labels, box, mask = ( + inputs["points"], + inputs["labels"], + inputs["box"], + inputs["mask"], + ) + + # Get the batch shape. Since all the inputs must have the + # same batch shape, choose one input arbitrarily. + B = ops.shape(points)[0] + + # Compute point embeddings + point_embeddings = self.__embed_points(points, labels) + + # Compute box embeddings + box_embeddings = self.__embed_box(box) + + # Concatenate both into a sparse embeddings tensor + sparse_embeddings = ops.concatenate( + [point_embeddings, box_embeddings], axis=1 + ) + + # Compute the mask embeddings + _no_mask_embed = lambda: ( + ops.broadcast_to( + ops.reshape( + self.no_mask_embed(ops.arange(1, dtype="int32")), + (1, 1, 1, self.embed_dim), + ), + shape=( + B, + self.image_embedding_size[0], + self.image_embedding_size[1], + self.embed_dim, + ), + ) + ) + + def _maybe_input_mask_embed(): + # Keras Core passes the masks as concrete tensors for both the + # true and false functions to build the output shape. So, we + # need to handle the case when 0 size mask is passed and + # dispatch the call to `_no_mask_embed`. Note that we can't call + # the lambda directly since the inputs are bound to different + # values when called with concrete values. + if mask.shape[1] == 0: + return ops.broadcast_to( + ops.reshape( + self.no_mask_embed(ops.arange(1, dtype="int32")), + (1, 1, 1, self.embed_dim), + ), + shape=( + B, + self.image_embedding_size[0], + self.image_embedding_size[1], + self.embed_dim, + ), + ) + shape = ops.shape(mask) + BM, N, H, W, C = shape[0], shape[1], shape[2], shape[3], shape[4] + return self.__embed_mask(ops.reshape(mask, (BM * N, H, W, C))) + + dense_embeddings = ops.cond( + ops.equal(ops.size(mask), 0), + _no_mask_embed, + _maybe_input_mask_embed, + ) + + # Compute the dense positional embeddings + dense_positional_embeddings = ( + self.positional_embedding_layer.encode_image( + self.image_embedding_size + )[None, ...] + ) + + return { + "sparse_embeddings": sparse_embeddings, + "dense_embeddings": dense_embeddings, + "dense_positional_embeddings": dense_positional_embeddings, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "image_embedding_size": self.image_embedding_size, + "input_image_size": self.input_image_size, + "mask_in_chans": self.mask_in_chans, + "activation": self.activation, + } + ) + return config diff --git a/keras_cv/models/segmentation/segment_anything/sam_test.py b/keras_cv/models/segmentation/segment_anything/sam_test.py new file mode 100644 index 0000000000..1ab5a131fa --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam_test.py @@ -0,0 +1,377 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os +import pathlib + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone +from keras_cv.models.segmentation.segment_anything.sam import ( + SegmentAnythingModel, +) +from keras_cv.models.segmentation.segment_anything.sam_layers import ( + TwoWayMultiHeadAttention, +) +from keras_cv.models.segmentation.segment_anything.sam_mask_decoder import ( + SAMMaskDecoder, +) +from keras_cv.models.segmentation.segment_anything.sam_prompt_encoder import ( + SAMPromptEncoder, +) +from keras_cv.models.segmentation.segment_anything.sam_transformer import ( + TwoWayTransformer, +) +from keras_cv.tests.test_case import TestCase + + +class SAMTest(TestCase): + def setUp(self): + self.image_encoder = ViTDetBBackbone() + self.prompt_encoder = SAMPromptEncoder( + embed_dim=256, + image_embedding_size=(64, 64), + input_image_size=(1024, 1024), + mask_in_chans=16, + ) + self.mask_decoder = SAMMaskDecoder( + transformer_dim=256, + transformer=TwoWayTransformer( + depth=2, embed_dim=256, mlp_dim=2048, num_heads=8 + ), + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + + def get_prompts(self, B, prompts="all"): + rng = np.random.default_rng(0) + + points = ops.ones((B, 0, 2)) + labels = ops.ones((B, 0)) + box = ops.ones((B, 0, 2, 2)) + input_mask = ops.ones((B, 0, 256, 256, 1)) + + if "all" in prompts or "points" in prompts: + points = ops.convert_to_tensor( + rng.integers(0, 1023, (B, 10, 2)), dtype="float32" + ) + labels = ops.convert_to_tensor( + 1 * (rng.random((B, 10)) > 0.5), dtype="int32" + ) + if "all" in prompts or "box" in prompts: + x1y1 = rng.integers(0, 1022, (B, 2)) + x2y2 = rng.integers(x1y1, 1023, (B, 2)) + box = np.stack([x1y1, x2y2], axis=1) + box = ops.convert_to_tensor(box[:, None, ...], dtype="float32") + if "all" in prompts or "mask" in prompts: + input_mask = ops.convert_to_tensor( + 1.0 * (rng.random((B, 1, 256, 256, 1)) > 0.5), dtype="float32" + ) + + return points, labels, box, input_mask + + def test_prompt_encoder_simple(self): + points, labels, box, input_mask = self.get_prompts(7) + + outputs = self.prompt_encoder( + dict(points=points, labels=labels, box=box, mask=input_mask) + ) + sparse_embeddings, dense_embeddings, dense_positional_embeddings = ( + outputs["sparse_embeddings"], + outputs["dense_embeddings"], + outputs["dense_positional_embeddings"], + ) + + trainable_parameters = np.sum( + [np.prod(x.shape) for x in self.prompt_encoder.trainable_weights] + ) + num_parameters = np.sum( + [np.prod(x.shape) for x in self.prompt_encoder.weights] + ) + + sparse_embeddings = ops.convert_to_numpy(sparse_embeddings) + dense_embeddings = ops.convert_to_numpy(dense_embeddings) + dense_positional_embeddings = ops.convert_to_numpy( + dense_positional_embeddings + ) + + self.assertEqual(sparse_embeddings.shape, (7, 12, 256)) + self.assertEqual(dense_embeddings.shape, (7, 64, 64, 256)) + self.assertEqual(dense_positional_embeddings.shape, (1, 64, 64, 256)) + self.assertEqual(trainable_parameters, 6_220) + self.assertEqual(num_parameters, 6_476) + + @parameterized.named_parameters( + [ + ("_".join(x), x) + for x in itertools.chain( + itertools.combinations(["points", "box", "mask"], 1), + itertools.combinations(["points", "box", "mask"], 2), + ) + ] + ) + def test_prompt_encoder_partial_prompts(self, prompts): + points, labels, box, input_mask = self.get_prompts(7, prompts) + outputs = self.prompt_encoder( + {"points": points, "labels": labels, "box": box, "mask": input_mask} + ) + sparse_embeddings, dense_embeddings = ( + outputs["sparse_embeddings"], + outputs["dense_embeddings"], + ) + + self.assertAllEqual( + sparse_embeddings.shape, + (7, points.shape[1] + box.shape[1] * 2, 256), + ) + self.assertAllEqual(dense_embeddings.shape, (7, 64, 64, 256)) + if "mask" not in prompts: + no_mask_embed = ops.broadcast_to( + self.prompt_encoder.no_mask_embed(ops.arange(1)), + (7, 64, 64, 256), + ) + self.assertAllClose(dense_embeddings, no_mask_embed) + + def test_two_way_multi_head_attention(self): + points, labels, box, input_mask = self.get_prompts(1) + image_embeddings = np.random.randn(1, 64, 64, 256).astype(np.float32) + + prompt_encoder_outputs = self.prompt_encoder( + dict(points=points, labels=labels, box=box, mask=input_mask) + ) + sparse_embeddings = prompt_encoder_outputs["sparse_embeddings"] + + two_way_attention = TwoWayMultiHeadAttention( + num_heads=8, + key_dim=256 // 8, + mlp_dim=2048, + skip_first_layer_pe=False, + ) + queries, keys = two_way_attention( + queries=sparse_embeddings, + keys=ops.reshape(image_embeddings, (1, 64 * 64, 256)), + query_pe=sparse_embeddings, + key_pe=ops.reshape( + prompt_encoder_outputs["dense_positional_embeddings"], + (1, 64 * 64, 256), + ), + ) + + queries, keys = map(ops.convert_to_numpy, [queries, keys]) + + self.assertEqual(queries.shape, (1, 12, 256)) + self.assertEqual(keys.shape, (1, 64 * 64, 256)) + + def test_two_way_transformer(self): + points, labels, box, input_mask = self.get_prompts(1) + prompt_encoder_outputs = self.prompt_encoder( + dict(points=points, labels=labels, box=box, mask=input_mask) + ) + sparse_embeddings = prompt_encoder_outputs["sparse_embeddings"] + image_embeddings = np.random.randn(1, 64, 64, 256) + two_way_transformer = TwoWayTransformer( + depth=2, embed_dim=256, num_heads=8, mlp_dim=2048 + ) + queries, keys = two_way_transformer( + image_embedding=image_embeddings, + image_pe=prompt_encoder_outputs["dense_positional_embeddings"], + point_embedding=sparse_embeddings, + ) + + queries, keys = map(ops.convert_to_numpy, [queries, keys]) + + self.assertEqual(queries.shape, (1, 12, 256)) + self.assertEqual(keys.shape, (1, 64 * 64, 256)) + + def test_mask_decoder(self): + points, labels, box, input_mask = self.get_prompts(1) + prompt_encoder_outputs = self.prompt_encoder( + dict(points=points, labels=labels, box=box, mask=input_mask) + ) + sparse_embeddings, dense_embeddings, dense_positional_embeddings = ( + prompt_encoder_outputs["sparse_embeddings"], + prompt_encoder_outputs["dense_embeddings"], + prompt_encoder_outputs["dense_positional_embeddings"], + ) + image_embeddings = np.random.randn(1, 64, 64, 256) + outputs = self.mask_decoder( + dict( + image_embeddings=image_embeddings, + image_pe=dense_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + ) + ) + masks, iou_pred = outputs["masks"], outputs["iou_pred"] + num_parameters = np.sum( + [np.prod(x.shape) for x in self.mask_decoder.weights] + ) + masks, iou_pred = map(ops.convert_to_numpy, [masks, iou_pred]) + self.assertEqual(masks.shape, (1, 4, 256, 256)) + self.assertEqual(iou_pred.shape, (1, 4)) + self.assertEqual(num_parameters, 4_058_340) + + @pytest.mark.large + def test_end_to_end_model_predict(self): + model = SegmentAnythingModel( + backbone=self.image_encoder, + prompt_encoder=self.prompt_encoder, + mask_decoder=self.mask_decoder, + ) + + points, labels, box, input_mask = self.get_prompts(1) + + inputs = { + "images": np.ones((1, 1024, 1024, 3)), + "points": points, + "labels": labels, + "box": box, + "mask": input_mask, + } + + # Check the number of parameters + num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) + self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) + + # Forward pass through the model + outputs = model.predict(inputs) + masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + # Check the output is equal to the one we expect if we + # run each component separately. This is to confirm that + # the graph is getting compiled correctly i.e. the jitted + # execution is equivalent to the eager execution. + features = self.image_encoder(inputs["images"]) + outputs_ex = self.prompt_encoder( + {k: v for k, v in inputs.items() if k != "images"} + ) + outputs_ex = self.mask_decoder( + { + "image_embeddings": features, + "image_pe": outputs_ex["dense_positional_embeddings"], + "sparse_prompt_embeddings": outputs_ex["sparse_embeddings"], + "dense_prompt_embeddings": outputs_ex["dense_embeddings"], + }, + ) + masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"] + + self.assertAllClose(masks, masks_ex, atol=5e-5) + self.assertAllClose(iou_pred, iou_pred_ex, atol=5e-5) + + @pytest.mark.extra_large + def test_end_to_end_model_save(self): + # Build the model + model = SegmentAnythingModel( + backbone=self.image_encoder, + prompt_encoder=self.prompt_encoder, + mask_decoder=self.mask_decoder, + ) + + # Get the inputs + points, labels, box, input_mask = self.get_prompts(1) + + inputs = { + "images": ops.ones((1, 1024, 1024, 3)), + "points": points, + "labels": labels, + "box": box, + "mask": input_mask, + } + + # Forward pass + outputs = model(inputs) + + # Save the model + save_path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, SegmentAnythingModel) + + # Check that output matches. + restored_outputs = restored_model(inputs) + self.assertAllClose(outputs, restored_outputs) + + @pytest.mark.large + def test_end_to_end_model_preset(self): + # Define the RNG. Don't change the seed. This seed + # was used to generate the inputs for the reference + # values. + rng = np.random.default_rng(0) + + # Generate the inputs + inputs = { + "images": 255.0 * rng.random((1, 1024, 1024, 3), dtype=np.float32), + "points": np.array( + [[[10, 10], [100, 100], [500, 500]]], dtype=np.float32 + ), + "labels": np.array([[0, 1, 0]], dtype=np.float32), + "box": np.array( + [[[[10.0, 10.0], [100.0, 100.0]]]], dtype=np.float32 + ), + "mask": (rng.random((1, 1, 256, 256, 1)) > 0.5).astype(np.float32), + } + + # Run the model + model = SegmentAnythingModel.from_preset("sam_base_sa1b") + outs = model.predict(inputs) + + # Make sure the weights have been loaded correctly. + masks_expected = np.load( + pathlib.Path(__file__).parent / "data" / "sam_base_out_masks.npy" + ) + iou_pred_expected = np.load( + pathlib.Path(__file__).parent / "data" / "sam_base_out_iou_pred.npy" + ) + self.assertAllClose(outs["masks"], masks_expected, atol=1e-2, rtol=1e-2) + self.assertAllClose( + outs["iou_pred"], iou_pred_expected, atol=1e-2, rtol=1e-2 + ) + + def test_end_to_end_model_fit_error(self): + # Build the model + model = SegmentAnythingModel( + backbone=self.image_encoder, + prompt_encoder=self.prompt_encoder, + mask_decoder=self.mask_decoder, + ) + + # Get the inputs + points, labels, box, input_mask = self.get_prompts(1) + + inputs = { + "images": ops.ones((1, 1024, 1024, 3)), + "points": points, + "labels": labels, + "box": box, + "mask": input_mask, + } + + # Compile the model + model.compile( + optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] + ) + + # Check that calling fit raises a NotImplementedError. + with self.assertRaises( + NotImplementedError, msg=r"only supports inference" + ): + model.fit(inputs) diff --git a/keras_cv/models/segmentation/segment_anything/sam_transformer.py b/keras_cv/models/segmentation/segment_anything/sam_transformer.py new file mode 100644 index 0000000000..64d1de8575 --- /dev/null +++ b/keras_cv/models/segmentation/segment_anything/sam_transformer.py @@ -0,0 +1,155 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.segmentation.segment_anything.sam_layers import ( + MultiHeadAttentionWithDownsampling, +) +from keras_cv.models.segmentation.segment_anything.sam_layers import ( + TwoWayMultiHeadAttention, +) + + +@keras_cv_export("keras_cv.models.TwoWayTransformer", package="keras_cv.models") +class TwoWayTransformer(keras.layers.Layer): + """A two-way cross-attention transformer decoder. + + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + The transformer decoder design is shown in [1]_. Each decoder layer + performs 4 steps: (1) self-attention on the tokens, (2) cross-attention + from tokens (as queries) to the image embedding, (3) a point-wise MLP + updates each token, and (4) cross-attention from the image embedding (as + queries) to tokens. This last step updates the image embedding with prompt + information. Each self/cross-attention and MLP has a residual connection + and layer normalization. + + To ensure the decoder has access to critical geometric information the + positional encodings are added to the image embedding whenever they + participate in an attention layer. Additionally, the entire original + prompt tokens (including their positional encodings) are re-added to the + updated tokens whenever they participate in an attention layer. This + allows for a strong dependence on both the prompt token's geometric + location and type. + + Args: + depth (int, optional): The depth of the attention blocks (the number + of attention blocks to use). Defaults to `2`. + embed_dim (int, optional): The number of features of the input image + and point embeddings. Defaults to `256`. + num_heads (int, optional): Number of heads to use in the attention + layers. Defaults to `8`. + mlp_dim (int, optional): The number of units in the hidden layer of + the MLP block used in the attention layers. Defaults to `2048`. + activation (str, optional): The activation of the MLP block's output + layer used in the attention layers. Defaults to `"relu"`. + attention_downsample_rate (int, optional): The downsample rate of the + attention layers. Defaults to `2`. + + References: + - [Segment Anything paper](https://arxiv.org/abs/2304.02643) + - [Segment Anything GitHub](https://github.com/facebookresearch/segment-anything) + """ # noqa: E501 + + def __init__( + self, + *, + depth=2, + embed_dim=256, + num_heads=8, + mlp_dim=2048, + activation="relu", + attention_downsample_rate=2, + **kwargs, + ): + super().__init__(**kwargs) + self.depth = depth + self.embed_dim = embed_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.activation = activation + self.attention_downsample_rate = attention_downsample_rate + self.layers = [] + for i in range(depth): + self.layers.append( + TwoWayMultiHeadAttention( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + mlp_dim=mlp_dim, + skip_first_layer_pe=(i == 0), + attention_downsample_rate=attention_downsample_rate, + activation=activation, + ) + ) + self.final_attention_token_to_image = ( + MultiHeadAttentionWithDownsampling( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + downsample_rate=attention_downsample_rate, + ) + ) + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5) + + def build(self, input_shape=None): + for layer in self.layers: + layer.build() + self.final_attention_token_to_image.build() + self.final_layer_norm.build([None, None, self.embed_dim]) + self.built = True + + def call(self, image_embedding, image_pe, point_embedding): + shape = ops.shape(image_embedding) + B, H, W, C = shape[0], shape[1], shape[2], shape[3] + image_embedding = ops.reshape(image_embedding, (B, H * W, C)) + + shape = ops.shape(image_pe) + B, H, W, C = shape[0], shape[1], shape[2], shape[3] + image_pe = ops.reshape(image_pe, (B, H * W, C)) + queries = point_embedding + keys = image_embedding + + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + queries_with_pe = queries + point_embedding + keys_with_pe = keys + image_pe + attention_map = self.final_attention_token_to_image( + query=queries_with_pe, key=keys_with_pe, value=keys + ) + queries = queries + attention_map + queries = self.final_layer_norm(queries) + + return queries, keys + + def get_config(self): + config = super().get_config() + config.update( + { + "depth": self.depth, + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "mlp_dim": self.mlp_dim, + "activation": self.activation, + "attention_downsample_rate": self.attention_downsample_rate, + } + ) + return config diff --git a/keras_cv/models/utils.py b/keras_cv/models/utils.py index cb573b63c3..7199823385 100644 --- a/keras_cv/models/utils.py +++ b/keras_cv/models/utils.py @@ -25,12 +25,14 @@ def get_tensor_input_name(tensor): return tensor.node.layer.name -def parse_model_inputs(input_shape, input_tensor): +def parse_model_inputs(input_shape, input_tensor, **kwargs): if input_tensor is None: - return keras.layers.Input(shape=input_shape) + return keras.layers.Input(shape=input_shape, **kwargs) else: if not keras.backend.is_keras_tensor(input_tensor): - return keras.layers.Input(tensor=input_tensor, shape=input_shape) + return keras.layers.Input( + tensor=input_tensor, shape=input_shape, **kwargs + ) else: return input_tensor