Skip to content

Commit

Permalink
this will crash after the epoch end and I am not sure why...
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Sep 24, 2024
1 parent 3539e6e commit 669e59f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
30 changes: 12 additions & 18 deletions examples/data_parallel/train_resnet_xla_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,27 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torchvision


class TrainResNetXLADDP(TrainResNetBase):

def __init__(self):
super().__init__()
# below code is commented out because in this example we used a fake data
# loader that does not take sampler. However this logic is needed if you
# want each process to handle different parts of the data.
'''
# for multiprocess we need a sampler
train_sampler = None
fake_dataset = xu.SampleGenerator(
data=(torch.zeros(3, self.img_dim,
self.img_dim), torch.tensor(0, dtype=torch.int64)),
sample_count=self.train_dataset_len)
if xr.world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xr.world_size(),
rank=xr.global_ordinal(),
shuffle=True)
fake_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal())
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS.batch_size,
sampler=train_sampler,
drop_last=FLAGS.drop_last,
shuffle=False if train_sampler else True,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)
fake_dataset, batch_size=self.batch_size, sampler=train_sampler)
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
'''

def run_optimizer(self):
# optimizer_step will call `optimizer.step()` and all_reduce the gradident
Expand All @@ -53,4 +46,5 @@ def _mp_fn(index):
print(
'consider using train_resnet_spmd_data_parallel.py instead to get better performance'
)
torch_xla.launch(_mp_fn, args=())
#torch_xla.launch(_mp_fn, args=())
_mp_fn(0)
3 changes: 3 additions & 0 deletions torch_xla/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def next(self):
self._count += 1
return self._data

def __getitem__(self, index):
return self.next()


class FnDataGenerator(object):

Expand Down

0 comments on commit 669e59f

Please sign in to comment.