-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Embedding] Add inf-cl in embedding trainer #9673
[Embedding] Add inf-cl in embedding trainer #9673
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9673 +/- ##
===========================================
- Coverage 53.18% 52.76% -0.43%
===========================================
Files 718 718
Lines 113340 112338 -1002
===========================================
- Hits 60282 59276 -1006
- Misses 53058 53062 +4 ☔ View full report in Codecov by Sentry. |
__all__ = ["Simple_Inf_cl_loss", "Matryoshka_Inf_cl_loss"] | ||
|
||
|
||
class Simple_Inf_cl_loss(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一些注释
paddlenlp/trl/embedding_trainer.py
Outdated
@@ -18,6 +18,10 @@ | |||
from paddle.base import core | |||
from paddle.distributed import fleet | |||
|
|||
from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import ( | |
from paddlenlp_kernel.triton.inf_cl.inf_cl_loss import ( |
paddlenlp/trl/embedding_trainer.py
Outdated
@@ -18,6 +18,10 @@ | |||
from paddle.base import core | |||
from paddle.distributed import fleet | |||
|
|||
from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有默认安装,需要 try except一下
group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query | ||
labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries | ||
labels = labels * group_size # Adjust labels based on group size | ||
loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你把import 的代码放到这里吧, 然后没有包的话,直接报错。
try:
from paddlenlp_kernel.triton.inf_cl import cal_inf_loss
except ImportError:
logger.warning(
"Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`."
)
PR types
Function optimization
PR changes
Others
Description
在embedding训练中增加inf_cl_loss,在超大batch_size下能有效节省显存消耗。
经测试,inf-cl算子能够与原有损失函数有效对齐:
经测试,在超大batch_size下,inf-cl算子能够有效降低embedding训练时的显存消耗:
42526MiB;42470MiB;
42470MiB;42526MiB;
42526MiB;42182MiB
28372MiB;28308MiB;
28320MiB;28384MiB;
28316MiB;28070MiB
44926MiB;45180MiB;
44674MiB;45022MiB;
45032MiB;44904MiB