Skip to content

Commit

Permalink
update build
Browse files Browse the repository at this point in the history
  • Loading branch information
Divyashree Sreepathihalli committed Feb 14, 2024
1 parent 95d9e10 commit d4c7e16
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 304 deletions.
75 changes: 42 additions & 33 deletions keras_cv/models/feature_extractor/clip/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from keras_cv.backend import keras
from keras_cv.backend import ops

Expand All @@ -20,7 +22,8 @@ def get_initializer(initializer_range=0.02):
Creates a `keras.initializers.TruncatedNormal` with the given range.
Args:
initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.
initializer_range (*float*, defaults to 0.02): Standard deviation of the
initializer range.
Returns:
`keras.initializers.TruncatedNormal`: The truncated normal initializer.
Expand Down Expand Up @@ -48,13 +51,34 @@ def __init__(
self.proj_dim = proj_dim
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.fc_std = ops.power(2 * self.proj_dim, -0.5) * 0.02
self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02

self.in_proj_std = (
ops.power(self.proj_dim, -0.5)
* (ops.power(2 * self.num_hidden_layers, -0.5))
np.power(self.proj_dim, -0.5)
* (np.power(2 * self.num_hidden_layers, -0.5))
* 0.02
)
self.attn = CLIPAttention(
self.proj_dim,
self.num_heads,
self.num_hidden_layers,
name="multi_head_attention",
)
self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1")
self.mlp = keras.Sequential(
[
keras.layers.Dense(
self.proj_dim * 4,
name="c_fc",
),
QuickGELU(name="gelu"),
keras.layers.Dense(
self.proj_dim,
name="c_proj",
),
]
)
self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2")

def attention(self, x, causal_attention_mask=None, attention_mask=None):
mask = None
Expand All @@ -75,33 +99,14 @@ def attention(self, x, causal_attention_mask=None, attention_mask=None):
return self.attn(
x,
attention_mask=mask,
)
)[0]

def build(self, input_shape):
super().build(input_shape)
self.attn = CLIPAttention(
self.proj_dim,
self.num_heads,
self.num_hidden_layers,
name="multi_head_attention",
)
self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_1")
self.mlp = keras.Sequential(
[
keras.layers.Dense(
self.proj_dim * 4,
kernel_initializer=get_initializer(self.in_proj_std),
name="c_fc",
),
QuickGELU(name="gelu"),
keras.layers.Dense(
self.proj_dim,
kernel_initializer=get_initializer(self.fc_std),
name="c_proj",
),
]
)
self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5, name="ln_2")
self.attn.build(None)
self.ln_1.build([None, None, self.proj_dim])
self.mlp.build(None)
self.ln_2.build([None, None, self.proj_dim])

def call(self, x, causal_attention_mask=None, attention_mask=None):
x = x + self.attention(
Expand Down Expand Up @@ -144,7 +149,7 @@ def __init__(self, width, num_layers, heads, **kwargs):

def build(self, input_shape):
super().build(input_shape)
self.resblocks.build(input_shape)
map(lambda blocks: blocks.build(input_shape), self.resblocks)

def call(
self,
Expand Down Expand Up @@ -199,9 +204,6 @@ def __init__(
)

self.scale = self.head_dim**-0.5

def build(self, input_shape):
super().build(input_shape)
in_proj_std = (
(self.proj_dim**-0.5)
* ((2 * self.num_hidden_layers) ** -0.5)
Expand Down Expand Up @@ -229,6 +231,13 @@ def build(self, input_shape):
name="out_proj",
)

def build(self, input_shape):
super().build(input_shape)
self.q_proj.build([None, None, self.proj_dim])
self.k_proj.build([None, None, self.proj_dim])
self.v_proj.build([None, None, self.proj_dim])
self.out_proj.build([None, None, self.proj_dim])

def _transpose_for_scores(self, tensor, batch_size):
"""
Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501
Expand Down Expand Up @@ -290,7 +299,7 @@ def call(
outputs = (
(attn_output, _attention_probs)
if output_attentions
else attn_output
else (attn_output,)
)

return outputs
Expand Down
5 changes: 5 additions & 0 deletions keras_cv/models/feature_extractor/clip/clip_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def __init__(
)

def build(self, input_shape):
super().build(input_shape)
self.embeddings.build(input_shape)
self.pre_norm.build([None, None, self.width])
self.encoder.build(None)
self.post_norm.build([None, self.width])
self.image_projector.build([None, None, self.width])

def call(self, image):
embeddings = self.embeddings(image)
Expand Down
32 changes: 19 additions & 13 deletions keras_cv/models/feature_extractor/clip/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.feature_extractor.clip.clip_image_model import CLIPEncoder
from keras_cv.models.feature_extractor.clip.clip_image_model import (
CLIPImageEncoder,
)
Expand Down Expand Up @@ -86,23 +85,23 @@ def __init__(
self.transformer_heads = transformer_heads
self.transformer_layers = transformer_layers

vision_heads = vision_width // 64
vision_heads = self.vision_width // 64
self.image_encoder = CLIPImageEncoder(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
num_layers=vision_layers,
input_resolution=self.image_resolution,
patch_size=self.vision_patch_size,
width=self.vision_width,
num_layers=self.vision_layers,
heads=vision_heads,
output_dim=embed_dim,
output_dim=self.embed_dim,
name="image_encoder",
)
self.text_encoder = CLIPTextEncoder(
transformer_width=transformer_width,
transformer_layers=transformer_layers,
transformer_heads=transformer_heads,
vocab_size=vocab_size,
embed_dim=embed_dim,
context_length=context_length,
transformer_width=self.transformer_width,
transformer_layers=self.transformer_layers,
transformer_heads=self.transformer_heads,
vocab_size=self.vocab_size,
embed_dim=self.embed_dim,
context_length=self.context_length,
name="text_encoder",
)

Expand All @@ -112,6 +111,13 @@ def __init__(
self.image_embeddings = None
self.text_embeddings = None

def build(self, input_shape):
super().build(input_shape)
self.text_encoder.build([None, self.context_length])
self.image_encoder.build(
[None, self.image_resolution, self.image_resolution, 3]
)

def encode_images(self, image):
return self.image_encoder(image)

Expand Down
13 changes: 5 additions & 8 deletions keras_cv/models/feature_extractor/clip/clip_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
from keras_cv.backend import ops
from keras_cv.backend.config import keras_3
from keras_cv.models import CLIP
from keras_cv.models.feature_extractor.clip import CLIPImageEncoder
from keras_cv.models.feature_extractor.clip import CLIPProcessor
from keras_cv.models.feature_extractor.clip import CLIPTextEncoder
from keras_cv.models.feature_extractor.clip import CLIPTokenizer
from keras_cv.tests.test_case import TestCase

VOCAB_PATH = keras.utils.get_file(
Expand All @@ -38,7 +35,7 @@

MODEL_PATH = keras.utils.get_file(
None,
"https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5",
"https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", # noqa: E501
)


Expand All @@ -55,9 +52,9 @@ def test_clip_model_golden_values(self):
processed_image, processed_text, attention_mask
)
print(image_logits)
self.assertAllClose(image_logits, [[3.747046, 3.747046, 3.747046]])
self.assertAllClose(image_logits, [[2.932678, 2.932678, 2.932675]])
self.assertAllClose(
text_logits, ops.transpose([[3.747046, 3.747046, 3.747046]])
text_logits, ops.transpose([[2.932678, 2.932678, 2.932675]])
)

def test_clip_preprocessor(self):
Expand Down Expand Up @@ -88,7 +85,7 @@ def test_image_encoder_golden_values(self):
model(processed_image, processed_text, attention_mask)
self.assertAllClose(
model.image_embeddings[:, :5],
[[0.038646, -0.051685, -0.077413, 0.062127, -0.089566]],
[[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]],
)

@pytest.mark.large
Expand All @@ -101,7 +98,7 @@ def test_text_encoder_golden_values(self):
print(model.text_embeddings)
self.assertAllClose(
model.text_embeddings[0, :3],
[0.011359, 0.039782, -0.010593],
[-0.018502, 0.000906, 0.020372],
)

@pytest.mark.large # Saving is slow, so mark these large.
Expand Down
20 changes: 16 additions & 4 deletions keras_cv/models/feature_extractor/clip/clip_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from keras_nlp.layers import StartEndPacker

from keras_cv.api_export import keras_cv_export
Expand Down Expand Up @@ -45,12 +44,14 @@ class CLIPProcessor:
"""

def __init__(self, input_resolution, vocabulary, merges):
def __init__(self, input_resolution, vocabulary, merges, **kwargs):
self.input_resolution = input_resolution
self.vocabulary = vocabulary
self.merges = merges
self.image_transform = self.transform_image
self.tokenizer = CLIPTokenizer(
vocabulary=vocabulary,
merges=merges,
vocabulary=self.vocabulary,
merges=self.merges,
unsplittable_tokens=["</w>"],
)
self.packer = StartEndPacker(
Expand Down Expand Up @@ -117,3 +118,14 @@ def pack_tokens(text):
)

return pack_tokens(texts)

def get_config(self):
config = super().get_config()
config.update(
{
"input_resolution": self.input_resolution,
"vocabulary": self.vocabulary,
"merges": self.merges,
}
)
return config
8 changes: 8 additions & 0 deletions keras_cv/models/feature_extractor/clip/clip_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def __init__(
embed_dim, name="text_projector", use_bias=False
)

def build(self, input_shape):
super().build(input_shape)
self.token_embedding.build(input_shape)
self.positional_embedding.build([1, self.context_length])
self.encoder.build(None)
self.ln_final.build([None, None, self.transformer_width])
self.text_projector.build([None, None, self.transformer_width])

def call(self, inputs, attention_mask=None):
token_embedding = self.token_embedding(inputs)
position_ids = ops.expand_dims(
Expand Down
Loading

0 comments on commit d4c7e16

Please sign in to comment.