Skip to content

Commit

Permalink
tensor parallel training
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Aug 31, 2023
1 parent 2b13ed1 commit 8e7bf4b
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 28 deletions.
4 changes: 3 additions & 1 deletion onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def build_base_model(model_opt, vocabs):
return model


def build_model(model_opt, opt, vocabs, checkpoint):
def build_model(model_opt, opt, vocabs, checkpoint, device_id):
logger.info("Building model...")

model = build_base_model(model_opt, vocabs)
Expand Down Expand Up @@ -414,6 +414,7 @@ def build_model(model_opt, opt, vocabs, checkpoint):
precision=precision,
device=device,
strict=strict,
device_id=device_id,
)
else:
# weights are not in the .pt checkpoint but stored in the safetensors file
Expand All @@ -425,6 +426,7 @@ def build_model(model_opt, opt, vocabs, checkpoint):
precision=precision,
device=device,
strict=strict,
device_id=device_id,
)
else:
model.to(precision)
Expand Down
2 changes: 2 additions & 0 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def load_state_dict(
# bitsandbytes quantize weights when .cuda() is called
# for huge models we need to save Ram
# so we load the weights module by module and transfer them to GPU for quantization
if device == torch.device("cpu"):
device_id = 0
buf_list = []
for name, module in self.named_modules():
for buf_name, buf in module.named_buffers():
Expand Down
122 changes: 98 additions & 24 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from onmt.modules.lora import lora_state_dict


def build_model_saver(model_opt, opt, model, vocabs, optim):
def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
Expand All @@ -20,6 +20,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim):
optim,
opt.keep_checkpoint,
opt.save_format,
device_id,
)
return model_saver

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
optim,
keep_checkpoint=-1,
save_format="pytorch",
device_id=0,
):
self.base_path = base_path
self.model = model
Expand All @@ -106,6 +108,8 @@ def __init__(
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.device_id = device_id

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
if save_format == "safetensors":
Expand Down Expand Up @@ -135,20 +139,24 @@ def save(self, step, moving_average=None):

self.last_saved_step = step

if moving_average:
for param_data, param in zip(model_params_data, save_model.parameters()):
param.data = param_data
if ckpt_path: # not None when process id 0

if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
self._rm_checkpoint(todel)
if self.save_format == "safetensors":
todel = self.model_queue.popleft()
if moving_average:
for param_data, param in zip(
model_params_data, save_model.parameters()
):
param.data = param_data

if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
self._rm_checkpoint(todel)
self.checkpoint_queue.append(ckpt_path)
if self.save_format == "safetensors":
self.model_queue.append(model_path)
if self.save_format == "safetensors":
todel = self.model_queue.popleft()
self._rm_checkpoint(todel)
self.checkpoint_queue.append(ckpt_path)
if self.save_format == "safetensors":
self.model_queue.append(model_path)

def _save(self, step, model):
"""Save a resumable checkpoint.
Expand Down Expand Up @@ -196,17 +204,83 @@ def _save(self, step, model):
}
generator_state_dict = model.generator.state_dict()

checkpoint = {
"model": model_state_dict,
"generator": generator_state_dict,
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}

logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
torch.save(checkpoint, ckpt_path)
if torch.distributed.is_initialized():
ws = torch.distributed.get_world_size()
else:
ws = 1
if ws > 1:
full_model = [None for _ in range(ws)]
torch.distributed.all_gather_object(full_model, model_state_dict)
fm_sd = {}
for key in full_model[0].keys():
if key.split(".")[-1] == "lora_A":
if key.split(".")[-2] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = (
sum([full_model[i][key].cpu() for i in range(ws)]) / ws
)
elif key.split(".")[-2] in ["final_linear", "w_2"]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 1
)
elif key.split(".")[-1] == "lora_B":
if key.split(".")[-2] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 0
)
elif key.split(".")[-2] in ["final_linear", "w_2"]:
fm_sd[key] = (
sum([full_model[i][key].cpu() for i in range(ws)]) / ws
)
elif key.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 0
)
elif key.split(".")[-1] in ["final_linear", "w_2"]:
fm_sd[key] = torch.cat(
[full_model[i][key].cpu() for i in range(ws)], 1
)

checkpoint = {
"model": fm_sd,
"generator": generator_state_dict,
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}
else:
checkpoint = {
"model": model_state_dict,
"generator": generator_state_dict,
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
torch.save(checkpoint, ckpt_path)
else:
ckpt_path = None
if torch.distributed.is_initialized():
torch.distributed.barrier()
return ckpt_path, None

def _st_save(self, step, model):
Expand Down
4 changes: 2 additions & 2 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(opt, device_id):
model_opt = _get_model_opts(opt, checkpoint=checkpoint)

# Build model.
model = build_model(model_opt, opt, vocabs, checkpoint)
model = build_model(model_opt, opt, vocabs, checkpoint, device_id)

model.count_parameters(log=logger.info)
trainable = {
Expand Down Expand Up @@ -196,7 +196,7 @@ def main(opt, device_id):
del checkpoint

# Build model saver
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim)
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id)

trainer = build_trainer(
opt, device_id, model, vocabs, optim, model_saver=model_saver
Expand Down
2 changes: 1 addition & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None):
parallel_mode,
report_manager,
with_align=True if opt.lambda_align > 0 else False,
model_saver=model_saver if gpu_rank <= 0 else None,
model_saver=model_saver,
average_decay=average_decay,
average_every=average_every,
model_dtype=opt.model_dtype,
Expand Down

0 comments on commit 8e7bf4b

Please sign in to comment.