Skip to content

Commit

Permalink
Raise error when minibatch is used in SPMD dataloading and per host b… (
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Oct 7, 2024
1 parent 14d14ae commit e3cf356
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
67 changes: 67 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import torch_xla.distributed.parallel_loader as pl
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
Expand Down Expand Up @@ -1310,6 +1311,72 @@ def test_get_1d_mesh(self):
self.assertEqual(mesh_without_name.mesh_shape,
(xr.global_runtime_device_count(),))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for dataloader sharding test")
def test_data_loader_with_sharding(self):
device = torch_xla.device()
mesh = xs.get_1d_mesh("data")
batch_size = 8
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 3, 64,
64), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=100)
train_device_loader = pl.MpDeviceLoader(
train_loader,
device,
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for dataloader sharding test")
def test_data_loader_with_non_batch_size(self):
device = torch_xla.device()
mesh = xs.get_1d_mesh("data")
batch_size = mesh.size() - 1
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 3, 64,
64), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=100)
train_device_loader = pl.MpDeviceLoader(
train_loader,
device,
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for dataloader sharding test")
def test_data_loader_with_non_batch_size_and_mini_batch(self):
device = torch_xla.device()
mesh = xs.get_1d_mesh("data")
batch_size = mesh.size() - 1
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 3, 64,
64), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=100)
train_device_loader = pl.MpDeviceLoader(
train_loader,
device,
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
input_sharding=xs.ShardingSpec(
mesh, ('data', None, None, None), minibatch=True))
with self.assertRaisesRegex(
RuntimeError,
"When minibatch is configured, batch dimension of the tensor must be divisible by local runtime device count*"
):
data, _ = iter(train_device_loader).__next__()


if __name__ == '__main__':
test = unittest.main()
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,22 @@ def convert_fn(tensors):
shardings = None
if input_sharding:
shardings = [input_sharding.xla_spec(t) for t in tensors]
if input_sharding and input_sharding.minibatch:
# when minibatch is configured we must make sure batch dimension of
# the tensor is divisible by the local runtime device count.
for tensor, sharding in zip(tensors, shardings):
# assume batch dimension is 0
local_runtime_device_count = torch_xla.runtime.addressable_runtime_device_count(
)
if sharding and tensor.dim() > 0 and (tensor.size()[0] %
local_runtime_device_count) != 0:
raise RuntimeError(
"When minibatch is configured, batch dimension of the tensor " +
"must be divisible by local runtime device count.input data shape "
+
f"={tensor.size()}, local_runtime_device_count = {local_runtime_device_count}"
)

xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
shardings)
return xtensors
Expand Down

0 comments on commit e3cf356

Please sign in to comment.