diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py index f0e19aa4064..064d441e048 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py @@ -51,9 +51,14 @@ import torchvision.datasets as datasets import torchvision.transforms as transforms + def load_jpeg_from_file(path, cuda=True): img_transforms = transforms.Compose( - [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()] + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + ] ) img = img_transforms(Image.open(path)) @@ -74,6 +79,7 @@ def load_jpeg_from_file(path, cuda=True): return input + class DALIWrapper(object): def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format): @@ -90,8 +96,11 @@ def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format): def nhwc_to_nchw(t): return t[0], t[3], t[1], t[2] - input = torch.as_strided(data[0]["data"], size=nhwc_to_nchw(shape), - stride=nhwc_to_nchw(stride)) + input = torch.as_strided( + data[0]["data"], + size=nhwc_to_nchw(shape), + stride=nhwc_to_nchw(stride), + ) else: input = data[0]["data"].contiguous(memory_format=memory_format) target = torch.reshape(data[0]["label"], [-1]).cuda().long() @@ -108,7 +117,10 @@ def __init__(self, dalipipeline, num_classes, one_hot, memory_format): def __iter__(self): return DALIWrapper.gen_wrapper( - self.dalipipeline, self.num_classes, self.one_hot, self.memory_format + self.dalipipeline, + self.num_classes, + self.one_hot, + self.memory_format, ) @@ -145,16 +157,23 @@ def gdtl( traindir = os.path.join(data_path, "train") pipeline_kwargs = { - "batch_size" : batch_size, - "num_threads" : workers, - "device_id" : rank % torch.cuda.device_count(), + "batch_size": batch_size, + "num_threads": workers, + "device_id": rank % torch.cuda.device_count(), "seed": 12 + rank % torch.cuda.device_count(), } - pipe = training_pipe(data_dir=traindir, interpolation=interpolation, image_size=image_size, - output_layout=output_layout, automatic_augmentation=augmentation, - dali_device=dali_device, rank=rank, world_size=world_size, - **pipeline_kwargs) + pipe = training_pipe( + data_dir=traindir, + interpolation=interpolation, + image_size=image_size, + output_layout=output_layout, + automatic_augmentation=augmentation, + dali_device=dali_device, + rank=rank, + world_size=world_size, + **pipeline_kwargs, + ) pipe.build() train_loader = DALIClassificationIterator( @@ -201,15 +220,20 @@ def gdvl( valdir = os.path.join(data_path, "val") pipeline_kwargs = { - "batch_size" : batch_size, - "num_threads" : workers, - "device_id" : rank % torch.cuda.device_count(), + "batch_size": batch_size, + "num_threads": workers, + "device_id": rank % torch.cuda.device_count(), "seed": 12 + rank % torch.cuda.device_count(), } - pipe = validation_pipe(data_dir=valdir, interpolation=interpolation, - image_size=image_size + crop_padding, image_crop=image_size, - output_layout=output_layout, **pipeline_kwargs) + pipe = validation_pipe( + data_dir=valdir, + interpolation=interpolation, + image_size=image_size + crop_padding, + image_crop=image_size, + output_layout=output_layout, + **pipeline_kwargs, + ) pipe.build() val_loader = DALIClassificationIterator( @@ -224,11 +248,16 @@ def gdvl( return gdvl -def fast_collate(memory_format, batch): +def fast_collate(memory_format, typical_loader, batch): imgs = [img[0] for img in batch] targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) - w = imgs[0].size[0] - h = imgs[0].size[1] + if typical_loader: + w = imgs[0].size()[1] + h = imgs[0].size()[2] + else: + w = imgs[0].size[0] + h = imgs[0].size[1] + tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8).contiguous( memory_format=memory_format ) @@ -236,7 +265,8 @@ def fast_collate(memory_format, batch): nump_array = np.asarray(img, dtype=np.uint8) if nump_array.ndim < 3: nump_array = np.expand_dims(nump_array, axis=-1) - nump_array = np.rollaxis(nump_array, 2) + if typical_loader is False: + nump_array = np.rollaxis(nump_array, 2) tensor[i] += torch.from_numpy(nump_array.copy()) @@ -252,57 +282,70 @@ def expand(num_classes, dtype, tensor): class PrefetchedWrapper(object): - def prefetched_loader(loader, num_classes, one_hot): - mean = ( - torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]) - .cuda() - .view(1, 3, 1, 1) - ) - std = ( - torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]) - .cuda() - .view(1, 3, 1, 1) - ) - - stream = torch.cuda.Stream() - first = True - - for next_input, next_target in loader: - with torch.cuda.stream(stream): - next_input = next_input.cuda(non_blocking=True) - next_target = next_target.cuda(non_blocking=True) - next_input = next_input.float() - if one_hot: - next_target = expand(num_classes, torch.float, next_target) - - next_input = next_input.sub_(mean).div_(std) - - if not first: - yield input, target - else: - first = False - - torch.cuda.current_stream().wait_stream(stream) - input = next_input - target = next_target + def prefetched_loader(loader, num_classes, one_hot, typical_loader ): + if typical_loader: + stream = torch.cuda.Stream() + for next_input, next_target in loader: + with torch.cuda.stream(stream): + next_input = next_input.to(device="cuda") + next_target = next_target.to(device="cuda") + next_input = next_input.float() + if one_hot: + next_target = expand(num_classes, torch.float, next_target) + yield next_input, next_target + else: + mean = ( + torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]) + .cuda() + .view(1, 3, 1, 1) + ) + std = ( + torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]) + .cuda() + .view(1, 3, 1, 1) + ) + + stream = torch.cuda.Stream() + first = True + + for next_input, next_target in loader: + with torch.cuda.stream(stream): + next_input = next_input.cuda(non_blocking=True) + next_target = next_target.cuda(non_blocking=True) + next_input = next_input.float() + if one_hot: + next_target = expand(num_classes, torch.float, next_target) + + # next_input = next_input.sub_(mean).div_(std) + + if not first: + yield input, target + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + input = next_input + target = next_target - yield input, target + yield input, target - def __init__(self, dataloader, start_epoch, num_classes, one_hot): + def __init__(self, dataloader, start_epoch, num_classes, one_hot, typical_loader): self.dataloader = dataloader self.epoch = start_epoch self.one_hot = one_hot self.num_classes = num_classes + self.typical_loader = typical_loader def __iter__(self): if self.dataloader.sampler is not None and isinstance( - self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler + self.dataloader.sampler, + torch.utils.data.distributed.DistributedSampler, ): self.dataloader.sampler.set_epoch(self.epoch) self.epoch += 1 return PrefetchedWrapper.prefetched_loader( - self.dataloader, self.num_classes, self.one_hot + self.dataloader, self.num_classes, self.one_hot, self.typical_loader ) def __len__(self): @@ -322,6 +365,7 @@ def get_pytorch_train_loader( _worker_init_fn=None, prefetch_factor=2, memory_format=torch.contiguous_format, + typical_loader=False, ): interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[ interpolation @@ -336,9 +380,19 @@ def get_pytorch_train_loader( elif augmentation == "autoaugment": transforms_list.append(AutoaugmentImageNetPolicy()) else: - raise NotImplementedError(f"Automatic augmentation: '{augmentation}' is not supported" - " for PyTorch data loader.") - train_dataset = datasets.ImageFolder(traindir, transforms.Compose(transforms_list)) + raise NotImplementedError( + f"Automatic augmentation: '{augmentation}' is not supported" + " for PyTorch data loader." + ) + + if typical_loader: + transforms_list.append(transforms.ToTensor()) + transforms_list.append( + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ) + train_dataset = datasets.ImageFolder( + traindir, transforms.Compose(transforms_list) + ) if torch.distributed.is_initialized(): train_sampler = torch.utils.data.distributed.DistributedSampler( @@ -355,14 +409,14 @@ def get_pytorch_train_loader( num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, - collate_fn=partial(fast_collate, memory_format), + collate_fn=partial(fast_collate, memory_format, typical_loader), drop_last=True, persistent_workers=True, prefetch_factor=prefetch_factor, ) return ( - PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot), + PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot, typical_loader), len(train_loader), ) @@ -379,21 +433,26 @@ def get_pytorch_val_loader( crop_padding=32, memory_format=torch.contiguous_format, prefetch_factor=2, + typical_loader=False, ): interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[ interpolation ] valdir = os.path.join(data_path, "val") + transforms_list = [ + transforms.Resize( + image_size + crop_padding, interpolation=interpolation + ), + transforms.CenterCrop(image_size), + ] + if typical_loader: + transforms_list.append(transforms.ToTensor()) + transforms_list.append( + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ) val_dataset = datasets.ImageFolder( valdir, - transforms.Compose( - [ - transforms.Resize( - image_size + crop_padding, interpolation=interpolation - ), - transforms.CenterCrop(image_size), - ] - ), + transforms.Compose(transforms_list), ) if torch.distributed.is_initialized(): @@ -411,13 +470,15 @@ def get_pytorch_val_loader( num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, - collate_fn=partial(fast_collate, memory_format), + collate_fn=partial(fast_collate, memory_format, typical_loader), drop_last=False, persistent_workers=True, prefetch_factor=prefetch_factor, ) - return PrefetchedWrapper(val_loader, 0, num_classes, one_hot), len(val_loader) + return PrefetchedWrapper(val_loader, 0, num_classes, one_hot, typical_loader), len( + val_loader + ) class SynteticDataLoader(object): diff --git a/docs/examples/use_cases/pytorch/efficientnet/main.py b/docs/examples/use_cases/pytorch/efficientnet/main.py index d8d373ae53e..e7677b38388 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/main.py +++ b/docs/examples/use_cases/pytorch/efficientnet/main.py @@ -78,6 +78,11 @@ def add_parser_arguments(parser, skip_arch=False): + " | ".join(DATA_BACKEND_CHOICES) + " (default: dali)", ) + parser.add_argument( + "--typical_loader", + action="store_true", + help="Skip advanced PyTorch data loader optimizations.", + ) parser.add_argument( "--interpolation", metavar="INTERPOLATION", @@ -510,6 +515,7 @@ def _worker_init_fn(id): _worker_init_fn=_worker_init_fn, memory_format=memory_format, prefetch_factor=args.prefetch, + typical_loader=args.typical_loader, ) if args.mixup != 0.0: train_loader = MixUpWrapper(args.mixup, train_loader) @@ -525,6 +531,7 @@ def _worker_init_fn(id): _worker_init_fn=_worker_init_fn, memory_format=memory_format, prefetch_factor=args.prefetch, + typical_loader=args.typical_loader, ) if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: diff --git a/qa/TL3_EfficientNet_benchmark/test_pytorch.sh b/qa/TL3_EfficientNet_benchmark/test_pytorch.sh index 609f9a7a834..abd7cbf4e54 100644 --- a/qa/TL3_EfficientNet_benchmark/test_pytorch.sh +++ b/qa/TL3_EfficientNet_benchmark/test_pytorch.sh @@ -56,22 +56,28 @@ export PATH_TO_IMAGENET=/imagenet export RESULT_WORKSPACE=./ # synthetic benchmark -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 1 --prof 1000 --no-checkpoints --training-only --data-backend synthetic --workspace $RESULT_WORKSPACE --report-file bench_report_synthetic.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --epochs 1 --prof 1000 --no-checkpoints --training-only --data-backend synthetic --workspace $RESULT_WORKSPACE --report-file bench_report_synthetic.json $PATH_TO_IMAGENET # DALI without automatic augmentations -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_dali.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_dali.json $PATH_TO_IMAGENET # DALI with AutoAugment -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_aa.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_aa.json $PATH_TO_IMAGENET # DALI with TrivialAugment -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation trivialaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_ta.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation trivialaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_ta.json $PATH_TO_IMAGENET # PyTorch without automatic augmentations -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_pytorch.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 10 --typical_loader --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_pytorch.json $PATH_TO_IMAGENET # PyTorch with AutoAugment: -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 128 --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_pytorch_aa.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 10 --typical_loader --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_pytorch_aa.json $PATH_TO_IMAGENET + +# Optimized PyTorch without automatic augmentations +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 10 --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_optimized_pytorch.json $PATH_TO_IMAGENET + +# Optimized PyTorch with AutoAugment: +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 10 --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_optimized_pytorch_aa.json $PATH_TO_IMAGENET # The line below finds the lines with `train.total_ips`, takes the last one (with the result we @@ -107,6 +113,9 @@ CHECK_PERF_THRESHOLD "bench_report_dali_aa.json" $DALI_AA_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_dali_ta.json" $DALI_TA_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_pytorch.json" $PYTORCH_NONE_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_pytorch_aa.json" $PYTORCH_AA_THRESHOLD +CHECK_PERF_THRESHOLD "bench_report_optimized_pytorch.json" $PYTORCH_NONE_THRESHOLD +CHECK_PERF_THRESHOLD "bench_report_optimized_pytorch_aa.json" $PYTORCH_AA_THRESHOLD + # In the initial training we get significant increase in accuracy on the first few epochs,