forked from nikitakit/sabertooth
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling.py
420 lines (370 loc) · 15.2 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
# Copyright 2020 The Sabertooth Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer models."""
import functools
from typing import Tuple
import flax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.training.checkpoints import restore_checkpoint
from flax.training.common_utils import onehot
from ml_collections import ConfigDict
import layers
ACT2FN = {
"gelu": layers.gelu,
"relu": nn.relu,
"swish": nn.swish,
"gelu_new": nn.gelu,
}
from absl import app, flags
from ml_collections.config_flags import config_flags
FLAGS = flags.FLAGS
def get_hidden_activation(config: ConfigDict):
return ACT2FN[config.hidden_act]
def get_kernel_init(config: ConfigDict):
return layers.truncated_normal_initializer(config.initializer_range)
class BertModel(nn.Module):
"""BERT model without any task-specific heads."""
config: ConfigDict
def setup(self):
self.word_embeddings = nn.Embed(
num_embeddings=self.config.vocab_size,
features=self.config.hidden_size,
embedding_init=get_kernel_init(self.config),
name="word_embeddings",
)
self.position_embeddings = layers.PositionalEncoding(
num_embeddings=self.config.max_position_embeddings,
features=self.config.hidden_size,
embedding_init=get_kernel_init(self.config),
name="position_embeddings",
)
self.type_embeddings = nn.Embed(
num_embeddings=self.config.type_vocab_size,
features=self.config.hidden_size,
embedding_init=get_kernel_init(self.config),
name="type_embeddings",
)
self.embeddings_layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, name="embeddings_layer_norm"
)
self.embeddings_dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
build_feed_forward = functools.partial(
layers.FeedForward,
d_model=self.config.hidden_size,
d_ff=self.config.intermediate_size,
intermediate_activation=get_hidden_activation(self.config),
kernel_init=get_kernel_init(self.config),
)
if self.config.attention_type == "VanillaMHA":
build_self_attention = functools.partial(
layers.SelfAttention,
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=False,
kernel_init=get_kernel_init(self.config),
bias_init=nn.initializers.zeros,
)
else:
build_self_attention = functools.partial(
layers.FastSelfAttention,
hidden_dim=self.config.hidden_size,
head_dim=int(self.config.hidden_size / self.config.num_attention_heads),
num_heads=self.config.num_attention_heads,
dropout=self.config.attention_probs_dropout_prob,
downsampling_k=self.config.downsampling_k,
attention_type=self.config.attention_type,
up_train=self.config.up_train,
num_landmarks=self.config.num_landmarks,
window_size=self.config.window_size,
use_t5_rpe=self.config.use_t5_rpe,
overlap_window=self.config.overlap_window
)
self.encoder_layers = [
layers.TransformerBlock(
build_feed_forward=build_feed_forward,
build_self_attention=build_self_attention,
dropout_rate=self.config.hidden_dropout_prob,
layer_norm_epsilon=self.config.layer_norm_eps,
name=f"encoder_layer_{layer_num}",
)
for layer_num in range(self.config.num_hidden_layers)
]
self.pooler = nn.Dense(
kernel_init=get_kernel_init(self.config),
name="pooler",
features=self.config.hidden_size,
)
def __call__(
self,
input_ids: jnp.ndarray,
input_mask: jnp.ndarray,
type_ids: jnp.ndarray,
switch: bool = True,
*,
deterministic: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Applies BERT model on the inputs."""
word_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(input_ids)
type_embeddings = self.type_embeddings(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self.embeddings_layer_norm(embeddings)
embeddings = self.embeddings_dropout(embeddings, deterministic=deterministic)
hidden_states = embeddings
mask = input_mask.astype(jnp.int32)
for transformer_block in self.encoder_layers:
hidden_states = transformer_block(
hidden_states, mask, switch, deterministic=deterministic
)
pooled_output = self.pooler(hidden_states[:, 0])
pooled_output = jnp.tanh(pooled_output)
return hidden_states, pooled_output
def get_embedding_table(self, **unused_kwargs):
return self.variables["params"]["word_embeddings"]["embedding"]
class GatherIndexes(nn.Module):
"""Gathers the vectors at the specific positions."""
@nn.compact
def __call__(self, sequence_tensor: jnp.ndarray, positions: jnp.ndarray):
"""Applies gather indexes layer.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where
`num_predictions` is maximum number of tokens to mask out and predict
per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
batch_size, seq_length, width = sequence_tensor.shape
flat_offsets = jnp.reshape(jnp.arange(batch_size) * seq_length, [-1, 1])
flat_positions = jnp.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = jnp.reshape(
sequence_tensor, [batch_size * seq_length, width]
)
output_tensor = jnp.take(flat_sequence_tensor, flat_positions, axis=0)
return output_tensor
class BertForSequenceClassification(nn.Module):
"""Bert model for sequence classification."""
config: ConfigDict
n_classes: int
@nn.compact
def __call__(
self,
input_ids: jnp.ndarray,
input_mask: jnp.ndarray,
type_ids: jnp.ndarray,
labels: jnp.ndarray = None,
switch: bool = True,
*,
deterministic: bool = False,
):
"""Applies BERT for sequence classification."""
bert = BertModel(config=self.config, name="bert")
_, pooled_output = bert(
input_ids, input_mask, type_ids, switch, deterministic=deterministic
)
pooled_output = nn.Dropout(
rate=self.config.hidden_dropout_prob, deterministic=deterministic
)(pooled_output)
logits = layers.OutputProjection(
n_out=self.n_classes,
kernel_init=get_kernel_init(self.config),
name="classification",
)(pooled_output)
if labels is None:
return logits
elif logits.shape[-1] == 1:
# Regression task
loss = jnp.mean((logits[..., 0] - labels) ** 2)
return {"loss": loss}
else:
# Classification task
logits = nn.log_softmax(logits)
loss = -jnp.mean(
jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)
)
return {"loss": loss}
@staticmethod
def params_from_checkpoint(model, checkpoint):
"""Initialize params (but not optimizer) from a pre-trained checkpoint."""
restored = restore_checkpoint(checkpoint, target=None)
if "target" in restored:
params = restored["target"] # For old checkpoints using flax.optim
else:
params = restored["params"] # For newer checkpoints using optax
# Delete the masked lm head
del params["predictions_output"]
del params["predictions_transform_dense"]
del params["predictions_transform_layernorm"]
# Re-initialize the output head
# If we switch to a non-static method, flax will complain that we're
# creating an OutputProjection module here.
params["classification"] = layers.OutputProjection(n_out=model.n_classes).init(
jax.random.PRNGKey(np.random.randint(2**16)),
jnp.zeros(
(1, model.config.hidden_size),
dtype=params["classification"]["kernel"].dtype,
),
)["params"]
# Convert any numpy arrays (which live CPU) to JAX DeviceArrays
params = jax.tree_map(jnp.asarray, params)
# Always use a FrozenDict to store params. If different container types are
# used at different places in the code, JAX may need to re-JIT model calls
# because FrozenDict and python dict input structures are not identical.
params = flax.core.freeze(params)
return params
class BertForPreTraining(nn.Module):
"""Bert model for pre-training."""
config: ConfigDict
@nn.compact
def __call__(
self,
input_ids: jnp.ndarray,
input_mask: jnp.ndarray,
type_ids: jnp.ndarray,
masked_lm_positions: jnp.ndarray = None,
masked_lm_labels: jnp.ndarray = None,
masked_lm_weights: jnp.ndarray = None,
next_sentence_labels: jnp.ndarray = None,
switch: bool = True,
*,
deterministic: bool = False,
):
"""Applies BERT for pre-training."""
config = self.config
bert = BertModel(config=config, name="bert")
sequence_output, pooled_output = bert(
input_ids, input_mask, type_ids, switch, deterministic=deterministic
)
if masked_lm_positions is None:
return sequence_output, pooled_output
# Masked LM
masked_lm_input = GatherIndexes()(sequence_output, masked_lm_positions)
masked_lm_input = nn.Dense(
features=config.hidden_size,
kernel_init=get_kernel_init(config),
name="predictions_transform_dense",
)(masked_lm_input)
masked_lm_input = get_hidden_activation(config)(masked_lm_input)
masked_lm_input = nn.LayerNorm(
epsilon=config.layer_norm_eps, name="predictions_transform_layernorm"
)(masked_lm_input)
masked_lm_logits = layers.OutputProjection(name="predictions_output")(
masked_lm_input, bert.get_embedding_table()
)
# Next-sentence prediction
next_sentence_logits = layers.OutputProjection(
n_out=2, kernel_init=get_kernel_init(config), name="classification"
)(pooled_output)
if masked_lm_labels is None or next_sentence_labels is None:
return masked_lm_logits, next_sentence_logits
else:
return self.compute_metrics(
masked_lm_logits,
next_sentence_logits,
masked_lm_labels,
masked_lm_weights,
next_sentence_labels,
)
@staticmethod
def compute_metrics(
masked_lm_logits: jnp.ndarray,
next_sentence_logits: jnp.ndarray,
masked_lm_labels: jnp.ndarray,
masked_lm_weights: jnp.ndarray,
next_sentence_labels: jnp.ndarray,
):
"""Computes the pre-training loss and its components."""
masked_lm_logits = nn.log_softmax(masked_lm_logits)
masked_lm_labels = onehot(
masked_lm_labels.reshape((-1,)), masked_lm_logits.shape[-1]
)
masked_lm_weights = masked_lm_weights.reshape((-1,))
masked_lm_loss = -jnp.sum(
jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) * masked_lm_weights
) / jnp.sum(masked_lm_weights)
next_sentence_logits = nn.log_softmax(next_sentence_logits)
next_sentence_labels = next_sentence_labels.reshape((-1,))
next_sentence_loss = -jnp.mean(
jnp.sum(
onehot(next_sentence_labels, next_sentence_logits.shape[-1])
* next_sentence_logits,
axis=-1,
)
)
return {
"loss": masked_lm_loss + next_sentence_loss,
"masked_lm_loss": masked_lm_loss,
"next_sentence_loss": next_sentence_loss,
}
@staticmethod
def params_from_checkpoint(model, checkpoint):
"""Initialize params (but not optimizer) from a pre-trained checkpoint."""
restored = restore_checkpoint(checkpoint, target=None)
if "target" in restored:
params = restored["target"] # For old checkpoints using flax.optim
else:
params = restored["params"] # For newer checkpoints using optax
# Convert any numpy arrays (which live CPU) to JAX DeviceArrays
params = jax.tree_map(jnp.asarray, params)
# Always use a FrozenDict to store params. If different container types are
# used at different places in the code, JAX may need to re-JIT model calls
# because FrozenDict and python dict input structures are not identical.
params = flax.core.freeze(params)
return params
'''
from jax import random
hidden_dim = 8
head_dim = 4
num_heads = 2
dropout = 0.1
sequence_length = 128
ffn_size = 10
num_layers = 2
vocabulary_size = 10
downsampling_k = 3
batch_size = 2
import configs.pretraining as cf
config = cf.get_config()
modelconfig = config.model
modelconfig.attention_type = "LinEVAMHA"
modelconfig.hidden_size = 8
modelconfig.num_attention_heads = 2
modelconfig.num_landmarks = 16
modelconfig.window_size = 16
modelconfig.use_t5_rpe = True
modelconfig.overlap_window = True
model = BertForPreTraining(modelconfig)
x = jnp.round(random.uniform(random.PRNGKey(44), (batch_size, sequence_length))).astype(jnp.int32)
mask = jnp.ones((batch_size, sequence_length), dtype=jnp.int32)
y = jnp.round(random.uniform(random.PRNGKey(44), (batch_size, sequence_length))).astype(jnp.int32)
param_key = random.PRNGKey(42)
dropout_key = random.PRNGKey(43)
params = model.init(
{'params': param_key, 'dropout': dropout_key},
input_ids=x,
input_mask=mask,
type_ids=y,
deterministic=False,
)["params"]
attn = model.apply({'params': params}, input_ids=x, input_mask=mask, type_ids=y, rngs={'dropout': dropout_key})
print(attn)
'''