-
Notifications
You must be signed in to change notification settings - Fork 3
/
reward_trainer.py
757 lines (654 loc) · 35.8 KB
/
reward_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
# Modified from https://github.com/huggingface/trl/blob/2f726ce4e88a99b5d20eca3b5482954851d91ef6/trl/trainer/dpo_trainer.py
# strongly recommend comparing with ./trl/trl/trainer/dpo_trainer.py line-by-line to identify code changes.
# Original License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import warnings
from collections import defaultdict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
DataCollator,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from trl.import_utils import is_peft_available, is_wandb_available
from trl.models import PreTrainedModelWrapper, create_reference_model
from trl.trainer.utils import disable_dropout_in_model, pad_to_length
from data_utils import NCADataCollatorWithPadding
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
if is_deepspeed_available():
import deepspeed
class NCATrainer(Trainer):
r"""
Initialize NCATrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
ref_model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (`float`, defaults to 0.1):
The beta factor in NCA loss. Higher beta means less divergence from the initial policy.
loss_type (`str`, defaults to `"sigmoid"`):
The type of NCA loss to use. Either `"NCA"` or `"InfoNCA"`.
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`NCADataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `0`):
The padding value. This argument is required if you want to use the default data collator.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
tokenizer (`transformers.PreTrainedTokenizerBase`):
The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_target_length (`int`, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
temperature_alpha: float = 1e-3,
loss_type: Literal["InfoNCA", "NCA"] = "InfoNCA",
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
padding_value: int = 0,
truncation_mode: str = "keep_end",
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
max_prompt_length: Optional[int] = None,
max_target_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
is_encoder_decoder: Optional[bool] = None,
disable_dropout: bool = True,
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the NCATrainer. But your model is already instantiated.")
if ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
raise ValueError(
"You passed ref_model_kwargs to the NCATrainer. But your ref_model is already instantiated."
)
if isinstance(model, str):
warnings.warn(
"You passed a model_id to the NCATrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if isinstance(ref_model, str):
warnings.warn(
"You passed a ref model_id to the NCATrainer. This will automatically create an "
"`AutoModelForCausalLM`"
)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
# For models that use gradient_checkpoiting, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = is_encoder_decoder
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
if data_collator is None:
if tokenizer is None:
raise ValueError(
"max_length or a tokenizer must be specified when using the default NCADataCollatorWithPadding"
)
if max_length is None:
warnings.warn(
"When using NCADataCollatorWithPadding, you should set `max_length` in the NCATrainer's init"
" it will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_prompt_length is None:
warnings.warn(
"When using NCADataCollatorWithPadding, you should set `max_prompt_length` in the NCATrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if max_target_length is None and self.is_encoder_decoder:
warnings.warn(
"When using NCADataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the NCATrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_target_length = 128
data_collator = NCADataCollatorWithPadding(
tokenizer,
max_length=max_length,
max_prompt_length=max_prompt_length,
label_pad_token_id=label_pad_token_id,
padding_value=padding_value,
truncation_mode=truncation_mode,
is_encoder_decoder=self.is_encoder_decoder,
max_target_length=max_target_length,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using NCADataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
if disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.max_length = max_length
self.generate_during_eval = generate_during_eval
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value
self.beta = beta
self.temperature_alpha = temperature_alpha
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
if self.ref_model is None:
if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"):
raise ValueError(
"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:#
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
concatenated_batch = {}
if self.is_encoder_decoder:
raise NotImplementedError
else:
max_length = max(batch["A0_input_ids"].shape[1], batch["A1_input_ids"].shape[1], batch["A2_input_ids"].shape[1], batch["A3_input_ids"].shape[1])
concatenated_batch = {
"concatenated_input_ids": torch.cat(
(
pad_to_length(batch["A0_input_ids"], max_length, pad_value=self.padding_value),
pad_to_length(batch["A1_input_ids"], max_length, pad_value=self.padding_value),
pad_to_length(batch["A2_input_ids"], max_length, pad_value=self.padding_value),
pad_to_length(batch["A3_input_ids"], max_length, pad_value=self.padding_value),
),
dim=0,
).to(self.accelerator.device),
"concatenated_attention_mask": torch.cat(
(
pad_to_length(batch["A0_attention_mask"], max_length, pad_value=self.padding_value),
pad_to_length(batch["A1_attention_mask"], max_length, pad_value=self.padding_value),
pad_to_length(batch["A2_attention_mask"], max_length, pad_value=self.padding_value),
pad_to_length(batch["A3_attention_mask"], max_length, pad_value=self.padding_value),
),
dim=0,
).to(self.accelerator.device),
"concatenated_labels": torch.cat(
(
pad_to_length(batch["A0_labels"], max_length, pad_value=self.label_pad_token_id),
pad_to_length(batch["A1_labels"], max_length, pad_value=self.label_pad_token_id),
pad_to_length(batch["A2_labels"], max_length, pad_value=self.label_pad_token_id),
pad_to_length(batch["A3_labels"], max_length, pad_value=self.label_pad_token_id),
),
dim=0,
).to(self.accelerator.device),
}
if self.is_encoder_decoder:
raise NotImplementedError
return concatenated_batch
def nca_loss(
self,
batch,
policy_A0_logps: torch.FloatTensor,
policy_A1_logps: torch.FloatTensor,
policy_A2_logps: torch.FloatTensor,
policy_A3_logps: torch.FloatTensor,
reference_A0_logps: torch.FloatTensor,
reference_A1_logps: torch.FloatTensor,
reference_A2_logps: torch.FloatTensor,
reference_A3_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the NCA loss for a batch of policy and reference model log probabilities.
"""
A0_reward = (policy_A0_logps - reference_A0_logps) * self.beta
A1_reward = (policy_A1_logps - reference_A1_logps) * self.beta
A2_reward = (policy_A2_logps - reference_A2_logps) * self.beta
A3_reward = (policy_A3_logps - reference_A3_logps) * self.beta
# The definition of temperature_alpha here is different from that in the paper (temperature_alpha = 1 / paper_alpha)
# rewards = torch.stack([batch["A0_score"],batch["A1_score"],batch["A2_score"],batch["A3_score"]],dim=-1) / self.temperature_alpha #<bz,4>
# +0.01 here ensures A0 has the highest reward even if r(A1) = r(A0). This is included merely to stay consistent with preference settings. Could be removed.
rewards = torch.stack([batch["A0_score"]+0.01,batch["A1_score"],batch["A2_score"],batch["A3_score"]], dim=-1) / self.temperature_alpha #<bz,4>
softlabel = rewards.softmax(dim=-1) #<bz,4>
model_rewards = torch.stack([A0_reward, A1_reward, A2_reward, A3_reward], dim=-1) #<bz,4>
if self.loss_type == "InfoNCA":
ratio_logits_p = model_rewards.log_softmax(dim=-1)
losses = - (softlabel * ratio_logits_p).sum(dim=-1)
elif self.loss_type == "NCA":
losses = -F.logsigmoid(-model_rewards).mean() - (softlabel * F.logsigmoid(model_rewards)).sum(dim=-1).mean()
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['InfoNCA', 'NCA']")
return losses, A0_reward.detach(), A1_reward.detach(), A2_reward.detach(), A3_reward.detach()
def _get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
if not self.is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != self.label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels[labels == self.label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(batch)
len_chosen = batch["A0_labels"].shape[0]
model_kwargs = (
{
"labels": concatenated_batch["concatenated_labels"],
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
}
if self.is_encoder_decoder
else {}
)
all_logits = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
**model_kwargs,
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=False,
)
A0_logps = all_logps[0*len_chosen:1*len_chosen]
A1_logps = all_logps[1*len_chosen:2*len_chosen]
A2_logps = all_logps[2*len_chosen:3*len_chosen]
A3_logps = all_logps[3*len_chosen:4*len_chosen]
A0_logits = all_logits[0*len_chosen:1*len_chosen]
A1_logits = all_logits[1*len_chosen:2*len_chosen]
A2_logits = all_logits[2*len_chosen:3*len_chosen]
A3_logits = all_logits[3*len_chosen:4*len_chosen]
return (A0_logps, A1_logps, A2_logps, A3_logps, A0_logits, A1_logits, A2_logits, A3_logits)
def get_batch_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the NCA loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
# TODO support arbitrary K option
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
reference_A0_logps, reference_A1_logps, reference_A2_logps, reference_A3_logps, _, _, _, _ = self.concatenated_forward(self.model, batch)
else:
reference_A0_logps, reference_A1_logps, reference_A2_logps, reference_A3_logps, _, _, _, _ = self.concatenated_forward(self.ref_model, batch)
policy_A0_logps, policy_A1_logps, policy_A2_logps, policy_A3_logps, _, _, _, _ = self.concatenated_forward(model, batch)
losses, A0_rewards, A1_rewards, A2_rewards, A3_rewards = self.nca_loss(
batch,
policy_A0_logps,
policy_A1_logps,
policy_A2_logps,
policy_A3_logps,
reference_A0_logps,
reference_A1_logps,
reference_A2_logps,
reference_A3_logps,
)
reward_accuracies = ((A0_rewards > A1_rewards).float() + (A0_rewards > A2_rewards).float() + (A0_rewards > A3_rewards).float())/3.0
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/A0"] = A0_rewards.cpu().mean()
metrics[f"{prefix}rewards/A1"] = A1_rewards.cpu().mean()
metrics[f"{prefix}rewards/A2"] = A2_rewards.cpu().mean()
metrics[f"{prefix}rewards/A3"] = A3_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (A0_rewards - (A1_rewards+A2_rewards+A3_rewards)/3.0).cpu().mean()
metrics[f"{prefix}logps/A0"] = policy_A0_logps.detach().cpu().mean()
metrics[f"{prefix}logps/A1"] = policy_A1_logps.detach().cpu().mean()
metrics[f"{prefix}logps/A2"] = policy_A2_logps.detach().cpu().mean()
metrics[f"{prefix}logps/A3"] = policy_A3_logps.detach().cpu().mean()
return losses.mean(), metrics
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for NCADataCollatorWithPadding, and you passed a datacollator that is different than "
"NCADataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
# force log the metrics
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return loss
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
reference_output = self.model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
reference_output = self.ref_model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
return policy_output_decoded, reference_output_decoded
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for NCADataCollatorWithPadding, and you passed a datacollator that is different than "
"NCADataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")
# force log the metrics
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="eval")
if prediction_loss_only:
return (loss.detach(), None, None)
# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
return (loss.detach(), logits, labels)
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy", "Ref Model"],
rows=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch["prompt"], policy_output_decoded, ref_output_decoded
)
],
)
}
)
self.state.log_history.pop()
# Base evaluation
initial_output = super().evaluation_loop(
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
)
return initial_output
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)