Skip to content

Commit

Permalink
feat: add distillation loss for ColPali2 (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Sep 4, 2024
1 parent 774a0f2 commit 9453e04
Showing 1 changed file with 88 additions and 14 deletions.
102 changes: 88 additions & 14 deletions colpali_engine/models/colpali_2/colpali_2_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812

from colpali_engine.models.colpali_2.colpali_2_modeling_outputs import ColPali2ModelOutput
Expand Down Expand Up @@ -32,10 +33,11 @@ def forward(self, output, target) -> torch.Tensor:
return weighted_losses.sum()


@dataclass
@dataclass(kw_only=True)
class ColPali2LossOutputs:
single_vector_loss: torch.Tensor
multi_vector_loss: torch.Tensor
distillation_loss: Optional[torch.Tensor] = None
total_loss: torch.Tensor


Expand All @@ -50,37 +52,51 @@ class ColPali2Loss(torch.nn.Module):

def __init__(
self,
alpha: float = 0.5,
use_matryoshka_loss: bool = True,
use_distillation_loss: bool = True,
beta: float = 0.5,
temperature: float = 2.0,
):
super().__init__()
self.alpha = alpha
self.use_matryoshka_loss = use_matryoshka_loss
self.alpha: float = 0.5
self.use_distillation_loss = use_distillation_loss
self.beta = beta
self.temperature = temperature

def single_vector_loss(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
) -> torch.Tensor:
return_scores: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
loss_fn = MatryoshkaCELoss() if self.use_matryoshka_loss else F.cross_entropy

if self.use_matryoshka_loss:
loss = self.single_vector_loss(scores, torch.arange(scores.shape[0], device=scores.device))
loss = loss_fn(scores, torch.arange(scores.shape[0], device=scores.device)) # (1,)

if return_scores:
return loss, scores
else:
loss = F.cross_entropy(scores, torch.arange(scores.shape[0], device=scores.device))
return loss
return loss

def multi_vector_loss(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
) -> torch.Tensor:
return_scores: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
NOTE: If `return_scores` is True, the function will return only the positive scores, i.e.
the diagonal of the scores matrix.
"""
# Compute the ColBERT scores
scores = (
Expand All @@ -97,18 +113,76 @@ def multi_vector_loss(
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)

# Compute the margin loss
loss = F.softplus(neg_scores - pos_scores).mean()
loss = F.softplus(neg_scores - pos_scores).mean() # (1,)

if return_scores:
return loss, pos_scores
else:
return loss

def distillation_loss(
self,
teacher_scores: torch.Tensor,
student_scores: torch.Tensor,
teacher_score_upper_bound: int,
):
"""
Compute the distillation loss between the multi-vector head (teacher) and
the single-vector head (student).
return loss
Inputs:
- teacher_scores: (batch_size)
- student_scores: (batch_size)
- teacher_score_upper_bound: The upper bound of the teacher scores.
"""
kl_div_loss = nn.KLDivLoss(reduction="batchmean")

# NOTE: Both the teacher and student scores should be turned into log-probabilities before
# computing the KL-divergence.
# The embeddings are normalized, thus we know the lower and upper bounds of the scores:
# - Teacher: the multi-vector scores (MaxSim) are between 0 and N_q, N_q being the number of query tokens
# - Student: the single-vector scores are between 0 and 1.

# Convert the scores to log-probabilities
teacher_logits = torch.logit(teacher_scores / teacher_score_upper_bound, eps=1e-6)
student_logits = torch.logit(student_scores, eps=1e-6)

# NOTE:
# - KLDivLoss argument order is the opposite of the KL(·||·) mathematical function.
# - KLDivLoss expects log-probabilities for `input` to avoid underflow issues.
loss_kd = self.temperature**2 * kl_div_loss(
input=student_logits / self.temperature,
target=teacher_logits / self.temperature,
) # (1,)

return loss_kd

def forward(
self,
query_embeddings: ColPali2ModelOutput,
doc_embeddings: ColPali2ModelOutput,
) -> ColPali2LossOutputs:
single_vector_loss = self.single_vector_loss(query_embeddings.single_vec_emb, doc_embeddings.single_vec_emb)
multi_vector_loss = self.multi_vector_loss(query_embeddings.multi_vec_emb, doc_embeddings.multi_vec_emb)
single_vector_loss, single_vector_scores = self.single_vector_loss(
query_embeddings.single_vec_emb, doc_embeddings.single_vec_emb, return_scores=True
)
multi_vector_loss, multi_vector_scores = self.multi_vector_loss(
query_embeddings.multi_vec_emb, doc_embeddings.multi_vec_emb, return_scores=True
)

total_loss = self.alpha * single_vector_loss + (1 - self.alpha) * multi_vector_loss

return ColPali2LossOutputs(single_vector_loss, multi_vector_loss, total_loss)
distillation_loss = None
if self.use_distillation_loss:
distillation_loss = self.distillation_loss(
single_vector_scores,
multi_vector_scores,
teacher_score_upper_bound=query_embeddings.multi_vec_emb.shape[1], # TODO: find the correct upper bound
)
total_loss += self.beta * distillation_loss

return ColPali2LossOutputs(
single_vector_loss=single_vector_loss,
multi_vector_loss=multi_vector_loss,
distillation_loss=distillation_loss,
total_loss=total_loss,
)

0 comments on commit 9453e04

Please sign in to comment.