Skip to content

Commit

Permalink
code reformat (#2311)
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli authored Jan 22, 2024
1 parent 07a097b commit 721fc25
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 13 deletions.
13 changes: 13 additions & 0 deletions keras_cv/models/feature_extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions keras_cv/models/feature_extractors/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.models.feature_extractors.clip.clip_image_encoder import (
CLIPImageEncoder,
)
Expand Down
10 changes: 6 additions & 4 deletions keras_cv/models/feature_extractors/clip/clip_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ def __init__(
name="residual_transformer_encoder",
)(x)
x = ops.transpose(x, axes=(1, 0, 2))

print(x.shape)
x = keras.layers.LayerNormalization(epsilon=1e-6, name="ln_2")(
x[:, 0, :]
)

proj = keras.layers.Dense(output_dim, name="vision_projector")
print(x.shape)
proj = keras.layers.Dense(
output_dim, name="vision_projector", use_bias=False
)
x = proj(x)

print("final", x.shape)
output = x

super().__init__(
Expand Down
22 changes: 13 additions & 9 deletions keras_cv/models/feature_extractors/clip/clip_modelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def compute_output_shape(self, inputs_shape):

class CLIPAttention(keras.layers.Layer):
"""
- Documentation page: https://huggingface.co/docs/transformers/model_doc/clip
- Implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
- Documentation page: https://huggingface.co/docs/transformers/model_doc/clip # noqa: E501
- Implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501
"""

def __init__(self, project_dim, num_heads, dropout=0.0, **kwargs):
Expand All @@ -171,7 +171,8 @@ def __init__(self, project_dim, num_heads, dropout=0.0, **kwargs):
self.head_dim = self.project_dim // self.num_heads
if self.head_dim * self.num_heads != self.project_dim:
raise ValueError(
f"project_dim must be divisible by num_heads (got `project_dim`: {self.project_dim} and `num_heads`:"
f"project_dim must be divisible by num_heads (got `project_dim`"
": {self.project_dim} and `num_heads`:"
f" {self.num_heads})."
)

Expand All @@ -194,13 +195,15 @@ def build(self, input_shape):

def _transpose_for_scores(self, tensor, batch_size):
"""
Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252
Copied from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501
"""
# [batch_size, seq_len, all_head_dim] -> [batch_size, seq_len, num_heads, head_dim]
# [batch_size, seq_len, all_head_dim] ->
# [batch_size, seq_len, num_heads, head_dim]
tensor = ops.reshape(
tensor, (batch_size, -1, self.num_heads, self.head_dim)
)
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
# [batch_size, seq_len, num_heads, head_dim] ->
# [batch_size, num_heads, seq_len, head_dim]
return ops.transpose(tensor, axes=[0, 2, 1, 3])

def call(
Expand All @@ -220,7 +223,6 @@ def call(
value_layer = self._transpose_for_scores(mixed_value_layer, batch_size)

# Scaled dot product between key and query = raw attention scores.

attention_scores = ops.matmul(
query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2])
)
Expand All @@ -231,11 +233,13 @@ def call(

# Apply the causal_attention_mask first
if causal_attention_mask is not None:
# Apply the causal attention mask (precomputed for all layers in the call() function)
# Apply the causal attention mask (precomputed for all layers in
# the call() function)
attention_scores = ops.add(attention_scores, causal_attention_mask)

if attention_mask is not None:
# Apply the attention mask (precomputed for all layers in the call() function)
# Apply the attention mask (precomputed for all layers in the
# call() function)
attention_scores = ops.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
Expand Down

0 comments on commit 721fc25

Please sign in to comment.