Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embedding] Add inf-cl in embedding trainer #9673

Merged
merged 7 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llm/config/qwen/emb_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
"unified_checkpoint": true,
"use_flash_attention": true,
"amp_custom_black_list": "elementwise_div",
"release_grads": true
}
"release_grads": true,
"loss_type": "contrastive"
}
8 changes: 8 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ class EmbeddingArgument:
default=None,
metadata={"help": "The dims for matryoshka training."},
)
loss_type: str = field(
default="contrastive",
metadata={"help": "The type of loss computation."},
)
inf_cl_head_dim: int = field(
default=64,
metadata={"help": "The size of the head dimension when gpu ops are set as 'inf_cl'."},
)
87 changes: 87 additions & 0 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,90 @@
else:
loss = self.loss_fn(q_reps, p_reps)
return loss


class SimpleInfclLoss(nn.Layer):
def __init__(self, inf_cl_head_dim=64):
"""
Initializes the Simple Inf_cl Loss class.

Args:
inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64.
"""
super().__init__()
self.head_dim = inf_cl_head_dim

Check warning on line 77 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L76-L77

Added lines #L76 - L77 were not covered by tests

def forward(self, q_reps, p_reps):
"""
Computes the instance discrimination loss.

Args:
q_reps (Tensor): Query representations.
p_reps (Tensor): key representations.

Returns:
Tensor: The computed loss.
"""
try:
from paddlenlp_kernel.triton.inf_cl import cal_inf_loss
except ImportError:
raise ImportError(

Check warning on line 93 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L90-L93

Added lines #L90 - L93 were not covered by tests
"Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`."
)
group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query
labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries
labels = labels * group_size # Adjust labels based on group size
loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你把import 的代码放到这里吧, 然后没有包的话,直接报错。

try:
    from paddlenlp_kernel.triton.inf_cl import cal_inf_loss
except ImportError:
    logger.warning(
        "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`."
    )

return loss

Check warning on line 100 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L96-L100

Added lines #L96 - L100 were not covered by tests


class MatryoshkaInfclLoss(nn.Layer):
def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64):
"""
Initializes the Matryoshka Inf_cl Loss class.

Args:
embedding_matryoshka_dims (List[int], optional): List of dimensions for Matryoshka embeddings.
If None, no Matryoshka embedding is used. Default is None.
inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64.
"""
super().__init__()
if embedding_matryoshka_dims is None:
self.embedding_matryoshka_dims = []

Check warning on line 115 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L113-L115

Added lines #L113 - L115 were not covered by tests
else:
self.embedding_matryoshka_dims = embedding_matryoshka_dims
self.loss_fn = SimpleInfclLoss(inf_cl_head_dim)

Check warning on line 118 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L117-L118

Added lines #L117 - L118 were not covered by tests

def forward(self, q_reps, p_reps):
"""
Computes the Matryoshka instance discrimination loss.

Args:
q_reps (Tensor): Query representations.
p_reps (Tensor): key representations.

Returns:
Tensor: The computed loss.
"""
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim] # Reduce query representations to the current Matryoshka dimension
reduced_q_reps = nn.functional.normalize(

Check warning on line 135 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L131-L135

Added lines #L131 - L135 were not covered by tests
reduced_q_reps, axis=-1
) # Normalize the reduced query representations along the last axis

reduced_p_reps = p_reps[:, :dim] # Reduce key representations to the current Matryoshka dimension
reduced_p_reps = nn.functional.normalize(

Check warning on line 140 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L139-L140

Added lines #L139 - L140 were not covered by tests
reduced_p_reps, axis=-1
) # Normalize the reduced key representations along the last axis

dim_loss = self.loss_fn(

Check warning on line 144 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L144

Added line #L144 was not covered by tests
reduced_q_reps, reduced_p_reps
) # Compute the loss for the current Matryoshka dimension using the internal loss function
loss += dim_loss

Check warning on line 147 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L147

Added line #L147 was not covered by tests
else:
loss = self.loss_fn(

Check warning on line 149 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L149

Added line #L149 was not covered by tests
q_reps, p_reps
) # If no Matryoshka dimensions are specified, compute the loss using the full representations
return loss

Check warning on line 152 in paddlenlp/transformers/contrastive_loss.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/contrastive_loss.py#L152

Added line #L152 was not covered by tests
18 changes: 14 additions & 4 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import (
MatryoshkaContrastiveLoss,
MatryoshkaInfclLoss,
SimpleContrastiveLoss,
SimpleInfclLoss,
)
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient

Expand All @@ -44,11 +46,19 @@
self.accum_rng_states["hybrid"] = []

if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0:
self.loss_fn = MatryoshkaContrastiveLoss(
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
if model_args.loss_type == "inf_cl":
self.embedding_negatives_cross_device = False
self.loss_fn = MatryoshkaInfclLoss(model_args.embedding_matryoshka_dims, model_args.inf_cl_head_dim)
elif model_args.loss_type == "contrastive":
self.loss_fn = MatryoshkaContrastiveLoss(

Check warning on line 53 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L49-L53

Added lines #L49 - L53 were not covered by tests
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
else:
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)
if model_args.loss_type == "inf_cl":
self.embedding_negatives_cross_device = False
self.loss_fn = SimpleInfclLoss(model_args.inf_cl_head_dim)
elif model_args.loss_type == "contrastive":
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)

Check warning on line 61 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L57-L61

Added lines #L57 - L61 were not covered by tests

def clear_memory(self):
self.accum_q_features.clear()
Expand Down