From 59a3ccf6c8d6de243bd4da89c6634f3a42335e70 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Sun, 30 Jul 2023 00:17:16 +0000 Subject: [PATCH] SAMLayerNormalization -> keras.layers.LayerNormalization They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6 --- .../segment_anything/sam_image_encoder.py | 7 +-- .../segment_anything/sam_layers.py | 47 ------------------- .../segment_anything/sam_mask_decoder.py | 5 +- .../segment_anything/sam_prompt_encoder.py | 7 +-- 4 files changed, 5 insertions(+), 61 deletions(-) diff --git a/keras_cv/models/segmentation/segment_anything/sam_image_encoder.py b/keras_cv/models/segmentation/segment_anything/sam_image_encoder.py index eaf8462cda..3801c75f62 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_image_encoder.py +++ b/keras_cv/models/segmentation/segment_anything/sam_image_encoder.py @@ -15,9 +15,6 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.segmentation.segment_anything.sam_layers import MLPBlock -from keras_cv.models.segmentation.segment_anything.sam_layers import ( - SAMLayerNormalization, -) def get_rel_pos(query_size, key_size, rel_pos): @@ -522,14 +519,14 @@ def __init__( keras.layers.Conv2D( filters=out_chans, kernel_size=1, use_bias=False ), - SAMLayerNormalization(), + keras.layers.LayerNormalization(epsilon=1e-6), keras.layers.Conv2D( filters=out_chans, kernel_size=3, padding="same", use_bias=False, ), - SAMLayerNormalization(), + keras.layers.LayerNormalization(epsilon=1e-6), ] ) diff --git a/keras_cv/models/segmentation/segment_anything/sam_layers.py b/keras_cv/models/segmentation/segment_anything/sam_layers.py index 764aefaf01..4dcde4a9b3 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_layers.py +++ b/keras_cv/models/segmentation/segment_anything/sam_layers.py @@ -13,7 +13,6 @@ # limitations under the License. from keras_cv.backend import keras -from keras_cv.backend import ops @keras.utils.register_keras_serializable(package="keras_cv") @@ -58,49 +57,3 @@ def get_config(self): "activation": self.activation, } ) - - -@keras.utils.register_keras_serializable(package="keras_cv") -class SAMLayerNormalization(keras.layers.Layer): - def __init__(self, epsilon=1e-6, **kwargs): - """A SAMLayerNormalization layer without moving mean and variance. - - Args: - epsilon (float, optional): Small float added to variance to - avoid dividing by zero. Defaults to 1e-6. - """ - super().__init__(**kwargs) - self.epsilon = epsilon - - def build(self, input_shape): - self.weight = self.add_weight( - name="weight", - shape=(input_shape[-1],), - initializer="ones", - trainable=True, - ) - self.bias = self.add_weight( - name="weight", - shape=(input_shape[-1],), - initializer="zeros", - trainable=True, - ) - - def call(self, x): - u = ops.mean(x, axis=-1, keepdims=True) - s = ops.mean(ops.square(x - u), axis=-1, keepdims=True) - x = (x - u) / ops.sqrt(s + self.epsilon) - x = self.weight * x + self.bias - return x - - def compute_output_shape(self, input_shape): - return input_shape - - def get_config(self): - config = super().get_config() - config.update( - { - "epsilon": self.epsilon, - } - ) - return config diff --git a/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py b/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py index 877039a9a8..486c6cd252 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py +++ b/keras_cv/models/segmentation/segment_anything/sam_mask_decoder.py @@ -14,9 +14,6 @@ from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.segmentation.segment_anything.sam_layers import ( - SAMLayerNormalization, -) @keras.utils.register_keras_serializable(package="keras_cv") @@ -133,7 +130,7 @@ def __init__( keras.layers.Conv2DTranspose( transformer_dim // 4, kernel_size=2, strides=2 ), - SAMLayerNormalization(), + keras.layers.LayerNormalization(epsilon=1e-6), keras.layers.Activation(activation), keras.layers.Conv2DTranspose( transformer_dim // 8, kernel_size=2, strides=2 diff --git a/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py b/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py index 1cfe94600c..5cb1759740 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py +++ b/keras_cv/models/segmentation/segment_anything/sam_prompt_encoder.py @@ -16,9 +16,6 @@ from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.models.segmentation.segment_anything.sam_layers import ( - SAMLayerNormalization, -) @keras.saving.register_keras_serializable(package="keras_cv") @@ -183,10 +180,10 @@ def __init__( keras.layers.Conv2D( mask_in_chans // 4, kernel_size=2, strides=2 ), - SAMLayerNormalization(), + keras.layers.LayerNormalization(epsilon=1e-6), keras.layers.Activation(activation), keras.layers.Conv2D(mask_in_chans, kernel_size=2, strides=2), - SAMLayerNormalization(), + keras.layers.LayerNormalization(epsilon=1e-6), keras.layers.Activation(activation), keras.layers.Conv2D(embed_dim, kernel_size=1), ]