Skip to content

Commit

Permalink
[DeepVision Port] SegFormer and Mix-Transformers (#1946)
Browse files Browse the repository at this point in the history
* initial dump

* add all basic layers, port roughly to keras core ops

* updated .gitignore

* segformer head and formatting

* cleanup

* remove tf call

* remove tf

* migrating to more keras ops

* cleanups and fixes

* fix reshaping

* comments

* from presets api, keras.ops -> ops

* embed_dims -> embedding_dims

* addressing some PR comments

* docstrings, argument update

* depths arg

* sync

* compute output shapes

* segformer progress

* head

* softmax

* remove softmax

* undo compute_output_shapes()

* efficientmultiheadattention -> segformermultiheadattention

* docstrings

* softmax output

* segformer presets

* updating segformer presets

* segformer presets

* import aliases

* refactoring

* pr comments

* pr comments

* add aliases

* aliases ot init

* refactor fix

* import keras_cv_export

* fix presets/aliases and add copyright

* linter warnings

* linter errors

* consistency in presets

* return config

* fix serialization

* Some cleanup + more tests

* Fix DropPath layer (need to update tests + add shim for tf.keras

* Finish DropPath layer

* Use static shape in backbone

* Formatting

* Switch back to ops.shape

* documentation

* documentation

* remove default num classes

* fix docs

---------

Co-authored-by: ianjjohnson <[email protected]>
  • Loading branch information
DavidLandup0 and ianstenbit authored Aug 24, 2023
1 parent b038f58 commit ab812d1
Show file tree
Hide file tree
Showing 22 changed files with 1,855 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ __pycache__/
.vscode/
.devcontainer/
.coverage
.history
1 change: 1 addition & 0 deletions keras_cv/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

from keras_cv.backend import config # noqa: E402
from keras_cv.backend import ops # noqa: E402
from keras_cv.backend import random # noqa: E402
from keras_cv.backend import tf_ops # noqa: E402


Expand Down
20 changes: 20 additions & 0 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.backend.config import multi_backend

if multi_backend():
from keras_core.random import * # noqa: F403, F401
else:
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401
9 changes: 9 additions & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras_cv.layers.augmenter import Augmenter
from keras_cv.layers.feature_pyramid import FeaturePyramid
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
from keras_cv.layers.hierarchical_transformer_encoder import (
HierarchicalTransformerEncoder,
)
from keras_cv.layers.mbconv import MBConvBlock
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
Expand All @@ -32,6 +35,9 @@
CenterNetLabelEncoder,
)
from keras_cv.layers.object_detection_3d.voxelization import DynamicVoxelization
from keras_cv.layers.overlapping_patching_embedding import (
OverlappingPatchingAndEmbedding,
)
from keras_cv.layers.preprocessing.aug_mix import AugMix
from keras_cv.layers.preprocessing.auto_contrast import AutoContrast
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
Expand Down Expand Up @@ -124,6 +130,9 @@
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth
from keras_cv.layers.segformer_multihead_attention import (
SegFormerMultiheadAttention,
)
from keras_cv.layers.spatial_pyramid import SpatialPyramidPooling
from keras_cv.layers.transformer_encoder import TransformerEncoder
from keras_cv.layers.vit_layers import PatchingAndEmbedding
140 changes: 140 additions & 0 deletions keras_cv/layers/hierarchical_transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.layers.regularization.drop_path import DropPath
from keras_cv.layers.segformer_multihead_attention import (
SegFormerMultiheadAttention,
)


@keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder")
class HierarchicalTransformerEncoder(keras.layers.Layer):
"""
Hierarchical transformer encoder block implementation as a Keras Layer.
The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention`
alternative for computational efficiency, and is meant to be used
within the SegFormer architecture.
References:
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501
Args:
project_dim: integer, the dimensionality of the projection of the
encoder, and output of the `SegFormerMultiheadAttention` layer.
Due to the residual addition the input dimensionality has to be
equal to the output dimensionality.
num_heads: integer, the number of heads for the
`SegFormerMultiheadAttention` layer.
drop_prob: float, the probability of dropping a random
sample using the `DropPath` layer. Defaults to `0.0`.
layer_norm_epsilon: float, the epsilon for
`LayerNormalization` layers. Defaults to `1e-06`
sr_ratio: integer, the ratio to use within
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
layer is used to reduce the length of the sequence. Defaults to `1`.
Basic usage:
```
project_dim = 1024
num_heads = 4
patch_size = 16
encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding(
project_dim=project_dim, patch_size=patch_size)(img_batch)
trans_encoded = keras_cv.layers.HierarchicalTransformerEncoder(project_dim=project_dim,
num_heads=num_heads,
sr_ratio=1)(encoded_patches)
print(trans_encoded.shape) # (1, 3136, 1024)
```
"""

def __init__(
self,
project_dim,
num_heads,
sr_ratio=1,
drop_prob=0.0,
layer_norm_epsilon=1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.project_dim = project_dim
self.num_heads = num_heads
self.drop_prop = drop_prob

self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.attn = SegFormerMultiheadAttention(
project_dim, num_heads, sr_ratio
)
self.drop_path = DropPath(drop_prob)
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.mlp = self.MixFFN(
channels=project_dim,
mid_channels=int(project_dim * 4),
)

def build(self, input_shape):
super().build(input_shape)
self.H = ops.sqrt(ops.cast(input_shape[1], "float32"))
self.W = ops.sqrt(ops.cast(input_shape[2], "float32"))

def call(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x

def get_config(self):
config = super().get_config()
config.update(
{
"mlp": keras.saving.serialize_keras_object(self.mlp),
"project_dim": self.project_dim,
"num_heads": self.num_heads,
"drop_prop": self.drop_prop,
}
)
return config

class MixFFN(keras.layers.Layer):
def __init__(self, channels, mid_channels):
super().__init__()
self.fc1 = keras.layers.Dense(mid_channels)
self.dwconv = keras.layers.DepthwiseConv2D(
kernel_size=3,
strides=1,
padding="same",
)
self.fc2 = keras.layers.Dense(channels)

def call(self, x):
x = self.fc1(x)
shape = ops.shape(x)
H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1]))
B, C = shape[0], shape[2]
x = ops.reshape(x, (B, H, W, C))
x = self.dwconv(x)
x = ops.reshape(x, (B, -1, C))
x = ops.nn.gelu(x)
x = self.fc2(x)
return x
85 changes: 85 additions & 0 deletions keras_cv/layers/overlapping_patching_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops


@keras_cv_export("keras_cv.layers.OverlappingPatchingAndEmbedding")
class OverlappingPatchingAndEmbedding(keras.layers.Layer):
def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs):
"""
Overlapping Patching and Embedding layer. Differs from `PatchingAndEmbedding`
in that the patch size does not affect the sequence length. It's fully derived
from the `stride` parameter. Additionally, no positional embedding is done
as part of the layer - only a projection using a `Conv2D` layer.
References:
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501
Args:
project_dim: integer, the dimensionality of the projection.
Defaults to `32`.
patch_size: integer, the size of the patches to encode.
Defaults to `7`.
stride: integer, the stride to use for the patching before
projection. Defaults to `5`.
Basic usage:
```
project_dim = 1024
patch_size = 16
encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding(
project_dim=project_dim, patch_size=patch_size)(img_batch)
print(encoded_patches.shape) # (1, 3136, 1024)
```
"""
super().__init__(**kwargs)

self.project_dim = project_dim
self.patch_size = patch_size
self.stride = stride

self.proj = keras.layers.Conv2D(
filters=project_dim,
kernel_size=patch_size,
strides=stride,
padding="same",
)
self.norm = keras.layers.LayerNormalization()

def call(self, x):
x = self.proj(x)
# B, H, W, C
shape = x.shape
x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
x = self.norm(x)
return x

def get_config(self):
config = super().get_config()
config.update(
{
"project_dim": self.project_dim,
"patch_size": self.patch_size,
"stride": self.stride,
}
)
return config
20 changes: 11 additions & 9 deletions keras_cv/layers/regularization/drop_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow import keras

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import random


@keras_cv_export("keras_cv.layers.DropPath")
class DropPath(keras.__internal__.layers.BaseRandomLayer):
class DropPath(keras.layers.Layer):
"""
Implements the DropPath layer. DropPath randomly drops samples during
training with a probability of `rate`. Note that this layer drops individual
Expand Down Expand Up @@ -47,20 +48,21 @@ class DropPath(keras.__internal__.layers.BaseRandomLayer):
""" # noqa: E501

def __init__(self, rate=0.5, seed=None, **kwargs):
super().__init__(seed=seed, **kwargs)
super().__init__(**kwargs)
self.rate = rate
self.seed = seed

def call(self, x, training=None):
if self.rate == 0.0 or not training:
return x
else:
keep_prob = 1 - self.rate
drop_map_shape = (x.shape[0],) + (1,) * (len(x.shape) - 1)
drop_map = keras.backend.random_bernoulli(
drop_map_shape, p=keep_prob, seed=self.seed
batch_size = x.shape[0] or ops.shape(x)[0]
drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1)
drop_map = ops.cast(
random.uniform(drop_map_shape, seed=self.seed) > self.rate,
x.dtype,
)
x = x / keep_prob
x = x / (1.0 - self.rate)
x = x * drop_map
return x

Expand Down
18 changes: 11 additions & 7 deletions keras_cv/layers/regularization/drop_path_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import tensorflow as tf

from keras_cv.layers import DropPath
Expand All @@ -23,23 +25,23 @@ class DropPathTest(TestCase):

def test_input_unchanged_in_eval_mode(self):
layer = DropPath(rate=0.5, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

outputs = layer(inputs, training=False)

self.assertAllClose(inputs, outputs)

def test_input_unchanged_with_rate_equal_to_zero(self):
layer = DropPath(rate=0, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

outputs = layer(inputs, training=True)

self.assertAllClose(inputs, outputs)

def test_input_gets_partially_zeroed_out_in_train_mode(self):
layer = DropPath(rate=0.2, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

outputs = layer(inputs, training=True)

Expand All @@ -48,9 +50,11 @@ def test_input_gets_partially_zeroed_out_in_train_mode(self):

self.assertGreaterEqual(non_zeros_inputs, non_zeros_outputs)

# Because randomness is inconsistent across backends, we just test with 1.
@pytest.mark.tf_keras_only
def test_strict_input_gets_partially_zeroed_out_in_train_mode(self):
layer = DropPath(rate=0.5, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
layer = DropPath(rate=0.5, seed=10)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

total_non_zero_inputs = 0
total_non_zero_outputs = 0
Expand All @@ -66,6 +70,6 @@ def test_strict_input_gets_partially_zeroed_out_in_train_mode(self):

self.assertAllInRange(
total_non_zero_outputs,
int(0.49 * tf.cast(total_non_zero_inputs, tf.float32)),
int(0.51 * tf.cast(total_non_zero_inputs, tf.float32)),
int(0.40 * tf.cast(total_non_zero_inputs, tf.float32)),
int(0.60 * tf.cast(total_non_zero_inputs, tf.float32)),
)
Loading

0 comments on commit ab812d1

Please sign in to comment.