diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py index 5b5cdb92d69..b66780d5cd9 100644 --- a/examples/train_resnet_base.py +++ b/examples/train_resnet_base.py @@ -16,16 +16,17 @@ class TrainResNetBase(): def __init__(self): - img_dim = 224 + self.img_dim = 224 self.batch_size = 128 self.num_steps = 300 self.num_epochs = 1 - train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. + self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. # For the purpose of this example, we are going to use fake data. train_loader = xu.SampleGenerator( - data=(torch.zeros(self.batch_size, 3, img_dim, img_dim), + data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim), torch.zeros(self.batch_size, dtype=torch.int64)), - sample_count=train_dataset_len // self.batch_size // xr.world_size()) + sample_count=self.train_dataset_len // self.batch_size // + xr.world_size()) self.device = torch_xla.device() self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) diff --git a/examples/train_resnet_spmd_data_parallel.py b/examples/train_resnet_spmd_data_parallel.py new file mode 100644 index 00000000000..7aa53a7bf9a --- /dev/null +++ b/examples/train_resnet_spmd_data_parallel.py @@ -0,0 +1,45 @@ +from train_resnet_base import TrainResNetBase + +import numpy as np + +import torch +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs +import torch_xla.distributed.parallel_loader as pl +import torch_xla.utils.utils as xu +from torch_xla import runtime as xr + +# Enable the SPMD +xr.use_spmd() + + +# More detailed examaple can be found in https://github.com/pytorch/xla/blob/master/test/spmd/test_train_spmd_imagenet.py +# Check out our user guide in https://github.com/pytorch/xla/blob/master/docs/spmd.md +class TrainResNetXLASpmdDDP(TrainResNetBase): + + def __init__(self): + super().__init__() + # Shard along batch dimension only + num_devices = xr.global_runtime_device_count() + device_ids = np.arange(num_devices) + mesh_shape = (num_devices,) + mesh = xs.Mesh(device_ids, mesh_shape, ('data',)) + # scale the batch size with num_devices since there will be only one + # process that handles all runtime devices. + self.batch_size *= num_devices + + train_loader = xu.SampleGenerator( + data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim), + torch.zeros(self.batch_size, dtype=torch.int64)), + sample_count=self.train_dataset_len // self.batch_size) + self.train_device_loader = pl.MpDeviceLoader( + train_loader, + self.device, + # Shard the input's batch dimension along the `data` axis, no sharding along other dimensions + input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) + + +if __name__ == '__main__': + spmd_ddp = TrainResNetXLASpmdDDP() + spmd_ddp.start_training()