Skip to content

Commit

Permalink
convert code to keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Dec 20, 2023
1 parent 82fe5dd commit bedf7eb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 40 deletions.
43 changes: 20 additions & 23 deletions keras_cv/models/clip/clip_image_encoder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
# encode images
from collections import OrderedDict
from typing import Tuple
from typing import Union

import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
from tensorflow.keras import layers
from torch import nn

from deepvision.layers.clip_patching_and_embedding import CLIPPatchingAndEmbedding
from deepvision.layers.residual_transformer_encoder import ResidualTransformerEncoder
from deepvision.layers.clip_patching_and_embedding import (
CLIPPatchingAndEmbedding,
)
from deepvision.layers.residual_transformer_encoder import (
ResidualTransformerEncoder,
)
from deepvision.utils.utils import parse_model_inputs

from keras_cv.backend import keras
from keras_cv.backend import ops

class __CLIPImageEncoderTF(tf.keras.Model):

class __CLIPImageEncoder(keras.Model):
def __init__(
self,
input_resolution: int,
Expand All @@ -27,7 +22,7 @@ def __init__(
input_tensor=None,
**kwargs,
):
inputs = tf.keras.layers.Input(
inputs = keras.layers.Input(
tensor=input_tensor, shape=(input_resolution, input_resolution, 3)
)
x = inputs
Expand All @@ -38,15 +33,17 @@ def __init__(
input_resolution=input_resolution,
backend="tensorflow",
)(x)
x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x)
x = keras.layers.LayerNormalization(epsilon=1e-6)(x)

x = tf.transpose(x, perm=(1, 0, 2))
x = ResidualTransformerEncoder(width, layers, heads, backend="tensorflow")(x)
x = tf.transpose(x, perm=(1, 0, 2))
x = ops.transpose(x, perm=(1, 0, 2))
x = ResidualTransformerEncoder(
width, layers, heads, backend="tensorflow"
)(x)
x = ops.transpose(x, perm=(1, 0, 2))

x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :])
x = keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :])

proj = tf.keras.layers.Dense(output_dim)
proj = keras.layers.Dense(output_dim)
x = proj(x)

output = x
Expand All @@ -55,4 +52,4 @@ def __init__(
inputs=inputs,
outputs=output,
**kwargs,
)
)
49 changes: 32 additions & 17 deletions keras_cv/models/clip/clip_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,49 @@
from typing import Union

import numpy as np
import tensorflow as tf

from deepvision.models.feature_extractors.clip.clip_tokenizer import (
CLIPTokenizer,
)
from pkg_resources import packaging

from deepvision.models.feature_extractors.clip.clip_tokenizer import CLIPTokenizer
from keras_cv.backend import keras
from keras_cv.backend import ops


class __CLIPProcessorTF:
class __CLIPProcessor:
def __init__(self, input_resolution):
self.input_resolution = input_resolution
self.image_transform = self.transform_image
self.tokenizer = CLIPTokenizer()

def transform_image(self, image_path):
input_resolution = self.input_resolution
mean = tf.constant([0.48145466, 0.4578275, 0.40821073])
std = tf.constant([0.26862954, 0.26130258, 0.27577711])
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])

image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = keras.utils.load_img(image_path)
image = keras.utils.img_to_array(image)
image = (
tf.image.resize(
ops.image.resize(
image,
(input_resolution, input_resolution),
method=tf.image.ResizeMethod.BICUBIC,
method="bicubic",
)
/ 255.0
)
image = tf.image.central_crop(
image, central_fraction=input_resolution / image.shape[0]
central_fraction = input_resolution / image.shape[0]
width, height = image.shape[0], image.shape[1]
left = ops.cast((width - width * central_fraction) / 2, dtype="int32")
top = ops.cast((height - height * central_fraction) / 2, dtype="int32")
right = ops.cast((width + width * central_fraction) / 2, dtype="int32")
bottom = ops.cast(
(height + height * central_fraction) / 2, dtype="int32"
)

image = ops.slice(
image, [left, top, 0], [right - left, bottom - top, 3]
)

image = (image - mean) / std
return image

Expand All @@ -44,17 +57,20 @@ def process_images(self, images):
if isinstance(image, str):
image = self.image_transform(image)
processed_images.append(image)
processed_images = tf.stack(processed_images)
processed_images = ops.stack(processed_images)
return processed_images

def process_texts(self, texts, context_length: int = 77, truncate: bool = False):
def process_texts(
self, texts, context_length: int = 77, truncate: bool = False
):
if isinstance(texts, str):
texts = [texts]

sot_token = self.tokenizer.encoder["<|startoftext|>"]
eot_token = self.tokenizer.encoder["<|endoftext|>"]
all_tokens = [
[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts
[sot_token] + self.tokenizer.encode(text) + [eot_token]
for text in texts
]

result = np.zeros(shape=[len(all_tokens), context_length])
Expand All @@ -70,7 +86,7 @@ def process_texts(self, texts, context_length: int = 77, truncate: bool = False)
)
result[i, : len(tokens)] = tokens

result = tf.stack(result)
result = ops.stack(result)
return result

def process_pair(self, images, texts, device=None):
Expand All @@ -82,4 +98,3 @@ def process_pair(self, images, texts, device=None):
images = self.process_images(images)
texts = self.process_texts(texts)
return (images, texts)

0 comments on commit bedf7eb

Please sign in to comment.