Skip to content

Commit

Permalink
Address review comments and address saving bug
Browse files Browse the repository at this point in the history
- Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`.
- Add a `SerializableSequential` class to address the saving bug with the `Sequential` model.
- Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file.
- Add the detectron2 layers to the `keras_cv/layers/__init__.py` file.
- Add a test for the `ViTDetPatchingAndEmbedding` layer.
  • Loading branch information
tirthasheshpatel committed Aug 18, 2023
1 parent 43b0f2b commit ac7f30e
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 164 deletions.
1 change: 1 addition & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth
from keras_cv.layers.serializable_sequential import SerializableSequential
from keras_cv.layers.spatial_pyramid import SpatialPyramidPooling
from keras_cv.layers.transformer_encoder import TransformerEncoder
from keras_cv.layers.vit_layers import PatchingAndEmbedding
273 changes: 137 additions & 136 deletions keras_cv/layers/detectron2_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,97 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.segmentation.segment_anything.sam_layers import MLPBlock


def get_rel_pos(query_size, key_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
query_size (int): The number of features of the queries.
key_size (int): The number of features of the keys.
rel_pos (tensor): Relative positional embedding tensor.
Returns:
tensor: Extracted positional embeddings according to relative
positions.
"""
max_rel_dist = 2 * max(query_size, key_size) - 1
if rel_pos.shape[0] != max_rel_dist:
rel_pos_resized = ops.image.resize(
images=ops.reshape(
rel_pos, (1, rel_pos.shape[0], rel_pos.shape[1], 1)
),
size=(max_rel_dist, rel_pos.shape[1]),
interpolation="bilinear",
)
rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1))
else:
rel_pos_resized = rel_pos
query_coordinates = ops.arange(query_size, dtype="float32")[:, None] * max(
key_size / query_size, 1.0
)
key_coordinates = ops.arange(key_size, dtype="float32")[None, :] * max(
query_size / key_size, 1.0
)
relative_coordinates = (query_coordinates - key_coordinates) + (
key_size - 1
) * max(query_size / key_size, 1.0)
relative_coordinates = ops.cast(relative_coordinates, dtype="int64")
return ops.take(rel_pos_resized, relative_coordinates, 0)


def add_decomposed_rel_pos(
attention_map, queries, rel_pos_h, rel_pos_w, query_size, key_size
):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
Args:
attention_map (tensor): Attention map.
queries (tensor): Queries in the attention layer with shape
`(B, q_h * q_w, C)`.
rel_pos_h (tensor): Relative position embeddings `(Lh, C)` for height
axis.
rel_pos_w (tensor): relative position embeddings `(Lw, C)` for width
axis.
query_size (tuple[int, int]): Spatial sequence size of queries with
`(q_h, q_w)`.
key_size (tuple[int, int]): Spatial sequence size of keys with
`(k_h, k_w)`.
Returns:
tensor: attention map with added relative positional embeddings.
References:
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501
"""
query_height, query_width = query_size
key_height, key_width = key_size
rel_heights = get_rel_pos(query_height, key_height, rel_pos_h)
rel_widths = get_rel_pos(query_width, key_width, rel_pos_w)

B, _, C = queries.shape
rel_queries = ops.reshape(queries, (B, query_height, query_width, C))
rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights)
rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths)

attention_map = ops.reshape(
attention_map, (B, query_height, query_width, key_height, key_width)
)
attention_map = attention_map + rel_heights[..., :, None]
attention_map = attention_map + rel_widths[..., None, :]
attention_map = ops.reshape(
attention_map, (B, query_height * query_width, key_height * key_width)
)
return attention_map


@keras.saving.register_keras_serializable(package="keras_cv")
@keras_cv_export("keras_cv.layers.MultiHeadAttentionWithRelativePE")
class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
"""Multi-head Attention block with relative position embeddings.
Expand Down Expand Up @@ -218,55 +134,7 @@ def get_config(self):
return config


def window_partition(x, window_size):
B, H, W, C = x.shape
pad_height = (window_size - H % window_size) % window_size
pad_width = (window_size - W % window_size) % window_size
if pad_height > 0 or pad_width > 0:
x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0)))
H_padded, W_padded = H + pad_height, W + pad_width
x = ops.reshape(
x,
(
B,
H_padded // window_size,
window_size,
W_padded // window_size,
window_size,
C,
),
)
windows = ops.reshape(
ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)),
(-1, window_size, window_size, C),
)
return windows, (H_padded, W_padded)


def window_unpartition(windows, window_size, HW_padded, HW):
H_padded, W_padded = HW_padded
H, W = HW
B = windows.shape[0] // (
(H_padded // window_size) * (W_padded // window_size)
)
x = ops.reshape(
windows,
(
B,
H_padded // window_size,
W_padded // window_size,
window_size,
window_size,
-1,
),
)
x = ops.reshape(
ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), (B, H_padded, W_padded, -1)
)
return x[:, :H, :W, :]


@keras.utils.register_keras_serializable(package="keras_cv")
@keras_cv_export("keras_cv.layers.WindowedTransformerEncoder")
class WindowedTransformerEncoder(keras.layers.Layer):
"""Transformer blocks with support of window attention and residual
propagation blocks.
Expand Down Expand Up @@ -379,7 +247,7 @@ def get_config(self):
return config


@keras.utils.register_keras_serializable(package="keras_cv")
@keras_cv_export("keras_cv.layers.ViTDetPatchingAndEmbedding")
class ViTDetPatchingAndEmbedding(keras.layers.Layer):
"""Image to Patch Embedding using only a conv layer (without
layer normalization).
Expand Down Expand Up @@ -429,3 +297,136 @@ def get_config(self):
}
)
return config


def get_rel_pos(query_size, key_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
query_size (int): The number of features of the queries.
key_size (int): The number of features of the keys.
rel_pos (tensor): Relative positional embedding tensor.
Returns:
tensor: Extracted positional embeddings according to relative
positions.
"""
max_rel_dist = 2 * max(query_size, key_size) - 1
if rel_pos.shape[0] != max_rel_dist:
rel_pos_resized = ops.image.resize(
images=ops.reshape(
rel_pos, (1, rel_pos.shape[0], rel_pos.shape[1], 1)
),
size=(max_rel_dist, rel_pos.shape[1]),
interpolation="bilinear",
)
rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1))
else:
rel_pos_resized = rel_pos
query_coordinates = ops.arange(query_size, dtype="float32")[:, None] * max(
key_size / query_size, 1.0
)
key_coordinates = ops.arange(key_size, dtype="float32")[None, :] * max(
query_size / key_size, 1.0
)
relative_coordinates = (query_coordinates - key_coordinates) + (
key_size - 1
) * max(query_size / key_size, 1.0)
relative_coordinates = ops.cast(relative_coordinates, dtype="int64")
return ops.take(rel_pos_resized, relative_coordinates, 0)


def add_decomposed_rel_pos(
attention_map, queries, rel_pos_h, rel_pos_w, query_size, key_size
):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
Args:
attention_map (tensor): Attention map.
queries (tensor): Queries in the attention layer with shape
`(B, q_h * q_w, C)`.
rel_pos_h (tensor): Relative position embeddings `(Lh, C)` for height
axis.
rel_pos_w (tensor): relative position embeddings `(Lw, C)` for width
axis.
query_size (tuple[int, int]): Spatial sequence size of queries with
`(q_h, q_w)`.
key_size (tuple[int, int]): Spatial sequence size of keys with
`(k_h, k_w)`.
Returns:
tensor: attention map with added relative positional embeddings.
References:
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501
"""
query_height, query_width = query_size
key_height, key_width = key_size
rel_heights = get_rel_pos(query_height, key_height, rel_pos_h)
rel_widths = get_rel_pos(query_width, key_width, rel_pos_w)

B, _, C = queries.shape
rel_queries = ops.reshape(queries, (B, query_height, query_width, C))
rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights)
rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths)

attention_map = ops.reshape(
attention_map, (B, query_height, query_width, key_height, key_width)
)
attention_map = attention_map + rel_heights[..., :, None]
attention_map = attention_map + rel_widths[..., None, :]
attention_map = ops.reshape(
attention_map, (B, query_height * query_width, key_height * key_width)
)
return attention_map


def window_partition(x, window_size):
B, H, W, C = x.shape
pad_height = (window_size - H % window_size) % window_size
pad_width = (window_size - W % window_size) % window_size
if pad_height > 0 or pad_width > 0:
x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0)))
H_padded, W_padded = H + pad_height, W + pad_width
x = ops.reshape(
x,
(
B,
H_padded // window_size,
window_size,
W_padded // window_size,
window_size,
C,
),
)
windows = ops.reshape(
ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)),
(-1, window_size, window_size, C),
)
return windows, (H_padded, W_padded)


def window_unpartition(windows, window_size, HW_padded, HW):
H_padded, W_padded = HW_padded
H, W = HW
B = windows.shape[0] // (
(H_padded // window_size) * (W_padded // window_size)
)
x = ops.reshape(
windows,
(
B,
H_padded // window_size,
W_padded // window_size,
window_size,
window_size,
-1,
),
)
x = ops.reshape(
ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), (B, H_padded, W_padded, -1)
)
return x[:, :H, :W, :]
7 changes: 7 additions & 0 deletions keras_cv/layers/detectron2_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_cv.backend import ops
from keras_cv.layers.detectron2_layers import MultiHeadAttentionWithRelativePE
from keras_cv.layers.detectron2_layers import ViTDetPatchingAndEmbedding
from keras_cv.layers.detectron2_layers import WindowedTransformerEncoder
from keras_cv.tests.test_case import TestCase

Expand Down Expand Up @@ -46,3 +47,9 @@ def test_windowed_transformer_encoder(self):
x_out = ops.convert_to_numpy(windowed_transformer_encoder(x))
self.assertEqual(x_out.shape, (1, 64, 64, 1280))
self.assertAllClose(x_out, np.ones_like(x_out))

def test_vit_patching_and_embedding(self):
vit_patching_and_embedding = ViTDetPatchingAndEmbedding()
x = np.ones((1, 1024, 1024, 3))
x_out = vit_patching_and_embedding(x)
self.assertEqual(x_out.shape, (1, 64, 64, 768))
62 changes: 62 additions & 0 deletions keras_cv/layers/serializable_sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.utils.python_utils import classproperty


# TODO(tirthasheshpatel): Use `Sequential` model once the bug is resolved.
# Temporarily substitute the `Sequential` model with this because a
# bug in Keras/Keras Core prevents the weights of a sequential model to
# load in TensorFlow if they are saved in JAX/PyTorch and vice versa.
# This only happens when the `build` method is called in the `__init__`
# step.
@keras_cv_export("keras_cv.layers.SerializableSequential")
class SerializableSequential(keras.layers.Layer):
def __init__(self, layers_list, **kwargs):
super().__init__(**kwargs)
self.layers_list = layers_list

def build(self, input_shape):
output_shape = input_shape
for layer in self.layers_list:
layer.build(output_shape)
output_shape = layer.compute_output_shape(output_shape)
self.built = True

def call(self, x):
for layer in self.layers_list:
x = layer(x)
return x

def get_config(self):
config = super().get_config()
layers_list_serialized = [
keras.saving.serialize_keras_object(layer)
for layer in self.layers_list
]
config.update({"layers_list": layers_list_serialized})

@classproperty
def from_config(self, config):
config.update(
{
"layers_list": [
keras.layers.deserialize(layer)
for layer in config["layers_list"]
]
}
)
return super().from_config(config)
Loading

0 comments on commit ac7f30e

Please sign in to comment.