forked from luyug/COIL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling.py
215 lines (180 loc) · 9.06 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import torch
from torch import nn
import torch.distributed as dist
from transformers import PreTrainedModel, TrainingArguments
from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling
from arguments import ModelArguments, DataArguments
from torch import Tensor
from typing import Dict, List, Tuple, Iterable
import os
import logging
from torch.cuda.amp import autocast
logger = logging.getLogger(__name__)
class COIL(nn.Module):
def __init__(self, model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments,
train_args: TrainingArguments):
super().__init__()
self.model: PreTrainedModel = model
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.data_args, self.model_args, self.train_args = data_args, model_args, train_args
self.tok_proj = nn.Linear(768, model_args.token_dim)
self.cls_proj = nn.Linear(768, model_args.cls_dim)
if model_args.token_norm_after:
self.ln_tok = nn.LayerNorm(model_args.token_dim)
if model_args.cls_norm_after:
self.ln_cls = nn.LayerNorm(model_args.cls_dim)
@classmethod
def from_pretrained(
cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments,
*args, **kwargs
):
hf_model = AutoModel.from_pretrained(*args, **kwargs)
model = COIL(hf_model, model_args, data_args, train_args)
path = args[0]
if os.path.exists(os.path.join(path, 'model.pt')):
logger.info('loading extra weights from local files')
model_dict = torch.load(os.path.join(path, 'model.pt'), map_location="cpu")
load_result = model.load_state_dict(model_dict, strict=False)
return model
def save_pretrained(self, output_dir: str):
self.model.save_pretrained(output_dir)
model_dict = self.state_dict()
hf_weight_keys = [k for k in model_dict.keys() if k.startswith('model')]
for k in hf_weight_keys:
model_dict.pop(k)
torch.save(model_dict, os.path.join(output_dir, 'model.pt'))
torch.save([self.data_args, self.model_args, self.train_args], os.path.join(output_dir, 'args.pt'))
def encode(self, **features):
assert all([x in features for x in ['input_ids', 'attention_mask', 'token_type_ids']])
model_out: BaseModelOutputWithPooling = self.model(**features, return_dict=True)
cls = self.cls_proj(model_out.last_hidden_state[:, 0])
reps = self.tok_proj(model_out.last_hidden_state)
if self.model_args.cls_norm_after:
cls = self.ln_cls(cls)
if self.model_args.token_norm_after:
reps = self.ln_tok(reps)
if self.model_args.token_rep_relu:
reps = torch.relu(reps)
return cls, reps
def forward(self, qry_input: Dict, doc_input: Dict):
qry_out: BaseModelOutputWithPooling = self.model(**qry_input, return_dict=True)
doc_out: BaseModelOutputWithPooling = self.model(**doc_input, return_dict=True)
qry_cls = self.cls_proj(qry_out.last_hidden_state[:, 0])
doc_cls = self.cls_proj(doc_out.last_hidden_state[:, 0])
qry_reps = self.tok_proj(qry_out.last_hidden_state) # Q * LQ * d
doc_reps = self.tok_proj(doc_out.last_hidden_state) # D * LD * d
if self.model_args.cls_norm_after:
qry_cls, doc_cls = self.ln_cls(qry_cls), self.ln_cls(doc_cls)
if self.model_args.token_norm_after:
qry_reps, doc_reps = self.ln_tok(qry_reps), self.ln_tok(doc_reps)
if self.model_args.token_rep_relu:
qry_reps = torch.relu(qry_reps)
doc_reps = torch.relu(doc_reps)
# mask ingredients
doc_input_ids: Tensor = doc_input['input_ids']
qry_input_ids: Tensor = qry_input['input_ids']
qry_attention_mask: Tensor = qry_input['attention_mask']
self.mask_sep(qry_attention_mask)
if not self.training:
# in testing phase, we have Q == D
assert doc_input_ids.size(0) == qry_input_ids.size(0), \
'we expect same number of query/doc'
tok_scores = self.compute_tok_score_pair(
doc_reps, doc_input_ids,
qry_reps, qry_input_ids, qry_attention_mask
)
# compute cls score separately
cls_scores = (qry_cls * doc_cls).sum(-1)
# sum the scores
if self.model_args.no_cls:
scores = tok_scores
elif self.model_args.cls_only:
scores = cls_scores
else:
if self.train_args.fp16:
with autocast(False):
scores = tok_scores.float() + cls_scores.float() # B
else:
scores = tok_scores + cls_scores # B
# loss not defined during inference
return scores.view(-1)
else:
# for training phase, we have D = Q * group_size
if self.model_args.x_device_negatives:
# the idea is simple
# fake it as if everything is on current device
# gradient is taken care of at reduction time
doc_input_ids, doc_cls, doc_reps = self.gather_tensors(
doc_input_ids, doc_cls, doc_reps)
qry_input_ids, qry_attention_mask, qry_cls, qry_reps = self.gather_tensors(
qry_input_ids, qry_attention_mask, qry_cls, qry_reps)
# qry_reps: Q * LQ * d
# doc_reps: D * LD * d
tok_scores = self.compute_tok_score_cart(
doc_reps, doc_input_ids,
qry_reps, qry_input_ids, qry_attention_mask
)
# remove padding and cls token
if self.model_args.no_cls:
scores = tok_scores
elif self.model_args.cls_only:
scores = torch.matmul(qry_cls, doc_cls.transpose(1, 0)) # Q * D
else:
cls_scores = torch.matmul(qry_cls, doc_cls.transpose(1, 0)) # Q * D
with autocast(False):
scores = tok_scores.float() + cls_scores.float() # Q * D
labels = torch.arange(
scores.size(0),
device=doc_input['input_ids'].device,
dtype=torch.long
)
# offset the labels
labels = labels * self.data_args.train_group_size
loss = self.cross_entropy(scores, labels)
return loss, scores.view(-1)
def mask_sep(self, qry_attention_mask):
if self.model_args.no_sep:
sep_pos = qry_attention_mask.sum(1).unsqueeze(1) - 1 # the sep token position
_zeros = torch.zeros_like(sep_pos)
qry_attention_mask.scatter_(1, sep_pos.long(), _zeros)
return qry_attention_mask
def compute_tok_score_pair(self, doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask):
exact_match = qry_input_ids.unsqueeze(2) == doc_input_ids.unsqueeze(1) # B * LQ * LD
exact_match = exact_match.float()
# qry_reps: B * LQ * d
# doc_reps: B * LD * d
scores_no_masking = torch.bmm(qry_reps, doc_reps.permute(0, 2, 1)) # B * LQ * LD
if self.model_args.pooling == 'max':
tok_scores, _ = (scores_no_masking * exact_match).max(dim=2) # B * LQ
else:
raise NotImplementedError(f'{self.model_args.pooling} pooling is not defined')
# remove padding and cls token
tok_scores = (tok_scores * qry_attention_mask)[:, 1:].sum(-1)
return tok_scores
def compute_tok_score_cart(self, doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask):
qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1
doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD
exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD
exact_match = exact_match.float()
scores_no_masking = torch.matmul(
qry_reps.view(-1, self.model_args.token_dim), # (Q * LQ) * d
doc_reps.view(-1, self.model_args.token_dim).transpose(0, 1) # d * (D * LD)
)
scores_no_masking = scores_no_masking.view(
*qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD
# scores_no_masking = scores_no_masking.permute(0, 2, 1, 3) # Q * D * LQ * LD
if self.model_args.pooling == 'max':
scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D
else:
raise NotImplementedError(f'{self.model_args.pooling} pooling is not defined')
tok_scores = (scores * qry_attention_mask.unsqueeze(2))[:, 1:].sum(1)
return tok_scores
def _gather_tensor(self, t: Tensor):
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(all_tensors, t)
all_tensors[self.train_args.local_rank] = t
return all_tensors
def gather_tensors(self, *tt: Tensor):
tt = [torch.cat(self._gather_tensor(t)) for t in tt]
return tt