Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive Modelling #77

Open
koaning opened this issue Aug 8, 2023 · 1 comment
Open

Contrastive Modelling #77

koaning opened this issue Aug 8, 2023 · 1 comment

Comments

@koaning
Copy link
Owner

koaning commented Aug 8, 2023

I think there's an opportunity for this library to make it much easier to finetune embeddings for models. So I figured I might write up an API proposal for myself. Here's some of the additions I'd like to add.

Right now, it feels like it makes sense to implement all of this in keras. With the advent of keras-core we may yet have an opportunity to keep things flexible for jax/tf/torch users.

Here's the components that I'd like to add.

Contrastive Model

This encoder assumes that you'll assume the same encoder for X1 and X2. This is quite reasonable for text comparison tasks, but won't hold for image/text multimodal situations.

from embetter.finetune import ContrastiveModel

model = ContrastiveModel().fit(X1, X2, y)
# If you want to train for a single epoch
model.partial_fit(X1, X2, y)
# If you want to leverage the keras generator to feed data
model.fit_generator(generator)
model.transform(X1)
model.transform(X2)
model.predict(X1, X2)

Such a contrastive fine-tuner might also allow folks to pretrain on their own datasets too. We can even make helpers for that, but this model only accepts binary values for y.

MultiClassifier

With such a constrastive model, we might be able to build a multi-label/multi-head classifier. I've always found it annoying that it's hard to create a model that is able to train on non-overlapping labels. The MultiClassifier can be that categoriser that I've wanted to have for a while.

from embetter.model import MultiClassifier

mc = MultiClassifier(
    classifier_head=LogisticRegression(weights="balanced"),
    finetuner=ContrastiveModel()
)

# If you only have one label
mc.fit(X, y)
# If you have multiple labels from different annotated sets. 
mc.fit_pairs(lab1=(X, y), lab2=(X, y), lab3=(X, y))
# Can we use the keras generator here? Not 100% sure. 
# mc.fit_generator(generator)
mc.encode(X)
mc.transform(X)
mc.predict(X)

The goal is to offer few hyperparams and to just offer a reasonable starting point. Again y is binary, but you can pass the labelname via the **kwargs in fit_pairs.

ContrastiveMultiModalModel

This encoder is more complex because it does not assume that X1 and X2 have the same encoder.

model = ContrastiveMultiModalModel().fit(X1, X2, y)
model.partial_fit(X1, X2, y)
model.fit_generator(generator)
model.transform_enc1(X1)
model.transform_enc2(X2)
model.predict(X1, X2)

This can be useful for folks in recommender-land.

@koaning
Copy link
Owner Author

koaning commented Aug 8, 2023

I even wrote some code for this, just as a proof of concept.

Show code
import matplotlib.pylab as plt 
from sklearn.decomposition import PCA 
import numpy as np
import srsly 

from embetter.finetune._contrastive import generate_pairs_batch
from embetter.text import SentenceEncoder
from embetter.utils import cached

from keras.losses import MeanSquaredError
from keras.models import Model, Sequential
from keras.layers import Dense, Input, Lambda, Subtract, Dot, Flatten
from keras import backend as K
from keras.optimizers import Adam

# Define the base model
def create_base_model(hidden_dim, n_layers, activation, input_shape):
    model = Sequential()
    for layer in range(n_layers):
        model.add(Dense(hidden_dim, activation=activation, input_shape=input_shape))
    return model

# Compute the cosine similarity
def cosine_similarity(vectors):
    x, y = vectors
    x = K.l2_normalize(x, axis=-1)
    y = K.l2_normalize(y, axis=-1)
    return Dot(axes=-1, normalize=False)([x, y])

# Contrastive loss using cosine similarity
def contrastive_loss(y_true, y_pred):
    margin = 1.0
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean(y_true * square_pred + (1 - y_true) * margin_square)

class ContrastiveFinetuner:
    def __init__(self, hidden_dim=300, n_layers=1, activation=None):
        self.hidden_dim = hidden_dim 
        self.activation = activation
        self.n_layers = n_layers

    def _construct_model(self, X1, X2):
        shape1 = (X1.shape[1], )
        shape2 = (X2.shape[1], )
        mod = create_base_model(self.hidden_dim, self.n_layers, self.activation, shape1)
        input1 = Input(shape=shape1)
        input2 = Input(shape=shape2)
        vector1 = mod(input1)
        vector2 = mod(input2)
        cosine_sim = Lambda(cosine_similarity)([vector1, vector2])
        cosine_sim = Flatten()(cosine_sim)
        model = Model(inputs=[input1, input2], outputs=cosine_sim)
        model.compile(optimizer=Adam(), loss=MeanSquaredError())
        return model, mod

dataset = list(srsly.read_jsonl("new-dataset.jsonl"))
labels = [ex['cats']['new-dataset'] for ex in dataset]
texts = [ex['text'] for ex in dataset]
pairs = generate_pairs_batch(labels)
enc = cached("sbert", SentenceEncoder())
X = enc.transform(texts)

X1 = np.array([X[ex.i1] for ex in pairs])
X2 = np.array([X[ex.i2] for ex in pairs])

# Before
X_pca = PCA(2).fit_transform(X)
plt.scatter(X_pca[:,0], X_pca[:,1], c=labels, s=5)

model, enc = ContrastiveFinetuner(n_layers=1)._construct_model(X1, X2)
model.fit([X1, X2], np.array([ex.label for ex in pairs], dtype=float), epochs=100, verbose=2)

# After
X_pca = PCA(2).fit_transform(enc.predict(X))

plt.scatter(X_pca[:,0], X_pca[:,1], c=labels, s=5)

Before

image

After

image

This isn't an all-encompassing benchmark or anything. But it does seem to "kind of work" and might be a nice starting point for something more general than setfit. The main goal here, again, is to make the rapid prototyping awesome.

@koaning koaning mentioned this issue Aug 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant