Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeepVision Port] SegFormer and Mix-Transformers #1946

Merged
merged 57 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
dc41892
initial dump
DavidLandup0 Jul 13, 2023
e5677e6
add all basic layers, port roughly to keras core ops
DavidLandup0 Jul 17, 2023
7bd1056
updated .gitignore
DavidLandup0 Jul 17, 2023
03470df
segformer head and formatting
DavidLandup0 Jul 17, 2023
cb1c702
cleanup
DavidLandup0 Jul 17, 2023
22f8fdf
remove tf call
DavidLandup0 Jul 17, 2023
5c9803a
remove tf
DavidLandup0 Jul 17, 2023
314dc6b
migrating to more keras ops
DavidLandup0 Jul 17, 2023
7a0151b
cleanups and fixes
DavidLandup0 Jul 23, 2023
44f01af
fix reshaping
DavidLandup0 Jul 23, 2023
eb5b5ae
comments
DavidLandup0 Jul 23, 2023
ea0239f
from presets api, keras.ops -> ops
DavidLandup0 Jul 23, 2023
b6128a5
embed_dims -> embedding_dims
DavidLandup0 Jul 23, 2023
8322109
addressing some PR comments
DavidLandup0 Jul 24, 2023
75bb4a2
docstrings, argument update
DavidLandup0 Jul 24, 2023
97daf7c
depths arg
DavidLandup0 Jul 24, 2023
5f9dc0c
sync
DavidLandup0 Jul 24, 2023
efbbd49
compute output shapes
DavidLandup0 Jul 26, 2023
d3b43c6
segformer progress
DavidLandup0 Jul 26, 2023
dab4e74
head
DavidLandup0 Jul 27, 2023
1dba059
softmax
DavidLandup0 Jul 27, 2023
bdc3687
remove softmax
DavidLandup0 Jul 28, 2023
ddfa315
undo compute_output_shapes()
DavidLandup0 Jul 28, 2023
5a091b6
efficientmultiheadattention -> segformermultiheadattention
DavidLandup0 Jul 30, 2023
4e9df16
docstrings
DavidLandup0 Jul 30, 2023
278875c
softmax output
DavidLandup0 Jul 30, 2023
884c376
Merge branch 'master' into segformer_tf
DavidLandup0 Jul 30, 2023
6618a65
segformer presets
DavidLandup0 Aug 1, 2023
e1fbdb0
Merge branch 'segformer_tf' of https://github.com/DavidLandup0/keras-…
DavidLandup0 Aug 1, 2023
00ecd92
updating segformer presets
DavidLandup0 Aug 1, 2023
97d9d4a
segformer presets
DavidLandup0 Aug 18, 2023
c10963f
import aliases
DavidLandup0 Aug 18, 2023
f882b3e
Merge branch 'master' into segformer_tf
DavidLandup0 Aug 18, 2023
ab10136
refactoring
DavidLandup0 Aug 18, 2023
094189e
pr comments
DavidLandup0 Aug 18, 2023
a4df0a6
pr comments
DavidLandup0 Aug 18, 2023
e22a15e
add aliases
DavidLandup0 Aug 18, 2023
5d63d18
aliases ot init
DavidLandup0 Aug 18, 2023
03a177f
refactor fix
DavidLandup0 Aug 18, 2023
d1cdd5d
import keras_cv_export
DavidLandup0 Aug 18, 2023
ff32d63
fix presets/aliases and add copyright
DavidLandup0 Aug 19, 2023
5f3fc22
linter warnings
DavidLandup0 Aug 19, 2023
c6b454f
linter errors
DavidLandup0 Aug 19, 2023
5ac7f77
consistency in presets
DavidLandup0 Aug 19, 2023
b2a76ce
return config
DavidLandup0 Aug 19, 2023
0ad5879
fix serialization
DavidLandup0 Aug 19, 2023
eea5e3c
Some cleanup + more tests
ianstenbit Aug 21, 2023
8e62cf6
Fix DropPath layer (need to update tests + add shim for tf.keras
ianstenbit Aug 21, 2023
b9efeb1
Finish DropPath layer
ianstenbit Aug 21, 2023
bd5a99f
Use static shape in backbone
ianstenbit Aug 21, 2023
3d29b0a
Formatting
ianstenbit Aug 21, 2023
4e2c4e8
Switch back to ops.shape
ianstenbit Aug 21, 2023
b32e0cf
documentation
DavidLandup0 Aug 23, 2023
743a3bb
documentation
DavidLandup0 Aug 23, 2023
c640fc9
remove default num classes
DavidLandup0 Aug 23, 2023
f1b5ffa
fix docs
DavidLandup0 Aug 23, 2023
e32704b
Merge branch 'master' into segformer_tf
ianstenbit Aug 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ __pycache__/
.vscode/
.devcontainer/
.coverage
.history
1 change: 1 addition & 0 deletions keras_cv/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

from keras_cv.backend import config # noqa: E402
from keras_cv.backend import ops # noqa: E402
from keras_cv.backend import random # noqa: E402
from keras_cv.backend import tf_ops # noqa: E402


Expand Down
20 changes: 20 additions & 0 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.backend.config import multi_backend

if multi_backend():
from keras_core.random import * # noqa: F403, F401
else:
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401
9 changes: 9 additions & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras_cv.layers.augmenter import Augmenter
from keras_cv.layers.feature_pyramid import FeaturePyramid
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
from keras_cv.layers.hierarchical_transformer_encoder import (
HierarchicalTransformerEncoder,
)
from keras_cv.layers.mbconv import MBConvBlock
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
Expand All @@ -32,6 +35,9 @@
CenterNetLabelEncoder,
)
from keras_cv.layers.object_detection_3d.voxelization import DynamicVoxelization
from keras_cv.layers.overlapping_patching_embedding import (
OverlappingPatchingAndEmbedding,
)
from keras_cv.layers.preprocessing.aug_mix import AugMix
from keras_cv.layers.preprocessing.auto_contrast import AutoContrast
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
Expand Down Expand Up @@ -124,6 +130,9 @@
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth
from keras_cv.layers.segformer_multihead_attention import (
SegFormerMultiheadAttention,
)
from keras_cv.layers.spatial_pyramid import SpatialPyramidPooling
from keras_cv.layers.transformer_encoder import TransformerEncoder
from keras_cv.layers.vit_layers import PatchingAndEmbedding
140 changes: 140 additions & 0 deletions keras_cv/layers/hierarchical_transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.layers.regularization.drop_path import DropPath
from keras_cv.layers.segformer_multihead_attention import (
SegFormerMultiheadAttention,
)


@keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder")
class HierarchicalTransformerEncoder(keras.layers.Layer):
"""
Hierarchical transformer encoder block implementation as a Keras Layer.
The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention`
alternative for computational efficiency, and is meant to be used
within the SegFormer architecture.

References:
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501

Args:
project_dim: integer, the dimensionality of the projection of the
encoder, and output of the `SegFormerMultiheadAttention` layer.
Due to the residual addition the input dimensionality has to be
equal to the output dimensionality.
num_heads: integer, the number of heads for the
`SegFormerMultiheadAttention` layer.
drop_prob: float, the probability of dropping a random
sample using the `DropPath` layer. Defaults to `0.0`.
layer_norm_epsilon: float, the epsilon for
`LayerNormalization` layers. Defaults to `1e-06`
sr_ratio: integer, the ratio to use within
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
layer is used to reduce the length of the sequence. Defaults to `1`.

Basic usage:

```
project_dim = 1024
num_heads = 4
patch_size = 16

encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding(
project_dim=project_dim, patch_size=patch_size)(img_batch)

trans_encoded = keras_cv.layers.HierarchicalTransformerEncoder(project_dim=project_dim,
num_heads=num_heads,
sr_ratio=1)(encoded_patches)

print(trans_encoded.shape) # (1, 3136, 1024)
```
"""

def __init__(
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved
self,
project_dim,
num_heads,
sr_ratio=1,
drop_prob=0.0,
layer_norm_epsilon=1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.project_dim = project_dim
self.num_heads = num_heads
self.drop_prop = drop_prob

self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.attn = SegFormerMultiheadAttention(
project_dim, num_heads, sr_ratio
)
self.drop_path = DropPath(drop_prob)
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.mlp = self.MixFFN(
channels=project_dim,
mid_channels=int(project_dim * 4),
)

def build(self, input_shape):
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved
super().build(input_shape)
self.H = ops.sqrt(ops.cast(input_shape[1], "float32"))
self.W = ops.sqrt(ops.cast(input_shape[2], "float32"))

def call(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x

def get_config(self):
config = super().get_config()
config.update(
{
"mlp": keras.saving.serialize_keras_object(self.mlp),
"project_dim": self.project_dim,
"num_heads": self.num_heads,
"drop_prop": self.drop_prop,
}
)
return config

class MixFFN(keras.layers.Layer):
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, channels, mid_channels):
super().__init__()
self.fc1 = keras.layers.Dense(mid_channels)
self.dwconv = keras.layers.DepthwiseConv2D(
kernel_size=3,
strides=1,
padding="same",
)
self.fc2 = keras.layers.Dense(channels)

def call(self, x):
x = self.fc1(x)
shape = ops.shape(x)
H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1]))
B, C = shape[0], shape[2]
x = ops.reshape(x, (B, H, W, C))
x = self.dwconv(x)
x = ops.reshape(x, (B, -1, C))
x = ops.nn.gelu(x)
x = self.fc2(x)
return x
85 changes: 85 additions & 0 deletions keras_cv/layers/overlapping_patching_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops


