-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch_xla.distributed.parallel_loader doesn't shard data #7904
Comments
Let me take a look |
In this case, it is working as expected. The MpDeviceLoader in each spawned process will asynchronously get batches from the torch dataloader and put them on the TPU. If you use SPMD and just spawn one process and specify |
To add on what @bhavya01 ParallelLoader itself does not handle distributing the correct data to each worker, it is the dataloder it wrapped needs to do that. In our MP example it is xla/test/test_train_mp_imagenet.py Lines 222 to 226 in 3d860bf
doing this work. In the GSPMD case it is the sharding we passed to the ParallelLoader does that distribution. In my example I only used fake data loader so this is not an issue, I can fix that to make it more clear.
|
I see, this is just a documentation issue then, e.g. https://pytorch.org/xla/release/r2.4/index.html#running-on-multiple-xla-devices-with-multi-processing gave me the impression wrapping my existing (single process) dataloader was sufficient. And yes it would probably be helpful to have your toy data use DistributedSampler so if people just their own real data in there it will do something sensible. |
yea let me update |
❓ Questions and Help
Maybe this is a misunderstanding on my part, but I assumed part of
MpDeviceLoader
s job was to split/shard data across devices. However the test below shows it doesn't do this: all 4 devices on my v4 receive all 12 datapoints. What am I missing here? Thanks.The text was updated successfully, but these errors were encountered: