From 9238b0644623fd65ad8bcd8ff454e6c4c330ff18 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 29 Oct 2024 06:49:57 +0800 Subject: [PATCH] Add `CLIP` model (#1955) * Add `CLIPVisionEmbedding` * Add `CLIPBackbone` and `CLIPVisionEncoder` and `CLIPImageConverter` * Fix typo --- keras_hub/api/layers/__init__.py | 1 + keras_hub/api/models/__init__.py | 3 + keras_hub/src/models/clip/clip_backbone.py | 242 ++++++++++++++++++ .../src/models/clip/clip_encoder_block.py | 7 +- .../src/models/clip/clip_image_converter.py | 8 + .../src/models/clip/clip_text_encoder.py | 2 + .../src/models/clip/clip_vision_embedding.py | 91 +++++++ .../src/models/clip/clip_vision_encoder.py | 158 ++++++++++++ 8 files changed, 511 insertions(+), 1 deletion(-) create mode 100644 keras_hub/src/models/clip/clip_backbone.py create mode 100644 keras_hub/src/models/clip/clip_image_converter.py create mode 100644 keras_hub/src/models/clip/clip_vision_embedding.py create mode 100644 keras_hub/src/models/clip/clip_vision_encoder.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 78a26075d1..adec61931c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -34,6 +34,7 @@ from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( DeepLabV3ImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index e0e8773a35..2082bc3b99 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -53,8 +53,11 @@ from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( CSPDarkNetBackbone, ) diff --git a/keras_hub/src/models/clip/clip_backbone.py b/keras_hub/src/models/clip/clip_backbone.py new file mode 100644 index 0000000000..e52336b59e --- /dev/null +++ b/keras_hub/src/models/clip/clip_backbone.py @@ -0,0 +1,242 @@ +import math + +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + + +class CLIPVisionPooler(layers.Layer): + """The vision pooler layer of CLIP. + + `CLIPVisionPooler` will extracts the first token (index `0`) from the + sequence of the vision embeddings as the pooled outputs. + + Call arguments: + vision_embeddings: A tensor of shape + `(batch_size, sequence_length, hidden_dim)`. + """ + + def call(self, vision_embeddings): + pooled_outputs = vision_embeddings[:, 0, :] + return pooled_outputs + + +class CLIPTextPooler(layers.Layer): + """The text pooler layer of CLIP. + + `CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens + as the pooled outputs. + + Call arguments: + text_embeddings: A tensor of shape + `(batch_size, sequence_length, hidden_dim)`. + token_ids: A tensor of shape `(batch_size, max_tokens)`, used to + identify the positions of EOS tokens. + """ + + def call(self, text_embeddings, token_ids): + eos_index = ops.argmax(token_ids, axis=-1, keepdims=True) + eos_index = ops.expand_dims(eos_index, axis=-1) + pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1) + pooled_outputs = ops.squeeze(pooled_outputs, axis=1) + return pooled_outputs + + +class CLIPHead(layers.Layer): + """The head layer of CLIP. + + `CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to + compute the corresponding logits. Both embeddings are L2 normalized and used + to compute pairwise cosine similarity. The resulting logits are then scaled + by a learnable `logit_scale` parameter. + + Call arguments: + vision_embedding: A tensor of shape `(batch_size, hidden_dim)`. + text_embedding: A tensor of shape `(batch_size, hidden_dim)`. + """ + + def build(self, input_shape): + self.logit_scale = self.add_weight( + shape=(), + initializer=lambda *a, **kw: math.log(1 / 0.07), + trainable=True, + dtype=self.variable_dtype, + name="logit_scale", + ) + + def call(self, vision_embedding, text_embedding): + normalized_vision_embedding = ops.sqrt( + ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True) + ) + normalized_text_embedding = ops.sqrt( + ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True) + ) + vision_embedding = vision_embedding / normalized_vision_embedding + text_embedding = text_embedding / normalized_text_embedding + logit_scale = ops.exp(self.logit_scale) + text_logits = ( + ops.matmul( + text_embedding, + ops.transpose(vision_embedding), + ) + * logit_scale + ) + vision_logits = ops.transpose(text_logits) + return vision_logits, text_logits + + +@keras_hub_export("keras_hub.models.CLIPBackbone") +class CLIPBackbone(Backbone): + """CLIP core network with hyperparameters. + + This backbone implements the base architecture for Contrastive + Language-Image Pretraining (CLIP) model. It includes a vision and text + encoders and the corresponding projection layers. This backbone will output + the final logit scores corresponding to each image and token input. These + values are cosine similarities between the corresponding image and text + features. + + The default constructor gives a fully customizable, randomly initialized + CLIP model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset` constructor. + + Args: + vision_encoder: The CLIP vision encoder for encoding the input images. + text_encoder: The CLIP text encoder for encoding the input tokens. + projection_dim: int. The size of the projection layer. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example: + ```python + input_data = { + "images": np.ones(shape=(1, 224, 224, 3), dtype="float32"), + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + } + + # Pretrained CLIP model. + model = keras_hub.models.CLIPBackbone.from_preset("clip-vit-base-patch32") + model(input_data) + + # Randomly initialized CLIP model with custom config. + vision_encoder = keras_hub.models.CLIPVisionEncoder( + patch_size=32, + hidden_dim=768, + num_layers=8, + num_heads=8, + intermediate_dim=2048, + image_shape=(384, 384, 3), + ) + text_encoder = keras_hub.models.CLIPTextEncoder( + vocabulary_size=49408, + embedding_dim=768, + hidden_dim=768, + num_layers=8, + num_heads=8, + intermediate_dim=2048, + ) + model = keras_hub.models.CLIPBackbone( + vision_encoder=50257, + text_encoder=12, + projection_dim=256, + ) + model(input_data) + ``` + """ + + def __init__( + self, + vision_encoder, + text_encoder, + projection_dim, + dtype=None, + name=None, + **kwargs, + ): + # === Layers === + self.vision_encoder = vision_encoder + self.text_encoder = text_encoder + self.vision_pooler = CLIPVisionPooler(dtype=dtype, name="vision_pooler") + self.text_pooler = CLIPTextPooler(dtype=dtype, name="text_pooler") + self.vision_projection = layers.Dense( + projection_dim, + use_bias=False, + dtype=dtype, + name="vision_projection", + ) + self.text_projection = layers.Dense( + projection_dim, + use_bias=False, + dtype=dtype, + name="text_projection", + ) + self.clip_head = CLIPHead(dtype=dtype, name="clip_head") + + # === Functional Model === + image_input = layers.Input( + shape=self.vision_encoder.image_shape, name="images" + ) + token_id_input = layers.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + vision_outputs = self.vision_encoder({"images": image_input}) + text_outputs = self.text_encoder({"token_ids": token_id_input}) + vision_outputs = self.vision_pooler(vision_outputs) + text_outputs = self.text_pooler(text_outputs, token_id_input) + vision_embeddings = self.vision_projection(vision_outputs) + text_embeddings = self.text_projection(text_outputs) + vision_logits, text_logits = self.clip_head( + vision_embeddings, text_embeddings + ) + + super().__init__( + inputs={ + "images": image_input, + "token_ids": token_id_input, + }, + outputs={ + "vision_logits": vision_logits, + "text_logits": text_logits, + }, + name=name, + **kwargs, + ) + + # === Config === + self.projection_dim = projection_dim + + def get_config(self): + config = super().get_config() + config.update( + { + "vision_encoder": layers.serialize(self.vision_encoder), + "text_encoder": layers.serialize(self.text_encoder), + "projection_dim": self.projection_dim, + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + + # Propagate `dtype` to submodels if needed. + if "dtype" in config and config["dtype"] is not None: + dtype_config = config["dtype"] + if "dtype" not in config["vision_encoder"]["config"]: + config["vision_encoder"]["config"]["dtype"] = dtype_config + if "dtype" not in config["text_encoder"]["config"]: + config["text_encoder"]["config"]["dtype"] = dtype_config + + # We expect submodels to be instantiated. + config["vision_encoder"] = layers.deserialize( + config["vision_encoder"], custom_objects=custom_objects + ) + config["text_encoder"] = layers.deserialize( + config["text_encoder"], custom_objects=custom_objects + ) + return cls(**config) diff --git a/keras_hub/src/models/clip/clip_encoder_block.py b/keras_hub/src/models/clip/clip_encoder_block.py index 522b9a5b64..2d6967ba70 100644 --- a/keras_hub/src/models/clip/clip_encoder_block.py +++ b/keras_hub/src/models/clip/clip_encoder_block.py @@ -14,6 +14,7 @@ def __init__( num_heads, intermediate_dim, intermediate_activation="quick_gelu", + use_causal_mask=True, **kwargs, ): super().__init__(**kwargs) @@ -26,6 +27,7 @@ def __init__( self.num_heads = num_heads self.intermediate_dim = intermediate_dim self.intermediate_activation = intermediate_activation + self.use_causal_mask = use_causal_mask if intermediate_activation == "quick_gelu": intermediate_activation = quick_gelu @@ -73,7 +75,9 @@ def compute_output_shape(self, inputs_shape): def call(self, x, training=None): residual = x x = self.layer_norm_1(x) - x = self.attention(x, x, x, training=training, use_causal_mask=True) + x = self.attention( + x, x, x, training=training, use_causal_mask=self.use_causal_mask + ) x = ops.add(residual, x) residual = x @@ -91,6 +95,7 @@ def get_config(self): "num_heads": self.num_heads, "intermediate_dim": self.intermediate_dim, "intermediate_activation": self.intermediate_activation, + "use_causal_mask": self.use_causal_mask, } ) return config diff --git a/keras_hub/src/models/clip/clip_image_converter.py b/keras_hub/src/models/clip/clip_image_converter.py new file mode 100644 index 0000000000..ad29ccb179 --- /dev/null +++ b/keras_hub/src/models/clip/clip_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone + + +@keras_hub_export("keras_hub.layers.CLIPImageConverter") +class CLIPImageConverter(ImageConverter): + backbone_cls = CLIPBackbone diff --git a/keras_hub/src/models/clip/clip_text_encoder.py b/keras_hub/src/models/clip/clip_text_encoder.py index 208acfddc7..503aef6abb 100644 --- a/keras_hub/src/models/clip/clip_text_encoder.py +++ b/keras_hub/src/models/clip/clip_text_encoder.py @@ -1,5 +1,6 @@ from keras import layers +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, ) @@ -7,6 +8,7 @@ from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock +@keras_hub_export("keras_hub.models.CLIPTextEncoder") class CLIPTextEncoder(Backbone): """CLIP text core network with hyperparameters. diff --git a/keras_hub/src/models/clip/clip_vision_embedding.py b/keras_hub/src/models/clip/clip_vision_embedding.py new file mode 100644 index 0000000000..7690319d90 --- /dev/null +++ b/keras_hub/src/models/clip/clip_vision_embedding.py @@ -0,0 +1,91 @@ +from keras import layers +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class CLIPVisionEmbedding(layers.Layer): + def __init__( + self, hidden_dim, patch_size, image_size, data_format=None, **kwargs + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.patch_size = int(patch_size) + self.image_size = int(image_size) + data_format = standardize_data_format(data_format) + self.data_format = data_format + num_patches = (image_size // patch_size) ** 2 + self.num_positions = num_patches + 1 + + self.patch_embedding = layers.Conv2D( + hidden_dim, + kernel_size=patch_size, + strides=patch_size, + data_format=data_format, + use_bias=False, + name="patch_embedding", + ) + self.position_embedding = layers.Embedding( + num_patches + 1, hidden_dim, name="position_embedding" + ) + + def build(self, input_shape): + self.class_embedding = self.add_weight( + shape=(self.hidden_dim,), + initializer="random_normal", + dtype=self.variable_dtype, + name="class_embedding", + ) + self.position_ids = self.add_weight( + shape=(1, self.num_positions), + initializer="zeros", + dtype="int32", + trainable=False, + name="position_ids", + ) + self.patch_embedding.build(input_shape) + self.position_embedding.build(self.position_ids.shape) + + def call(self, inputs, training=None): + x = inputs + batch_size = ops.shape(x)[0] + patch_embeddings = self.patch_embedding(x, training=training) + if self.data_format == "channels_last": + patch_embeddings = ops.reshape( + patch_embeddings, (batch_size, -1, self.hidden_dim) + ) + else: + patch_embeddings = ops.reshape( + patch_embeddings, (batch_size, self.hidden_dim, -1) + ) + patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1)) + class_embeddings = ops.expand_dims(self.class_embedding, axis=(0, 1)) + class_embeddings = ops.tile(class_embeddings, (batch_size, 1, 1)) + position_embeddings = self.position_embedding(self.position_ids) + embeddings = ops.concatenate( + [class_embeddings, patch_embeddings], axis=1 + ) + return ops.add(embeddings, position_embeddings) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "patch_size": self.patch_size, + "image_size": self.image_size, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = [input_shape[0], None, self.hidden_dim] + if self.data_format == "channels_last": + if input_shape[1] is not None and input_shape[2] is not None: + patch_num = input_shape[1] // self.patch_size + output_shape[1] = patch_num**2 + 1 + else: + if input_shape[2] is not None and input_shape[3] is not None: + patch_num = input_shape[2] // self.patch_size + output_shape[1] = patch_num**2 + 1 + return output_shape diff --git a/keras_hub/src/models/clip/clip_vision_encoder.py b/keras_hub/src/models/clip/clip_vision_encoder.py new file mode 100644 index 0000000000..a5967dc6ff --- /dev/null +++ b/keras_hub/src/models/clip/clip_vision_encoder.py @@ -0,0 +1,158 @@ +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock +from keras_hub.src.models.clip.clip_vision_embedding import CLIPVisionEmbedding +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.CLIPVisionEncoder") +class CLIPVisionEncoder(Backbone): + """CLIP vision core network with hyperparameters. + + Args: + patch_size: int. The size of each square patch in the input image. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + intermediate_activation: activation function. The activation that + is used for the first Dense layer in a two-layer feedforward network + for each transformer. + intermediate_output_index: optional int. The index of the intermediate + output. If specified, the output will become a dictionary with two + keys `"sequence_output"` and `"intermediate_output"`. + image_shape: tuple. The input shape without the batch size. Defaults to + `(224, 224, 3)`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + """ + + def __init__( + self, + patch_size, + hidden_dim, + num_layers, + num_heads, + intermediate_dim, + intermediate_activation="quick_gelu", + intermediate_output_index=None, + image_shape=(224, 224, 3), + data_format=None, + dtype=None, + name=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + if data_format == "channels_last": + height, width = image_shape[0], image_shape[1] + else: + height, width = image_shape[1], image_shape[2] + if height != width: + raise ValueError( + "`CLIPVisionEncoder` expects the height and width to be the " + f"same in `image_shape`. Received: image_shape={image_shape}" + ) + + if ( + intermediate_output_index is not None + and intermediate_output_index < 0 + ): + intermediate_output_index += num_layers + + # `prefix` is used to prevent duplicate name when utilizing multiple + # CLIP models within a single model, such as in StableDiffusion3. + prefix = str(name) + "_" if name is not None else "" + + # === Layers === + self.embedding = CLIPVisionEmbedding( + hidden_dim=hidden_dim, + patch_size=patch_size, + image_size=height, + data_format=data_format, + dtype=dtype, + name=f"{prefix}embedding", + ) + self.pre_layer_norm = layers.LayerNormalization( + epsilon=1e-5, dtype="float32", name=f"{prefix}pre_layer_norm" + ) + self.encoder_layers = [ + CLIPEncoderBlock( + hidden_dim, + num_heads, + intermediate_dim, + intermediate_activation, + use_causal_mask=False, # `False` in the vision encoder. + dtype=dtype, + name=f"{prefix}encoder_block_{i}", + ) + for i in range(num_layers) + ] + self.layer_norm = layers.LayerNormalization( + epsilon=1e-5, dtype="float32", name=f"{prefix}layer_norm" + ) + + # === Functional Model === + image_input = layers.Input(shape=image_shape, name="images") + x = self.embedding(image_input) + x = self.pre_layer_norm(x) + intermediate_output = None + for i, block in enumerate(self.encoder_layers): + x = block(x) + if i == intermediate_output_index: + intermediate_output = x + sequence_output = self.layer_norm(x) + + if intermediate_output_index is not None: + outputs = { + "sequence_output": sequence_output, + "intermediate_output": intermediate_output, + } + else: + outputs = sequence_output + super().__init__( + inputs={"images": image_input}, + outputs=outputs, + name=name, + **kwargs, + ) + + # === Config === + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.intermediate_activation = intermediate_activation + self.intermediate_output_index = intermediate_output_index + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "hidden_dim": self.hidden_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "intermediate_activation": self.intermediate_activation, + "intermediate_output_index": self.intermediate_output_index, + "image_shape": self.image_shape, + } + ) + return config