-
Notifications
You must be signed in to change notification settings - Fork 330
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DeepVision Port] SegFormer and Mix-Transformers (#1946)
* initial dump * add all basic layers, port roughly to keras core ops * updated .gitignore * segformer head and formatting * cleanup * remove tf call * remove tf * migrating to more keras ops * cleanups and fixes * fix reshaping * comments * from presets api, keras.ops -> ops * embed_dims -> embedding_dims * addressing some PR comments * docstrings, argument update * depths arg * sync * compute output shapes * segformer progress * head * softmax * remove softmax * undo compute_output_shapes() * efficientmultiheadattention -> segformermultiheadattention * docstrings * softmax output * segformer presets * updating segformer presets * segformer presets * import aliases * refactoring * pr comments * pr comments * add aliases * aliases ot init * refactor fix * import keras_cv_export * fix presets/aliases and add copyright * linter warnings * linter errors * consistency in presets * return config * fix serialization * Some cleanup + more tests * Fix DropPath layer (need to update tests + add shim for tf.keras * Finish DropPath layer * Use static shape in backbone * Formatting * Switch back to ops.shape * documentation * documentation * remove default num classes * fix docs --------- Co-authored-by: ianjjohnson <[email protected]>
- Loading branch information
1 parent
b038f58
commit ab812d1
Showing
22 changed files
with
1,855 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ __pycache__/ | |
.vscode/ | ||
.devcontainer/ | ||
.coverage | ||
.history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.backend.config import multi_backend | ||
|
||
if multi_backend(): | ||
from keras_core.random import * # noqa: F403, F401 | ||
else: | ||
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
|
||
from keras_cv.api_export import keras_cv_export | ||
from keras_cv.backend import keras | ||
from keras_cv.backend import ops | ||
from keras_cv.layers.regularization.drop_path import DropPath | ||
from keras_cv.layers.segformer_multihead_attention import ( | ||
SegFormerMultiheadAttention, | ||
) | ||
|
||
|
||
@keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder") | ||
class HierarchicalTransformerEncoder(keras.layers.Layer): | ||
""" | ||
Hierarchical transformer encoder block implementation as a Keras Layer. | ||
The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` | ||
alternative for computational efficiency, and is meant to be used | ||
within the SegFormer architecture. | ||
References: | ||
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 | ||
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 | ||
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 | ||
Args: | ||
project_dim: integer, the dimensionality of the projection of the | ||
encoder, and output of the `SegFormerMultiheadAttention` layer. | ||
Due to the residual addition the input dimensionality has to be | ||
equal to the output dimensionality. | ||
num_heads: integer, the number of heads for the | ||
`SegFormerMultiheadAttention` layer. | ||
drop_prob: float, the probability of dropping a random | ||
sample using the `DropPath` layer. Defaults to `0.0`. | ||
layer_norm_epsilon: float, the epsilon for | ||
`LayerNormalization` layers. Defaults to `1e-06` | ||
sr_ratio: integer, the ratio to use within | ||
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D` | ||
layer is used to reduce the length of the sequence. Defaults to `1`. | ||
Basic usage: | ||
``` | ||
project_dim = 1024 | ||
num_heads = 4 | ||
patch_size = 16 | ||
encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding( | ||
project_dim=project_dim, patch_size=patch_size)(img_batch) | ||
trans_encoded = keras_cv.layers.HierarchicalTransformerEncoder(project_dim=project_dim, | ||
num_heads=num_heads, | ||
sr_ratio=1)(encoded_patches) | ||
print(trans_encoded.shape) # (1, 3136, 1024) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
project_dim, | ||
num_heads, | ||
sr_ratio=1, | ||
drop_prob=0.0, | ||
layer_norm_epsilon=1e-6, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.project_dim = project_dim | ||
self.num_heads = num_heads | ||
self.drop_prop = drop_prob | ||
|
||
self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) | ||
self.attn = SegFormerMultiheadAttention( | ||
project_dim, num_heads, sr_ratio | ||
) | ||
self.drop_path = DropPath(drop_prob) | ||
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) | ||
self.mlp = self.MixFFN( | ||
channels=project_dim, | ||
mid_channels=int(project_dim * 4), | ||
) | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) | ||
self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) | ||
|
||
def call(self, x): | ||
x = x + self.drop_path(self.attn(self.norm1(x))) | ||
x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"mlp": keras.saving.serialize_keras_object(self.mlp), | ||
"project_dim": self.project_dim, | ||
"num_heads": self.num_heads, | ||
"drop_prop": self.drop_prop, | ||
} | ||
) | ||
return config | ||
|
||
class MixFFN(keras.layers.Layer): | ||
def __init__(self, channels, mid_channels): | ||
super().__init__() | ||
self.fc1 = keras.layers.Dense(mid_channels) | ||
self.dwconv = keras.layers.DepthwiseConv2D( | ||
kernel_size=3, | ||
strides=1, | ||
padding="same", | ||
) | ||
self.fc2 = keras.layers.Dense(channels) | ||
|
||
def call(self, x): | ||
x = self.fc1(x) | ||
shape = ops.shape(x) | ||
H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) | ||
B, C = shape[0], shape[2] | ||
x = ops.reshape(x, (B, H, W, C)) | ||
x = self.dwconv(x) | ||
x = ops.reshape(x, (B, -1, C)) | ||
x = ops.nn.gelu(x) | ||
x = self.fc2(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.api_export import keras_cv_export | ||
from keras_cv.backend import keras | ||
from keras_cv.backend import ops | ||
|
||
|
||
@keras_cv_export("keras_cv.layers.OverlappingPatchingAndEmbedding") | ||
class OverlappingPatchingAndEmbedding(keras.layers.Layer): | ||
def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): | ||
""" | ||
Overlapping Patching and Embedding layer. Differs from `PatchingAndEmbedding` | ||
in that the patch size does not affect the sequence length. It's fully derived | ||
from the `stride` parameter. Additionally, no positional embedding is done | ||
as part of the layer - only a projection using a `Conv2D` layer. | ||
References: | ||
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 | ||
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 | ||
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 | ||
Args: | ||
project_dim: integer, the dimensionality of the projection. | ||
Defaults to `32`. | ||
patch_size: integer, the size of the patches to encode. | ||
Defaults to `7`. | ||
stride: integer, the stride to use for the patching before | ||
projection. Defaults to `5`. | ||
Basic usage: | ||
``` | ||
project_dim = 1024 | ||
patch_size = 16 | ||
encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding( | ||
project_dim=project_dim, patch_size=patch_size)(img_batch) | ||
print(encoded_patches.shape) # (1, 3136, 1024) | ||
``` | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
self.project_dim = project_dim | ||
self.patch_size = patch_size | ||
self.stride = stride | ||
|
||
self.proj = keras.layers.Conv2D( | ||
filters=project_dim, | ||
kernel_size=patch_size, | ||
strides=stride, | ||
padding="same", | ||
) | ||
self.norm = keras.layers.LayerNormalization() | ||
|
||
def call(self, x): | ||
x = self.proj(x) | ||
# B, H, W, C | ||
shape = x.shape | ||
x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) | ||
x = self.norm(x) | ||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"project_dim": self.project_dim, | ||
"patch_size": self.patch_size, | ||
"stride": self.stride, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.