diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 76ac0631b4..8e1bf5d616 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -70,6 +70,7 @@ then keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ + keras_cv/models/feature_extractor/clip \ keras_cv/models/stable_diffusion else pytest --cache-clear --check_gpu --run_large --durations 0 \ @@ -84,5 +85,6 @@ else keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ + keras_cv/models/feature_extractor/clip \ keras_cv/models/stable_diffusion fi \ No newline at end of file diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index c3e6d49caf..0cd96643d7 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -95,6 +95,13 @@ def __init__( self.transformer_layers = transformer_layers vision_heads = self.vision_width // 64 + self.image_input = keras.layers.Input(shape=(None,), name="image") + self.text_input = keras.layers.Input( + shape=(None, None, self.context_length), name="text" + ) + self.attention_mask_input = keras.layers.Input( + shape=(None, None, self.context_length), name="attention_mask" + ) self.image_encoder = CLIPImageEncoder( input_resolution=self.image_resolution, patch_size=self.vision_patch_size, @@ -133,7 +140,12 @@ def encode_images(self, image): def encode_text(self, text, attention_mask=None): return self.text_encoder(text, attention_mask=attention_mask) - def call(self, image, text, attention_mask=None): + def call(self, inputs): + image, text = inputs["image"], inputs["text"] + if "attention_mask" in inputs: + attention_mask = inputs["attention_mask"] + else: + attention_mask = None self.image_embeddings = self.encode_images(image) self.text_embeddings = self.encode_text( text, attention_mask=attention_mask diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index 14304b73ef..1a657d540c 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -34,27 +34,24 @@ "https://storage.googleapis.com/keras-cv/models/clip/merges.txt", ) -MODEL_PATH = keras.utils.get_file( - None, - "https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", # noqa: E501 -) - class CLIPTest(TestCase): @pytest.mark.large def test_clip_model_golden_values(self): - model = CLIP() - model.load_weights(MODEL_PATH) + model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) image_logits, text_logits = model( - processed_image, processed_text, attention_mask + { + "image": processed_image, + "text": processed_text, + "attention_mask": attention_mask, + } ) - print(image_logits) - self.assertAllClose(image_logits, [[1.896713, 1.896713, 1.896713]]) + self.assertAllClose(image_logits, [[1.896712, 1.896712, 1.896712]]) self.assertAllClose( - text_logits, ops.transpose([[1.896713, 1.896713, 1.896713]]) + text_logits, ops.transpose([[1.896712, 1.896712, 1.896712]]) ) def test_clip_preprocessor(self): @@ -83,20 +80,29 @@ def test_presets(self): processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) image_logits, text_logits = model( - processed_image, processed_text, attention_mask + { + "image": processed_image, + "text": processed_text, + "attention_mask": attention_mask, + } ) @pytest.mark.large def test_image_encoder_golden_values(self): - model = CLIP() - model.load_weights(MODEL_PATH) + model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - model(processed_image, processed_text, attention_mask) + model( + { + "image": processed_image, + "text": processed_text, + "attention_mask": attention_mask, + } + ) self.assertAllClose( model.image_embeddings[:, :5], - [[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]], + [[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]], ) @pytest.mark.large @@ -105,8 +111,13 @@ def test_text_encoder_golden_values(self): processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - model(processed_image, processed_text, attention_mask) - print(model.text_embeddings) + model( + { + "image": processed_image, + "text": processed_text, + "attention_mask": attention_mask, + } + ) self.assertAllClose( model.text_embeddings[0, :3], [0.007531, -0.038361, -0.035686], @@ -118,7 +129,13 @@ def test_saved_model(self): processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - model_output, _ = model(processed_image, processed_text, attention_mask) + model_output, _ = model( + { + "image": processed_image, + "text": processed_text, + "attention_mask": attention_mask, + } + ) save_path = os.path.join(self.get_temp_dir(), "model.keras") if keras_3(): model.save(save_path) @@ -130,6 +147,10 @@ def test_saved_model(self): self.assertIsInstance(restored_model, CLIP) # Check that output matches. restored_output, _ = restored_model( - processed_image, processed_text, attention_mask + { + "image": processed_image, + "text": processed_text, + "attention_mask": attention_mask, + } ) self.assertAllClose(model_output, restored_output)