Skip to content

Commit

Permalink
add DDP with SPMD example (#7063)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 15, 2024
1 parent a8eae0d commit c6074ab
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
9 changes: 5 additions & 4 deletions examples/train_resnet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
class TrainResNetBase():

def __init__(self):
img_dim = 224
self.img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs = 1
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
# For the purpose of this example, we are going to use fake data.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // self.batch_size // xr.world_size())
sample_count=self.train_dataset_len // self.batch_size //
xr.world_size())

self.device = torch_xla.device()
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
Expand Down
45 changes: 45 additions & 0 deletions examples/train_resnet_spmd_data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from train_resnet_base import TrainResNetBase

import numpy as np

import torch
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
from torch_xla import runtime as xr

# Enable the SPMD
xr.use_spmd()


# More detailed examaple can be found in https://github.com/pytorch/xla/blob/master/test/spmd/test_train_spmd_imagenet.py
# Check out our user guide in https://github.com/pytorch/xla/blob/master/docs/spmd.md
class TrainResNetXLASpmdDDP(TrainResNetBase):

def __init__(self):
super().__init__()
# Shard along batch dimension only
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
mesh_shape = (num_devices,)
mesh = xs.Mesh(device_ids, mesh_shape, ('data',))
# scale the batch size with num_devices since there will be only one
# process that handles all runtime devices.
self.batch_size *= num_devices

train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=self.train_dataset_len // self.batch_size)
self.train_device_loader = pl.MpDeviceLoader(
train_loader,
self.device,
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))


if __name__ == '__main__':
spmd_ddp = TrainResNetXLASpmdDDP()
spmd_ddp.start_training()

0 comments on commit c6074ab

Please sign in to comment.