Skip to content
This repository has been archived by the owner on Jul 22, 2024. It is now read-only.

Adding AMP support in pretraining. #102

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions oscar/run_oscarplus_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def main():
# Every args.ckpt_period, report train_score and save model
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
scaler = torch.cuda.amp.GradScaler(enabled=True)
for step, (batch, batch_extra) in enumerate(zip(train_dataloader, train_dataloader_extra), start_iter):
if not clock_started:
start_training_time = time.time()
Expand Down Expand Up @@ -391,17 +392,19 @@ def forward_backward(images, input_ids, input_mask, segment_ids,
# feature as input
image_features = torch.stack(images).to(args.device, non_blocking=True)

outputs = model(input_ids, segment_ids, input_mask,
lm_label_ids, is_next, img_feats=image_features)
with torch.cuda.amp.autocast(enabled=True):
outputs = model(input_ids, segment_ids, input_mask,
lm_label_ids, is_next, img_feats=image_features)

loss = loss_weight * outputs[0]
loss = loss_weight * outputs[0]

if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.

if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
scaler.scale(loss).backward()
# loss.backward()

return loss.item(), input_ids.size(0)

Expand Down Expand Up @@ -436,7 +439,9 @@ def forward_backward(images, input_ids, input_mask, segment_ids,
if args.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
# do the optimization steps
optimizer.step()
# optimizer.step()
scaler.step(optimizer)
scaler.update()
scheduler.step() # Update learning rate schedule
optimizer.zero_grad()

Expand Down