Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_xla.distributed.parallel_loader doesn't shard data #7904

Closed
davidaknowles opened this issue Aug 23, 2024 · 5 comments · May be fixed by #7914
Closed

torch_xla.distributed.parallel_loader doesn't shard data #7904

davidaknowles opened this issue Aug 23, 2024 · 5 comments · May be fixed by #7914

Comments

@davidaknowles
Copy link

❓ Questions and Help

Maybe this is a misunderstanding on my part, but I assumed part of MpDeviceLoaders job was to split/shard data across devices. However the test below shows it doesn't do this: all 4 devices on my v4 receive all 12 datapoints. What am I missing here? Thanks.

import torch
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def test_parallel_loader(rank):

    data = torch.arange(12).reshape(-1, 1)
    dataset = torch.utils.data.TensorDataset(data)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
    
    parallel_loader = pl.MpDeviceLoader(dataloader, xm.xla_device())
    results = sum([batch[0].tolist() for batch in parallel_loader], [])

    print(f"Device {rank} received data: {results}")
    
    expected_data_size = len(data) // xr.world_size()
    print(f"Device {rank} received {len(results)} datapoints, expected {expected_data_size}")

if __name__ == "__main__":
    xmp.spawn(test_parallel_loader, args=()) 
@JackCaoG
Copy link
Collaborator

Let me take a look

@bhavya01
Copy link
Collaborator

In this case, it is working as expected. The MpDeviceLoader in each spawned process will asynchronously get batches from the torch dataloader and put them on the TPU. If you use SPMD and just spawn one process and specify input_sharding such that it shard on the batch size, then you should see that batch is spread out over multiple devices.

@JackCaoG
Copy link
Collaborator

To add on what @bhavya01 ParallelLoader itself does not handle distributing the correct data to each worker, it is the dataloder it wrapped needs to do that.

In our MP example it is

train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xr.world_size(),
rank=xr.global_ordinal(),
shuffle=True)

doing this work. In the GSPMD case it is the sharding we passed to the ParallelLoader does that distribution. In my example I only used fake data loader so this is not an issue, I can fix that to make it more clear.

@davidaknowles
Copy link
Author

I see, this is just a documentation issue then, e.g. https://pytorch.org/xla/release/r2.4/index.html#running-on-multiple-xla-devices-with-multi-processing gave me the impression wrapping my existing (single process) dataloader was sufficient. And yes it would probably be helpful to have your toy data use DistributedSampler so if people just their own real data in there it will do something sensible.

@JackCaoG
Copy link
Collaborator

yea let me update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants