Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch dataset use multiprocessing #1383

Closed
albertz opened this issue Aug 22, 2023 · 3 comments
Closed

PyTorch dataset use multiprocessing #1383

albertz opened this issue Aug 22, 2023 · 3 comments

Comments

@albertz
Copy link
Member

albertz commented Aug 22, 2023

Now that we use DataLoader (v1) again (c0ac991, fixed #1382), we can directly use the num_workers option.

In 5b569b3, I added the config option torch_dataloader_opts, which you can set like:

torch_dataloader_opts = dict(num_workers=1)

num_workers = 1 will use a single worker only, but then this should just behave as before w.r.t. to the epoch size. Otherwise (num_workers > 1) it would duplicate the data over the workers, because we do not do any sharding. (If this is an issue for you, please open a new issue about it.)

num_workers = 1 should in principle be also fast enough in most cases (in all our TF-based experiments, we also only had a single worker, and computation time was always close to 100%, i.e. the dataset was never a bottleneck).

The computation time should tell you in the end if you have a bottleneck with the dataset or not. Now with this option, I also see around 98% computation time (on demo-rf-pt-benchmark with num_workers=1).

Note that you can also use DataLoader num_workers=1 and additionally use MultiProcDataset with a higher num workers, because MultiProcDataset does handle the sharding correctly.


For TorchData DataLoader2:

I think it needs some code like this:

rs = MultiProcessingReadingService(num_workers=4)
dl = DataLoader2(datapipe, reading_service=rs)
for epoch in range(10):
    dl.seed(epoch)
    for d in dl:
        model(d)
dl.shutdown()

Via.

(Related is #1382, however, until this is resolved, we should probably anyway implement this here for now.)

Some things which need to be clarified:

  • Is this enough? Does each worker gets an individual seed? Is this all correct?
  • So it means that an epoch is actually num_workers times more data now? Just like in DDP training? -> Yes. But there is sth like sharding_filter.

Also note that we have another alternative: MultiProcDataset. This one keeps the original epochs, i.e. it implements sharding.


Some options to implement the sharding logic:

  • sharding_filter, how does this work?
  • Implement sharding somewhere inside the data loader, data iterator, dataset wrapper. E.g. in ReturnnDatasetIterDataPipe, ReturnnDatasetResetDefaultEpochCounterCallback or ReturnnDatasetResetMpSharedEpochCallback.
  • Make a new ShardingDataset wrapping dataset, similar like MultiProcDataset but without the multi-proc logic.
  • Implement the sharding logic inside MetaDataset. (But I don't like to extend MetaDataset more and more by such unrelated logic... I prefer to have this separate.)
  • Implement sharding in the Dataset.get_seq_order_for_epoch.
  • Automatically apply a factor (1/num_workers) to Dataset.partition_epoch.
  • Implement sharding individually in each dataset.

In all cases, the user might actually be interested in switching between the logic, just like horovod_dataset_distribution="random_seed_offset" vs horovod_dataset_distribution="shard".

@flixxox
Copy link

flixxox commented Aug 23, 2023

I think your code sinppet looks good. Only if DDP and multiple workers are used rs should be replaced by

# DDP and Multiple workers
rs = SequentialReadingService(
    DistributedReadingService()
    MultiProcessingReadingService(num_workers=num_workers)
)

Additionally, you have to add a sharding_filter datapipe somewhere into the data graph. From this point on the data is distributed across processes and workers. It is typically preceeded by a .shuffle datapipe and should be as early as possible in the dataloading process.

So it means that an epoch is actually num_workers times more data now? Just like in DDP training?

Yes, if you are missing the sharding_filter your data will be duplicated. If not, you have num_workers*num_ddp_processes that load fractions of the dataset. So on a single process for d in dl returns the same number samples.

@albertz
Copy link
Member Author

albertz commented Aug 24, 2023

I think your code sinppet looks good.

This is just copied & pasted from the original docs (what I linked), so of course, I would hope that this snippet should be ok.

Have you tried that in RETURNN? Why did you not make a PR?

Yes, if you are missing the sharding_filter your data will be duplicated.

But they will all get a different random seed, as we do it with DDP, so it's not really a problem, right? It just changes what an "epoch" means now. I just wanted to verify this.

If you actually use sharding_filter, then it would be wrong currently, right? Because every DDP worker uses a different random seed.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2023

Now that we use DataLoader (v1) again (c0ac991, fixed #1382), we can directly use the num_workers option.

In 5b569b3, I added the config option torch_dataloader_opts, which you can set like:

torch_dataloader_opts = dict(num_workers=1)

num_workers = 1 will use a single worker only, but then this should just behave as before w.r.t. to the epoch size. Otherwise it would duplicate the data over the workers, because we do not do any sharding. num_workers = 1 should in principle be also fast enough in most cases (in all our TF-based experiments, we also only had a single worker, and computation time was always close to 100%, i.e. the dataset was never a bottleneck).

The computation time should tell you in the end if you have a bottleneck with the dataset or not. Now with this option, I also see around 98% computation time (on demo-rf-pt-benchmark with num_workers=1).

So I guess this issue can be closed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants