-
Notifications
You must be signed in to change notification settings - Fork 20
/
loss.py
31 lines (21 loc) · 1.02 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
import torch.nn.functional as F
class OrthogonalProjectionLoss(nn.Module):
def __init__(self, gamma=0.5):
super(OrthogonalProjectionLoss, self).__init__()
self.gamma = gamma
def forward(self, features, labels=None):
device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))
# features are normalized
features = F.normalize(features, p=2, dim=1)
labels = labels[:, None] # extend dim
mask = torch.eq(labels, labels.t()).bool().to(device)
eye = torch.eye(mask.shape[0], mask.shape[1]).bool().to(device)
mask_pos = mask.masked_fill(eye, 0).float()
mask_neg = (~mask).float()
dot_prod = torch.matmul(features, features.t())
pos_pairs_mean = (mask_pos * dot_prod).sum() / (mask_pos.sum() + 1e-6)
neg_pairs_mean = (mask_neg * dot_prod).sum() / (mask_neg.sum() + 1e-6) # TODO: removed abs
loss = (1.0 - pos_pairs_mean) + self.gamma * neg_pairs_mean
return loss