Skip to content

Commit

Permalink
fix multi gpu valid for data_parallel (OpenNMT#2534)
Browse files Browse the repository at this point in the history
* fix multi gpu valid for data_parallel
* timeout as an option
  • Loading branch information
vince62s authored Dec 12, 2023
1 parent c568bc5 commit f0bd36f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ def distributed_opts(parser):
type=int,
help="Port of master for torch.distributed training.",
)
group.add(
"--timeout",
"-timeout",
default=60,
type=int,
help="Timeout for one GOU to wait for the others.",
)


def model_opts(parser):
Expand Down
7 changes: 4 additions & 3 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,10 @@ def train(
)

if valid_iter is not None and step % valid_steps == 0:
valid_stats = self.validate(
valid_iter, moving_average=self.moving_average
)
if self.parallel_mode == "tensor_parallel" or self.gpu_rank <= 0:
valid_stats = self.validate(
valid_iter, moving_average=self.moving_average
)

if step % valid_steps == 0 and self.gpu_rank <= 0:
self._report_step(
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def multi_init(opt, device_id):
init_method=dist_init_method,
world_size=dist_world_size,
rank=opt.gpu_ranks[device_id],
timeout=timedelta(seconds=60),
timeout=timedelta(seconds=opt.timeout),
)
gpu_rank = torch.distributed.get_rank()
if not is_master(opt, device_id):
Expand Down

0 comments on commit f0bd36f

Please sign in to comment.