Skip to content

Commit

Permalink
fix lora merge with TP and avoid OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Sep 6, 2023
1 parent 1d409f5 commit 2ba42e8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
4 changes: 4 additions & 0 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def _save(self, step, model):
ws = 1
if ws > 1:
full_model = [None for _ in range(ws)]
for key, value in model_state_dict.items():
model_state_dict[key] = value.cpu()
torch.distributed.all_gather_object(full_model, model_state_dict)
fm_sd = {}
for key in full_model[0].keys():
Expand Down Expand Up @@ -297,6 +299,8 @@ def _st_save(self, step, model):
ws = 1
if ws > 1:
full_model = [None for _ in range(ws)]
for key, value in model_state_dict.items():
model_state_dict[key] = value.cpu()
torch.distributed.all_gather_object(full_model, model_state_dict)
fm_sd = {}
for key in full_model[0].keys():
Expand Down
2 changes: 1 addition & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _maybe_report_training(self, step, num_steps, learning_rate, report_stats):
if self.earlystopper is None
else self.earlystopper.current_tolerance,
report_stats,
multigpu=self.n_gpu > 1,
multigpu=self.n_gpu > 1 and self.parallel_mode == "data_parallel",
)

def _report_step(self, learning_rate, step, valid_stats=None, train_stats=None):
Expand Down
1 change: 1 addition & 0 deletions tools/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
lora_opt = lora_checkpoint["opt"]

lora_opt.quant_layers = [] # we need to remove any quantization to merge weights
lora_opt.parallel_mode= 'data_parallel'

model = build_base_model(lora_opt, vocabs)

Expand Down

0 comments on commit 2ba42e8

Please sign in to comment.