@keras_cv_export("keras_cv.layers.OverlappingPatchingAndEmbedding")
class OverlappingPatchingAndEmbedding(keras.layers.Layer):
def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs):
"""
Overlapping Patching and Embedding layer. Differs from `PatchingAndEmbedding`
in that the patch size does not affect the sequence length. It's fully derived
from the `stride` parameter. Additionally, no positional embedding is done
as part of the layer - only a projection using a `Conv2D` layer.

References:
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501

Args:
project_dim: integer, the dimensionality of the projection.
Defaults to `32`.
patch_size: integer, the size of the patches to encode.
Defaults to `7`.
stride: integer, the stride to use for the patching before
projection. Defaults to 5`.
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved

Basic usage:

```
project_dim = 1024
patch_size = 16

encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding(
project_dim=project_dim, patch_size=patch_size)(img_batch)

print(encoded_patches.shape) # (1, 3136, 1024)
```
"""
super().__init__(**kwargs)

self.project_dim = project_dim
self.patch_size = patch_size
self.stride = stride

self.proj = keras.layers.Conv2D(
filters=project_dim,
kernel_size=patch_size,
strides=stride,
padding="same",
)
self.norm = keras.layers.LayerNormalization()

def call(self, x):
x = self.proj(x)
# B, H, W, C
shape = x.shape
x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
x = self.norm(x)
return x

def get_config(self):
config = super().get_config()
config.update(
{
"project_dim": self.project_dim,
"patch_size": self.patch_size,
"stride": self.stride,
}
)
return config
20 changes: 11 additions & 9 deletions keras_cv/layers/regularization/drop_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow import keras

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import random


@keras_cv_export("keras_cv.layers.DropPath")
class DropPath(keras.__internal__.layers.BaseRandomLayer):
class DropPath(keras.layers.Layer):
"""
Implements the DropPath layer. DropPath randomly drops samples during
training with a probability of `rate`. Note that this layer drops individual
Expand Down Expand Up @@ -47,20 +48,21 @@ class DropPath(keras.__internal__.layers.BaseRandomLayer):
""" # noqa: E501

def __init__(self, rate=0.5, seed=None, **kwargs):
super().__init__(seed=seed, **kwargs)
super().__init__(**kwargs)
self.rate = rate
self.seed = seed

def call(self, x, training=None):
if self.rate == 0.0 or not training:
return x
else:
keep_prob = 1 - self.rate
drop_map_shape = (x.shape[0],) + (1,) * (len(x.shape) - 1)
drop_map = keras.backend.random_bernoulli(
drop_map_shape, p=keep_prob, seed=self.seed
batch_size = x.shape[0] or ops.shape(x)[0]
drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1)
drop_map = ops.cast(
random.uniform(drop_map_shape, seed=self.seed) > self.rate,
x.dtype,
)
x = x / keep_prob
x = x / (1.0 - self.rate)
x = x * drop_map
return x

Expand Down
18 changes: 11 additions & 7 deletions keras_cv/layers/regularization/drop_path_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import tensorflow as tf

from keras_cv.layers import DropPath
Expand All @@ -23,23 +25,23 @@ class DropPathTest(TestCase):

def test_input_unchanged_in_eval_mode(self):
layer = DropPath(rate=0.5, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

outputs = layer(inputs, training=False)

self.assertAllClose(inputs, outputs)

def test_input_unchanged_with_rate_equal_to_zero(self):
layer = DropPath(rate=0, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

outputs = layer(inputs, training=True)

self.assertAllClose(inputs, outputs)

def test_input_gets_partially_zeroed_out_in_train_mode(self):
layer = DropPath(rate=0.2, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

outputs = layer(inputs, training=True)

Expand All @@ -48,9 +50,11 @@ def test_input_gets_partially_zeroed_out_in_train_mode(self):

self.assertGreaterEqual(non_zeros_inputs, non_zeros_outputs)

# Because randomness is inconsistent across backends, we just test with 1.
@pytest.mark.tf_keras_only
def test_strict_input_gets_partially_zeroed_out_in_train_mode(self):
layer = DropPath(rate=0.5, seed=42)
inputs = tf.random.uniform(self.FEATURE_SHAPE)
layer = DropPath(rate=0.5, seed=10)
inputs = np.random.uniform(size=self.FEATURE_SHAPE)

total_non_zero_inputs = 0
total_non_zero_outputs = 0
Expand All @@ -66,6 +70,6 @@ def test_strict_input_gets_partially_zeroed_out_in_train_mode(self):

self.assertAllInRange(
total_non_zero_outputs,
int(0.49 * tf.cast(total_non_zero_inputs, tf.float32)),
int(0.51 * tf.cast(total_non_zero_inputs, tf.float32)),
int(0.40 * tf.cast(total_non_zero_inputs, tf.float32)),
int(0.60 * tf.cast(total_non_zero_inputs, tf.float32)),
)
Loading
Loading