Skip to content

Commit

Permalink
initial dump
Browse files Browse the repository at this point in the history
(cherry picked from commit 49781a6)
  • Loading branch information
DavidLandup0 authored and divyashreepathihalli committed Dec 18, 2023
1 parent e360fb7 commit 82fe5dd
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 0 deletions.
Empty file added keras_cv/models/clip/clip.py
Empty file.
58 changes: 58 additions & 0 deletions keras_cv/models/clip/clip_image_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# encode images
from collections import OrderedDict
from typing import Tuple
from typing import Union

import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
from tensorflow.keras import layers
from torch import nn

from deepvision.layers.clip_patching_and_embedding import CLIPPatchingAndEmbedding
from deepvision.layers.residual_transformer_encoder import ResidualTransformerEncoder
from deepvision.utils.utils import parse_model_inputs


class __CLIPImageEncoderTF(tf.keras.Model):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
input_tensor=None,
**kwargs,
):
inputs = tf.keras.layers.Input(
tensor=input_tensor, shape=(input_resolution, input_resolution, 3)
)
x = inputs

x = CLIPPatchingAndEmbedding(
width=width,
patch_size=patch_size,
input_resolution=input_resolution,
backend="tensorflow",
)(x)
x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x)

x = tf.transpose(x, perm=(1, 0, 2))
x = ResidualTransformerEncoder(width, layers, heads, backend="tensorflow")(x)
x = tf.transpose(x, perm=(1, 0, 2))

x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x[:, 0, :])

proj = tf.keras.layers.Dense(output_dim)
x = proj(x)

output = x

super().__init__(
inputs=inputs,
outputs=output,
**kwargs,
)
85 changes: 85 additions & 0 deletions keras_cv/models/clip/clip_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List
from typing import Union

import numpy as np
import tensorflow as tf

from pkg_resources import packaging

from deepvision.models.feature_extractors.clip.clip_tokenizer import CLIPTokenizer

class __CLIPProcessorTF:
def __init__(self, input_resolution):
self.input_resolution = input_resolution
self.image_transform = self.transform_image
self.tokenizer = CLIPTokenizer()

def transform_image(self, image_path):
input_resolution = self.input_resolution
mean = tf.constant([0.48145466, 0.4578275, 0.40821073])
std = tf.constant([0.26862954, 0.26130258, 0.27577711])

image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = (
tf.image.resize(
image,
(input_resolution, input_resolution),
method=tf.image.ResizeMethod.BICUBIC,
)
/ 255.0
)
image = tf.image.central_crop(
image, central_fraction=input_resolution / image.shape[0]
)
image = (image - mean) / std
return image

def process_images(self, images):
if isinstance(images, str):
images = [images]

processed_images = []
for image in images:
if isinstance(image, str):
image = self.image_transform(image)
processed_images.append(image)
processed_images = tf.stack(processed_images)
return processed_images

def process_texts(self, texts, context_length: int = 77, truncate: bool = False):
if isinstance(texts, str):
texts = [texts]

sot_token = self.tokenizer.encoder["<|startoftext|>"]
eot_token = self.tokenizer.encoder["<|endoftext|>"]
all_tokens = [
[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts
]

result = np.zeros(shape=[len(all_tokens), context_length])

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}"
)
result[i, : len(tokens)] = tokens

result = tf.stack(result)
return result

def process_pair(self, images, texts, device=None):
if device:
raise ValueError(
"device argument is only supported for the PyTorch backend"
)

images = self.process_images(images)
texts = self.process_texts(texts)
return (images, texts)

147 changes: 147 additions & 0 deletions keras_cv/models/clip/clip_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# encodes text
import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def default_bpe():
return os.path.join(
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
)


class CLIPTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = self.__bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(self.__bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
"<|startoftext|>": "<|startoftext|>",
"<|endoftext|>": "<|endoftext|>",
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)

@lru_cache()
def __bytes_to_unicode(self):
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def __get_pairs(self, word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs

def __basic_clean(self, text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()

def __whitespace_clean(self, text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text

def __bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = self.__get_pairs(word)

if not pairs:
return token + "</w>"

while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = self.__get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
text = self.__whitespace_clean(self.__basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.__bpe(token).split(" ")
)
return bpe_tokens

def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text

0 comments on commit 82fe5dd

Please sign in to comment.