From 82fe5dd7de3b403c38ca11a00d0d79f332af365c Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 13 Jul 2023 20:19:41 +0200 Subject: [PATCH] initial dump (cherry picked from commit 49781a63139d8f484e5b71d43b9b1e1eb542aee1) --- keras_cv/models/clip/clip.py | 0 keras_cv/models/clip/clip_image_encoder.py | 58 ++++++++ keras_cv/models/clip/clip_processor.py | 85 ++++++++++++ keras_cv/models/clip/clip_tokenizer.py | 147 +++++++++++++++++++++ 4 files changed, 290 insertions(+) create mode 100644 keras_cv/models/clip/clip.py create mode 100644 keras_cv/models/clip/clip_image_encoder.py create mode 100644 keras_cv/models/clip/clip_processor.py create mode 100644 keras_cv/models/clip/clip_tokenizer.py diff --git a/keras_cv/models/clip/clip.py b/keras_cv/models/clip/clip.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_cv/models/clip/clip_image_encoder.py b/keras_cv/models/clip/clip_image_encoder.py new file mode 100644 index 0000000000..5d08d9e0dc --- /dev/null +++ b/keras_cv/models/clip/clip_image_encoder.py @@ -0,0 +1,58 @@ +# encode images +from collections import OrderedDict +from typing import Tuple +from typing import Union + +import numpy as np +import tensorflow as tf +import torch +import torch.nn.functional as F +from tensorflow.keras import layers +from torch import nn + +from deepvision.layers.clip_patching_and_embedding import CLIPPatchingAndEmbedding +from deepvision.layers.residual_transformer_encoder import ResidualTransformerEncoder +from deepvision.utils.utils import parse_model_inputs + + +class __CLIPImageEncoderTF(tf.keras.Model): + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + input_tensor=None, + **kwargs, + ): + inputs = tf.keras.layers.Input( + tensor=input_tensor, shape=(input_resolution, input_resolution, 3) + ) + x = inputs + + x = CLIPPatchingAndEmbedding( + width=width, + patch_size=patch_size, + input_resolution=input_resolution, + backend="tensorflow", + )(x) + x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x) + + x = tf.transpose(x, perm=(1, 0, 2)) + x = ResidualTransformerEncoder(width, layers, heads, backend="tensorflow")(x) + x = tf.transpose(x, perm=(1, 0, 2)) + + x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :]) + + proj = tf.keras.layers.Dense(output_dim) + x = proj(x) + + output = x + + super().__init__( + inputs=inputs, + outputs=output, + **kwargs, + ) \ No newline at end of file diff --git a/keras_cv/models/clip/clip_processor.py b/keras_cv/models/clip/clip_processor.py new file mode 100644 index 0000000000..c9b12a79c9 --- /dev/null +++ b/keras_cv/models/clip/clip_processor.py @@ -0,0 +1,85 @@ +from typing import List +from typing import Union + +import numpy as np +import tensorflow as tf + +from pkg_resources import packaging + +from deepvision.models.feature_extractors.clip.clip_tokenizer import CLIPTokenizer + +class __CLIPProcessorTF: + def __init__(self, input_resolution): + self.input_resolution = input_resolution + self.image_transform = self.transform_image + self.tokenizer = CLIPTokenizer() + + def transform_image(self, image_path): + input_resolution = self.input_resolution + mean = tf.constant([0.48145466, 0.4578275, 0.40821073]) + std = tf.constant([0.26862954, 0.26130258, 0.27577711]) + + image = tf.io.read_file(image_path) + image = tf.image.decode_jpeg(image, channels=3) + image = ( + tf.image.resize( + image, + (input_resolution, input_resolution), + method=tf.image.ResizeMethod.BICUBIC, + ) + / 255.0 + ) + image = tf.image.central_crop( + image, central_fraction=input_resolution / image.shape[0] + ) + image = (image - mean) / std + return image + + def process_images(self, images): + if isinstance(images, str): + images = [images] + + processed_images = [] + for image in images: + if isinstance(image, str): + image = self.image_transform(image) + processed_images.append(image) + processed_images = tf.stack(processed_images) + return processed_images + + def process_texts(self, texts, context_length: int = 77, truncate: bool = False): + if isinstance(texts, str): + texts = [texts] + + sot_token = self.tokenizer.encoder["<|startoftext|>"] + eot_token = self.tokenizer.encoder["<|endoftext|>"] + all_tokens = [ + [sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts + ] + + result = np.zeros(shape=[len(all_tokens), context_length]) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f"Input {texts[i]} is too long for context length {context_length}" + ) + result[i, : len(tokens)] = tokens + + result = tf.stack(result) + return result + + def process_pair(self, images, texts, device=None): + if device: + raise ValueError( + "device argument is only supported for the PyTorch backend" + ) + + images = self.process_images(images) + texts = self.process_texts(texts) + return (images, texts) + diff --git a/keras_cv/models/clip/clip_tokenizer.py b/keras_cv/models/clip/clip_tokenizer.py new file mode 100644 index 0000000000..26a8b1331e --- /dev/null +++ b/keras_cv/models/clip/clip_tokenizer.py @@ -0,0 +1,147 @@ +# encodes text +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +class CLIPTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = self.__bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(self.__bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + @lru_cache() + def __bytes_to_unicode(self): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + def __get_pairs(self, word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + def __basic_clean(self, text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def __whitespace_clean(self, text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + def __bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = self.__get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = self.__get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = self.__whitespace_clean(self.__basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.__bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text \ No newline at end of file