Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lance.torch.data LanceDataset with HF accelerate for distributed training #3262

Open
andrijazz opened this issue Dec 17, 2024 · 1 comment
Open

Comments

@andrijazz
Copy link

andrijazz commented Dec 17, 2024

I can't find a way to make LanceDataset work with accelerate lib when running training in a distributed setting. Here is the code snippet:

from io import BytesIO

import pyarrow as pa
import torch.utils.data
import torchvision.transforms as transforms
from PIL import Image
from lance.torch.data import LanceDataset
from accelerate import Accelerator
from accelerate.utils.dataclasses import DataLoaderConfiguration


def to_tensor_fn(batch: pa.RecordBatch, hf_converter=None):
    image = batch["image"][0].as_py()
    image = transforms.ToTensor()((Image.open(BytesIO(image)).resize((512, 512)).convert("RGB")))
    return {'image': image}

uri = ""
lance_filter = "aesthetic_score > 4.5"
dataset = LanceDataset(
    uri,
    columns=["image", "caption"],
    filter=lance_filter,
    batch_size=1,
    to_tensor_fn=to_tensor_fn,
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)

step = 0
for batch in dataloader:
    print(torch.sum(batch['image'][0]))
    step += 1
print("Total number of batches: ", step)
# prints Total number of batches: 157

So this works fine ... However when you use accelerate and wrap dataloader (run with accelerate launch script.py on > 1 number of processes):

accelerator = Accelerator()
dataloader = accelerator.prepare(dataloader)
step = 0
for batch in dataloader:
    print(torch.sum(batch['image'][0]))
    step += 1
print("Total number of batches: ", step) 

# Running in a distributed setting with 2 workers prints "Total number of batches: 1" x2 

I would expect 2 workers with 78 batches each. I am not sure if this problem is related to torch.utils.data.IterableDataset vs accelerate but would be really great if you guys can post example of using lance with accelerate for distributed training.

Setting dispatch_batches=False results in 2 workers where first worker has only 1 batch, and second worker 78 batches.

@andrijazz andrijazz changed the title from lance.torch.data import LanceDataset Torch LanceDataset doesn't work with HF accelerate Dec 17, 2024
@andrijazz andrijazz changed the title Torch LanceDataset doesn't work with HF accelerate Torch LanceDataset with HF accelerate for distributed training Dec 17, 2024
@andrijazz andrijazz changed the title Torch LanceDataset with HF accelerate for distributed training lance.torch.data LanceDataset with HF accelerate for distributed training Dec 17, 2024
@wjones127
Copy link
Contributor

accelerate but would be really great if you guys can post example of using lance with accelerate for distributed training.

I don't know if any of us are familiar with the accelerate framework. But if someone is, an example would indeed be helpful.

We have some example snippets here https://lancedb.github.io/lance/integrations/pytorch.html that show using ShardedFragmentSampler for distributed training. Maybe you can adapt that into something using accelerate?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants