Skip to content

Commit

Permalink
Update implementation to use clip attention (#2307)
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli authored Jan 20, 2024
1 parent ac8864c commit c42b962
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 128 deletions.
156 changes: 31 additions & 125 deletions keras_cv/models/feature_extractors/clip/clip_image_encoder.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,25 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.backend import keras
from keras_cv.backend import ops


class CLIPPatchingAndEmbedding(keras.layers.Layer):
def __init__(self, width, patch_size, input_resolution):
super().__init__()

self.conv1 = keras.layers.Conv2D(
filters=width,
kernel_size=patch_size,
strides=patch_size,
use_bias=False,
)
self.width = width
self.input_resolution = input_resolution
self.patch_size = patch_size

def build(self, input_shape):
self.conv1.build(input_shape)
self.class_embedding = self.add_weight(
shape=((self.width,)), name="patch_embed.class_embedding"
)

self.positional_embedding = self.add_weight(
shape=(
(
(self.input_resolution // self.patch_size) ** 2 + 1,
self.width,
)
),
trainable=True,
name="patch_embed.positional_embedding",
)

def call(self, x):
x = self.conv1(x) # shape = [*, grid, grid, width]
x = ops.transpose(
x, axes=[0, 3, 1, 2]
) # shape = [*, width, grid, grid]
shape = ops.shape(x)
x = ops.reshape(
x, [shape[0], shape[1], shape[2] * shape[3]]
) # shape = [*, width, grid ** 2]
x = ops.transpose(x, axes=(0, 2, 1)) # shape = [*, grid ** 2, width]

class_embedding = self.class_embedding

shape = ops.shape(x)
class_embedding_expanded = ops.expand_dims(class_embedding, axis=0)
class_embedding_expanded = ops.expand_dims(
class_embedding_expanded, axis=1
)
class_embedding_expanded = ops.tile(
class_embedding_expanded, (shape[0], 1, 1)
)
x = ops.concatenate(
[class_embedding_expanded, x], axis=1
) # shape = [*, grid ** 2 + 1, width]
positional_embedding = self.positional_embedding
x = x + positional_embedding

return x


class QuickGELU(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def call(self, x):
return x * ops.sigmoid(1.702 * x)


class ResidualTransformerEncoder(keras.layers.Layer):
def __init__(self, width, layers, heads, attn_mask=None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = keras.Sequential(
[ResidualAttention(width, heads, attn_mask) for _ in range(layers)]
)

def call(self, x):
return self.resblocks(x)


class ResidualAttention(keras.layers.Layer):
def __init__(self, d_model, n_head, attn_mask=None):
super().__init__()

self.attn = keras.layers.MultiHeadAttention(n_head, d_model)
self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5)
self.mlp = keras.Sequential(
[
keras.layers.Dense(d_model * 4, name="c_fc"),
QuickGELU(name="gelu"),
keras.layers.Dense(d_model, name="c_proj"),
]
)
self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5)
self.attn_mask = attn_mask
self.q_proj = keras.layers.Dense(units=d_model, name="q_proj")
self.k_proj = keras.layers.Dense(units=d_model, name="k_proj")
self.v_proj = keras.layers.Dense(units=d_model, name="v_proj")

def attention(self, x):
self.attn_mask = (
ops.cast(self.attn_mask, dtype=x.dtype)
if self.attn_mask is not None
else None
)

key = self.k_proj(inputs=x)
value = self.v_proj(inputs=x)
query = self.q_proj(inputs=x)
return self.attn(
key=key, value=value, query=query, attention_mask=self.attn_mask
)

def call(self, x):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
from keras_cv.models.feature_extractors.clip.clip_modelling import (
CLIPPatchingAndEmbedding,
)
from keras_cv.models.feature_extractors.clip.clip_modelling import (
ResidualTransformerEncoder,
)


class CLIPImageEncoder(keras.Model):
Expand All @@ -145,15 +44,22 @@ def __init__(
patch_size=patch_size,
input_resolution=input_resolution,
)(x)
x = keras.layers.LayerNormalization(epsilon=1e-6)(x)
x = keras.layers.LayerNormalization(epsilon=1e-6, name="ln_1")(x)

x = ops.transpose(x, axes=(1, 0, 2))
x = ResidualTransformerEncoder(width, layers, heads)(x)
x = ResidualTransformerEncoder(
width,
layers,
heads,
name="residual_transformer_encoder",
)(x)
x = ops.transpose(x, axes=(1, 0, 2))

x = keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :])
x = keras.layers.LayerNormalization(epsilon=1e-6, name="ln_2")(
x[:, 0, :]
)

proj = keras.layers.Dense(output_dim)
proj = keras.layers.Dense(output_dim, name="vision_projector")
x = proj(x)

output = x
Expand Down
21 changes: 19 additions & 2 deletions keras_cv/models/feature_extractors/clip/clip_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
Expand Down Expand Up @@ -53,24 +66,28 @@ def __init__(
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
name="clip_encoder",
)

self.transformer = ResidualTransformerEncoder(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
name="residual_transformer_encoder",
)

self.vocab_size = vocab_size
self.token_embedding = keras.layers.Embedding(
vocab_size, transformer_width
vocab_size,
transformer_width,
name="token_embedding",
)
self.positional_embedding = self.add_weight(
shape=[self.context_length, transformer_width],
name="positional_embedding",
)
self.ln_final = keras.layers.LayerNormalization()
self.ln_final = keras.layers.LayerNormalization(name="ln_final")

self.text_projection = self.add_weight(
shape=(transformer_width, embed_dim), name="text_projection"
Expand Down
Loading

0 comments on commit c42b962

Please sign in to comment.