diff --git a/keras_cv/models/clip/clip_image_encoder.py b/keras_cv/models/clip/clip_image_encoder.py index 5d08d9e0dc..d652ddb924 100644 --- a/keras_cv/models/clip/clip_image_encoder.py +++ b/keras_cv/models/clip/clip_image_encoder.py @@ -1,21 +1,16 @@ -# encode images -from collections import OrderedDict -from typing import Tuple -from typing import Union - -import numpy as np -import tensorflow as tf -import torch -import torch.nn.functional as F -from tensorflow.keras import layers -from torch import nn - -from deepvision.layers.clip_patching_and_embedding import CLIPPatchingAndEmbedding -from deepvision.layers.residual_transformer_encoder import ResidualTransformerEncoder +from deepvision.layers.clip_patching_and_embedding import ( + CLIPPatchingAndEmbedding, +) +from deepvision.layers.residual_transformer_encoder import ( + ResidualTransformerEncoder, +) from deepvision.utils.utils import parse_model_inputs +from keras_cv.backend import keras +from keras_cv.backend import ops -class __CLIPImageEncoderTF(tf.keras.Model): + +class __CLIPImageEncoder(keras.Model): def __init__( self, input_resolution: int, @@ -27,7 +22,7 @@ def __init__( input_tensor=None, **kwargs, ): - inputs = tf.keras.layers.Input( + inputs = keras.layers.Input( tensor=input_tensor, shape=(input_resolution, input_resolution, 3) ) x = inputs @@ -38,15 +33,17 @@ def __init__( input_resolution=input_resolution, backend="tensorflow", )(x) - x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x) + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) - x = tf.transpose(x, perm=(1, 0, 2)) - x = ResidualTransformerEncoder(width, layers, heads, backend="tensorflow")(x) - x = tf.transpose(x, perm=(1, 0, 2)) + x = ops.transpose(x, perm=(1, 0, 2)) + x = ResidualTransformerEncoder( + width, layers, heads, backend="tensorflow" + )(x) + x = ops.transpose(x, perm=(1, 0, 2)) - x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :]) + x = keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :]) - proj = tf.keras.layers.Dense(output_dim) + proj = keras.layers.Dense(output_dim) x = proj(x) output = x @@ -55,4 +52,4 @@ def __init__( inputs=inputs, outputs=output, **kwargs, - ) \ No newline at end of file + ) diff --git a/keras_cv/models/clip/clip_processor.py b/keras_cv/models/clip/clip_processor.py index c9b12a79c9..baeb77a058 100644 --- a/keras_cv/models/clip/clip_processor.py +++ b/keras_cv/models/clip/clip_processor.py @@ -2,13 +2,16 @@ from typing import Union import numpy as np -import tensorflow as tf - +from deepvision.models.feature_extractors.clip.clip_tokenizer import ( + CLIPTokenizer, +) from pkg_resources import packaging -from deepvision.models.feature_extractors.clip.clip_tokenizer import CLIPTokenizer +from keras_cv.backend import keras +from keras_cv.backend import ops + -class __CLIPProcessorTF: +class __CLIPProcessor: def __init__(self, input_resolution): self.input_resolution = input_resolution self.image_transform = self.transform_image @@ -16,22 +19,32 @@ def __init__(self, input_resolution): def transform_image(self, image_path): input_resolution = self.input_resolution - mean = tf.constant([0.48145466, 0.4578275, 0.40821073]) - std = tf.constant([0.26862954, 0.26130258, 0.27577711]) + mean = np.array([0.48145466, 0.4578275, 0.40821073]) + std = np.array([0.26862954, 0.26130258, 0.27577711]) - image = tf.io.read_file(image_path) - image = tf.image.decode_jpeg(image, channels=3) + image = keras.utils.load_img(image_path) + image = keras.utils.img_to_array(image) image = ( - tf.image.resize( + ops.image.resize( image, (input_resolution, input_resolution), - method=tf.image.ResizeMethod.BICUBIC, + method="bicubic", ) / 255.0 ) - image = tf.image.central_crop( - image, central_fraction=input_resolution / image.shape[0] + central_fraction = input_resolution / image.shape[0] + width, height = image.shape[0], image.shape[1] + left = ops.cast((width - width * central_fraction) / 2, dtype="int32") + top = ops.cast((height - height * central_fraction) / 2, dtype="int32") + right = ops.cast((width + width * central_fraction) / 2, dtype="int32") + bottom = ops.cast( + (height + height * central_fraction) / 2, dtype="int32" + ) + + image = ops.slice( + image, [left, top, 0], [right - left, bottom - top, 3] ) + image = (image - mean) / std return image @@ -44,17 +57,20 @@ def process_images(self, images): if isinstance(image, str): image = self.image_transform(image) processed_images.append(image) - processed_images = tf.stack(processed_images) + processed_images = ops.stack(processed_images) return processed_images - def process_texts(self, texts, context_length: int = 77, truncate: bool = False): + def process_texts( + self, texts, context_length: int = 77, truncate: bool = False + ): if isinstance(texts, str): texts = [texts] sot_token = self.tokenizer.encoder["<|startoftext|>"] eot_token = self.tokenizer.encoder["<|endoftext|>"] all_tokens = [ - [sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts + [sot_token] + self.tokenizer.encode(text) + [eot_token] + for text in texts ] result = np.zeros(shape=[len(all_tokens), context_length]) @@ -70,7 +86,7 @@ def process_texts(self, texts, context_length: int = 77, truncate: bool = False) ) result[i, : len(tokens)] = tokens - result = tf.stack(result) + result = ops.stack(result) return result def process_pair(self, images, texts, device=None): @@ -82,4 +98,3 @@ def process_pair(self, images, texts, device=None): images = self.process_images(images) texts = self.process_texts(texts) return (images, texts) -