Skip to content

Commit

Permalink
SAMLayerNormalization -> keras.layers.LayerNormalization
Browse files Browse the repository at this point in the history
They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6
  • Loading branch information
tirthasheshpatel committed Jul 30, 2023
1 parent 4a10a31 commit 59a3ccf
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.segmentation.segment_anything.sam_layers import MLPBlock
from keras_cv.models.segmentation.segment_anything.sam_layers import (
SAMLayerNormalization,
)


def get_rel_pos(query_size, key_size, rel_pos):
Expand Down Expand Up @@ -522,14 +519,14 @@ def __init__(
keras.layers.Conv2D(
filters=out_chans, kernel_size=1, use_bias=False
),
SAMLayerNormalization(),
keras.layers.LayerNormalization(epsilon=1e-6),
keras.layers.Conv2D(
filters=out_chans,
kernel_size=3,
padding="same",
use_bias=False,
),
SAMLayerNormalization(),
keras.layers.LayerNormalization(epsilon=1e-6),
]
)

Expand Down
47 changes: 0 additions & 47 deletions keras_cv/models/segmentation/segment_anything/sam_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from keras_cv.backend import keras
from keras_cv.backend import ops


@keras.utils.register_keras_serializable(package="keras_cv")
Expand Down Expand Up @@ -58,49 +57,3 @@ def get_config(self):
"activation": self.activation,
}
)


@keras.utils.register_keras_serializable(package="keras_cv")
class SAMLayerNormalization(keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
"""A SAMLayerNormalization layer without moving mean and variance.
Args:
epsilon (float, optional): Small float added to variance to
avoid dividing by zero. Defaults to 1e-6.
"""
super().__init__(**kwargs)
self.epsilon = epsilon

def build(self, input_shape):
self.weight = self.add_weight(
name="weight",
shape=(input_shape[-1],),
initializer="ones",
trainable=True,
)
self.bias = self.add_weight(
name="weight",
shape=(input_shape[-1],),
initializer="zeros",
trainable=True,
)

def call(self, x):
u = ops.mean(x, axis=-1, keepdims=True)
s = ops.mean(ops.square(x - u), axis=-1, keepdims=True)
x = (x - u) / ops.sqrt(s + self.epsilon)
x = self.weight * x + self.bias
return x

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = super().get_config()
config.update(
{
"epsilon": self.epsilon,
}
)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@

from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.segmentation.segment_anything.sam_layers import (
SAMLayerNormalization,
)


@keras.utils.register_keras_serializable(package="keras_cv")
Expand Down Expand Up @@ -133,7 +130,7 @@ def __init__(
keras.layers.Conv2DTranspose(
transformer_dim // 4, kernel_size=2, strides=2
),
SAMLayerNormalization(),
keras.layers.LayerNormalization(epsilon=1e-6),
keras.layers.Activation(activation),
keras.layers.Conv2DTranspose(
transformer_dim // 8, kernel_size=2, strides=2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.segmentation.segment_anything.sam_layers import (
SAMLayerNormalization,
)


@keras.saving.register_keras_serializable(package="keras_cv")
Expand Down Expand Up @@ -183,10 +180,10 @@ def __init__(
keras.layers.Conv2D(
mask_in_chans // 4, kernel_size=2, strides=2
),
SAMLayerNormalization(),
keras.layers.LayerNormalization(epsilon=1e-6),
keras.layers.Activation(activation),
keras.layers.Conv2D(mask_in_chans, kernel_size=2, strides=2),
SAMLayerNormalization(),
keras.layers.LayerNormalization(epsilon=1e-6),
keras.layers.Activation(activation),
keras.layers.Conv2D(embed_dim, kernel_size=1),
]
Expand Down

0 comments on commit 59a3ccf

Please sign in to comment.