You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can't find a way to make LanceDataset work with accelerate lib when running training in a distributed setting. Here is the code snippet:
fromioimportBytesIOimportpyarrowaspaimporttorch.utils.dataimporttorchvision.transformsastransformsfromPILimportImagefromlance.torch.dataimportLanceDatasetfromaccelerateimportAcceleratorfromaccelerate.utils.dataclassesimportDataLoaderConfigurationdefto_tensor_fn(batch: pa.RecordBatch, hf_converter=None):
image=batch["image"][0].as_py()
image=transforms.ToTensor()((Image.open(BytesIO(image)).resize((512, 512)).convert("RGB")))
return {'image': image}
uri=""lance_filter="aesthetic_score > 4.5"dataset=LanceDataset(
uri,
columns=["image", "caption"],
filter=lance_filter,
batch_size=1,
to_tensor_fn=to_tensor_fn,
)
dataloader=torch.utils.data.DataLoader(dataset, batch_size=2)
step=0forbatchindataloader:
print(torch.sum(batch['image'][0]))
step+=1print("Total number of batches: ", step)
# prints Total number of batches: 157
So this works fine ... However when you use accelerate and wrap dataloader (run with accelerate launch script.py on > 1 number of processes):
accelerator=Accelerator()
dataloader=accelerator.prepare(dataloader)
step=0forbatchindataloader:
print(torch.sum(batch['image'][0]))
step+=1print("Total number of batches: ", step)
# Running in a distributed setting with 2 workers prints "Total number of batches: 1" x2
I would expect 2 workers with 78 batches each. I am not sure if this problem is related to torch.utils.data.IterableDataset vs accelerate but would be really great if you guys can post example of using lance with accelerate for distributed training.
Setting dispatch_batches=False results in 2 workers where first worker has only 1 batch, and second worker 78 batches.
The text was updated successfully, but these errors were encountered:
andrijazz
changed the title
from lance.torch.data import LanceDataset
Torch LanceDataset doesn't work with HF accelerate
Dec 17, 2024
andrijazz
changed the title
Torch LanceDataset doesn't work with HF accelerate
Torch LanceDataset with HF accelerate for distributed training
Dec 17, 2024
andrijazz
changed the title
Torch LanceDataset with HF accelerate for distributed training
lance.torch.data LanceDataset with HF accelerate for distributed training
Dec 17, 2024
I can't find a way to make LanceDataset work with accelerate lib when running training in a distributed setting. Here is the code snippet:
So this works fine ... However when you use accelerate and wrap dataloader (run with
accelerate launch script.py
on > 1 number of processes):I would expect 2 workers with 78 batches each. I am not sure if this problem is related to
torch.utils.data.IterableDataset
vs accelerate but would be really great if you guys can post example of using lance with accelerate for distributed training.Setting
dispatch_batches=False
results in 2 workers where first worker has only 1 batch, and second worker 78 batches.The text was updated successfully, but these errors were encountered: