Skip to content

Commit

Permalink
Add a typical data loading pipeline path for the EfficeintNet
Browse files Browse the repository at this point in the history
- adds an option to run the EfficeintNet network with a typical
  data loading pipeline without very advanced optimization that
  most users won't implement

Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL committed Dec 20, 2024
1 parent 8535a3b commit 4f7e994
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@
import torchvision.datasets as datasets
import torchvision.transforms as transforms


def load_jpeg_from_file(path, cuda=True):
img_transforms = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
)

img = img_transforms(Image.open(path))
Expand All @@ -74,6 +79,7 @@ def load_jpeg_from_file(path, cuda=True):

return input


class DALIWrapper(object):

def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format):
Expand All @@ -90,8 +96,11 @@ def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format):
def nhwc_to_nchw(t):
return t[0], t[3], t[1], t[2]

input = torch.as_strided(data[0]["data"], size=nhwc_to_nchw(shape),
stride=nhwc_to_nchw(stride))
input = torch.as_strided(
data[0]["data"],
size=nhwc_to_nchw(shape),
stride=nhwc_to_nchw(stride),
)
else:
input = data[0]["data"].contiguous(memory_format=memory_format)
target = torch.reshape(data[0]["label"], [-1]).cuda().long()
Expand All @@ -108,7 +117,10 @@ def __init__(self, dalipipeline, num_classes, one_hot, memory_format):

def __iter__(self):
return DALIWrapper.gen_wrapper(
self.dalipipeline, self.num_classes, self.one_hot, self.memory_format
self.dalipipeline,
self.num_classes,
self.one_hot,
self.memory_format,
)


Expand Down Expand Up @@ -145,16 +157,23 @@ def gdtl(
traindir = os.path.join(data_path, "train")

pipeline_kwargs = {
"batch_size" : batch_size,
"num_threads" : workers,
"device_id" : rank % torch.cuda.device_count(),
"batch_size": batch_size,
"num_threads": workers,
"device_id": rank % torch.cuda.device_count(),
"seed": 12 + rank % torch.cuda.device_count(),
}

pipe = training_pipe(data_dir=traindir, interpolation=interpolation, image_size=image_size,
output_layout=output_layout, automatic_augmentation=augmentation,
dali_device=dali_device, rank=rank, world_size=world_size,
**pipeline_kwargs)
pipe = training_pipe(
data_dir=traindir,
interpolation=interpolation,
image_size=image_size,
output_layout=output_layout,
automatic_augmentation=augmentation,
dali_device=dali_device,
rank=rank,
world_size=world_size,
**pipeline_kwargs,
)

pipe.build()
train_loader = DALIClassificationIterator(
Expand Down Expand Up @@ -201,15 +220,20 @@ def gdvl(
valdir = os.path.join(data_path, "val")

pipeline_kwargs = {
"batch_size" : batch_size,
"num_threads" : workers,
"device_id" : rank % torch.cuda.device_count(),
"batch_size": batch_size,
"num_threads": workers,
"device_id": rank % torch.cuda.device_count(),
"seed": 12 + rank % torch.cuda.device_count(),
}

pipe = validation_pipe(data_dir=valdir, interpolation=interpolation,
image_size=image_size + crop_padding, image_crop=image_size,
output_layout=output_layout, **pipeline_kwargs)
pipe = validation_pipe(
data_dir=valdir,
interpolation=interpolation,
image_size=image_size + crop_padding,
image_crop=image_size,
output_layout=output_layout,
**pipeline_kwargs,
)

pipe.build()
val_loader = DALIClassificationIterator(
Expand All @@ -224,19 +248,25 @@ def gdvl(
return gdvl


def fast_collate(memory_format, batch):
def fast_collate(memory_format, typical_loader, batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
if typical_loader:
w = imgs[0].size()[1]
h = imgs[0].size()[2]
else:
w = imgs[0].size[0]
h = imgs[0].size[1]

tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8).contiguous(
memory_format=memory_format
)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if nump_array.ndim < 3:
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
if typical_loader is False:
nump_array = np.rollaxis(nump_array, 2)

tensor[i] += torch.from_numpy(nump_array.copy())

Expand All @@ -252,57 +282,70 @@ def expand(num_classes, dtype, tensor):


class PrefetchedWrapper(object):
def prefetched_loader(loader, num_classes, one_hot):
mean = (
torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
.cuda()
.view(1, 3, 1, 1)
)
std = (
torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
.cuda()
.view(1, 3, 1, 1)
)

stream = torch.cuda.Stream()
first = True

for next_input, next_target in loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float()
if one_hot:
next_target = expand(num_classes, torch.float, next_target)

next_input = next_input.sub_(mean).div_(std)

if not first:
yield input, target
else:
first = False

torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
def prefetched_loader(loader, num_classes, one_hot, typical_loader ):
if typical_loader:
stream = torch.cuda.Stream()
for next_input, next_target in loader:
with torch.cuda.stream(stream):
next_input = next_input.to(device="cuda")
next_target = next_target.to(device="cuda")
next_input = next_input.float()
if one_hot:
next_target = expand(num_classes, torch.float, next_target)
yield next_input, next_target

Check notice

Code scanning / CodeQL

Unused local variable Note documentation

Variable mean is not used.
else:
mean = (
torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
.cuda()
.view(1, 3, 1, 1)

Check notice

Code scanning / CodeQL

Unused local variable Note documentation

Variable std is not used.
)
std = (
torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
.cuda()
.view(1, 3, 1, 1)
)

stream = torch.cuda.Stream()
first = True

for next_input, next_target in loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float()
if one_hot:
next_target = expand(num_classes, torch.float, next_target)

# next_input = next_input.sub_(mean).div_(std)

if not first:
yield input, target
else:
first = False

torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target

yield input, target
yield input, target

def __init__(self, dataloader, start_epoch, num_classes, one_hot):
def __init__(self, dataloader, start_epoch, num_classes, one_hot, typical_loader):
self.dataloader = dataloader
self.epoch = start_epoch
self.one_hot = one_hot
self.num_classes = num_classes
self.typical_loader = typical_loader

def __iter__(self):
if self.dataloader.sampler is not None and isinstance(
self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler
self.dataloader.sampler,
torch.utils.data.distributed.DistributedSampler,
):

self.dataloader.sampler.set_epoch(self.epoch)
self.epoch += 1
return PrefetchedWrapper.prefetched_loader(
self.dataloader, self.num_classes, self.one_hot
self.dataloader, self.num_classes, self.one_hot, self.typical_loader
)

def __len__(self):
Expand All @@ -322,6 +365,7 @@ def get_pytorch_train_loader(
_worker_init_fn=None,
prefetch_factor=2,
memory_format=torch.contiguous_format,
typical_loader=False,
):
interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
interpolation
Expand All @@ -336,9 +380,19 @@ def get_pytorch_train_loader(
elif augmentation == "autoaugment":
transforms_list.append(AutoaugmentImageNetPolicy())
else:
raise NotImplementedError(f"Automatic augmentation: '{augmentation}' is not supported"
" for PyTorch data loader.")
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(transforms_list))
raise NotImplementedError(
f"Automatic augmentation: '{augmentation}' is not supported"
" for PyTorch data loader."
)

if typical_loader:
transforms_list.append(transforms.ToTensor())
transforms_list.append(
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)
train_dataset = datasets.ImageFolder(
traindir, transforms.Compose(transforms_list)
)

if torch.distributed.is_initialized():
train_sampler = torch.utils.data.distributed.DistributedSampler(
Expand All @@ -355,14 +409,14 @@ def get_pytorch_train_loader(
num_workers=workers,
worker_init_fn=_worker_init_fn,
pin_memory=True,
collate_fn=partial(fast_collate, memory_format),
collate_fn=partial(fast_collate, memory_format, typical_loader),
drop_last=True,
persistent_workers=True,
prefetch_factor=prefetch_factor,
)

return (
PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot),
PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot, typical_loader),
len(train_loader),
)

Expand All @@ -379,21 +433,26 @@ def get_pytorch_val_loader(
crop_padding=32,
memory_format=torch.contiguous_format,
prefetch_factor=2,
typical_loader=False,
):
interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
interpolation
]
valdir = os.path.join(data_path, "val")
transforms_list = [
transforms.Resize(
image_size + crop_padding, interpolation=interpolation
),
transforms.CenterCrop(image_size),
]
if typical_loader:
transforms_list.append(transforms.ToTensor())
transforms_list.append(
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose(
[
transforms.Resize(
image_size + crop_padding, interpolation=interpolation
),
transforms.CenterCrop(image_size),
]
),
transforms.Compose(transforms_list),
)

if torch.distributed.is_initialized():
Expand All @@ -411,13 +470,15 @@ def get_pytorch_val_loader(
num_workers=workers,
worker_init_fn=_worker_init_fn,
pin_memory=True,
collate_fn=partial(fast_collate, memory_format),
collate_fn=partial(fast_collate, memory_format, typical_loader),
drop_last=False,
persistent_workers=True,
prefetch_factor=prefetch_factor,
)

return PrefetchedWrapper(val_loader, 0, num_classes, one_hot), len(val_loader)
return PrefetchedWrapper(val_loader, 0, num_classes, one_hot, typical_loader), len(
val_loader
)


class SynteticDataLoader(object):
Expand Down
7 changes: 7 additions & 0 deletions docs/examples/use_cases/pytorch/efficientnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def add_parser_arguments(parser, skip_arch=False):
+ " | ".join(DATA_BACKEND_CHOICES)
+ " (default: dali)",
)
parser.add_argument(
"--typical_loader",
action="store_true",
help="Skip advanced PyTorch data loader optimizations.",
)
parser.add_argument(
"--interpolation",
metavar="INTERPOLATION",
Expand Down Expand Up @@ -510,6 +515,7 @@ def _worker_init_fn(id):
_worker_init_fn=_worker_init_fn,
memory_format=memory_format,
prefetch_factor=args.prefetch,
typical_loader=args.typical_loader,
)
if args.mixup != 0.0:
train_loader = MixUpWrapper(args.mixup, train_loader)
Expand All @@ -525,6 +531,7 @@ def _worker_init_fn(id):
_worker_init_fn=_worker_init_fn,
memory_format=memory_format,
prefetch_factor=args.prefetch,
typical_loader=args.typical_loader,
)

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
Expand Down
Loading

0 comments on commit 4f7e994

Please sign in to comment.