diff --git a/benchmarks/torchvision_classification/helpers.py b/benchmarks/torchvision_classification/helpers.py new file mode 100644 index 000000000..383bc554a --- /dev/null +++ b/benchmarks/torchvision_classification/helpers.py @@ -0,0 +1,150 @@ +import itertools +import os +import random +from functools import partial +from pathlib import Path + +import torch +import torch.distributed as dist +import torchvision +from PIL import Image +from torchdata.datapipes.iter import FileLister, IterDataPipe + + +# TODO: maybe infinite buffer can / is already natively supported by torchdata? +INFINITE_BUFFER_SIZE = 1_000_000_000 + +IMAGENET_TRAIN_LEN = 1_281_167 +IMAGENET_TEST_LEN = 50_000 + + +class _LenSetter(IterDataPipe): + # TODO: Ideally, we woudn't need this extra class + def __init__(self, dp, root): + self.dp = dp + + if "train" in str(root): + self.size = IMAGENET_TRAIN_LEN + elif "val" in str(root): + self.size = IMAGENET_TEST_LEN + else: + raise ValueError("oops?") + + def __iter__(self): + yield from self.dp + + def __len__(self): + # TODO The // world_size part shouldn't be needed. See https://github.com/pytorch/data/issues/533 + return self.size // dist.get_world_size() + + +def _decode(path, root, category_to_int): + category = Path(path).relative_to(root).parts[0] + + image = Image.open(path).convert("RGB") + label = category_to_int(category) + + return image, label + + +def _apply_tranforms(img_and_label, transforms): + img, label = img_and_label + return transforms(img), label + + +def make_dp(root, transforms): + + root = Path(root).expanduser().resolve() + categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) + category_to_int = {category: i for (i, category) in enumerate(categories)} + + dp = FileLister(str(root), recursive=True, masks=["*.JPEG"]) + + dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter() + dp = dp.map(partial(_decode, root=root, category_to_int=category_to_int)) + dp = dp.map(partial(_apply_tranforms, transforms=transforms)) + + dp = _LenSetter(dp, root=root) + return dp + + +class PreLoadedMapStyle: + # All the data is pre-loaded and transformed in __init__, so the DataLoader should be crazy fast. + # This is just to assess how fast a model could theoretically be trained if there was no data bottleneck at all. + def __init__(self, dir, transform, buffer_size=100): + dataset = torchvision.datasets.ImageFolder(dir, transform=transform) + self.size = len(dataset) + self.samples = [dataset[torch.randint(0, len(dataset), size=(1,)).item()] for i in range(buffer_size)] + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return self.samples[idx % len(self.samples)] + + +class _PreLoadedDP(IterDataPipe): + # Same as above, but this is a DataPipe + def __init__(self, root, transforms, buffer_size=100): + dataset = torchvision.datasets.ImageFolder(root, transform=transforms) + self.size = len(dataset) + self.samples = [dataset[torch.randint(0, len(dataset), size=(1,)).item()] for i in range(buffer_size)] + # Note: the rng might be different across DDP workers so they'll all have different samples. + # But we don't care about accuracy here so whatever. + + def __iter__(self): + for idx in range(self.size): + yield self.samples[idx % len(self.samples)] + + +def make_pre_loaded_dp(root, transforms): + dp = _PreLoadedDP(root=root, transforms=transforms) + dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter() + dp = _LenSetter(dp, root=root) + return dp + + +class MapStyleToIterable(torch.utils.data.IterableDataset): + # This converts a MapStyle dataset into an iterable one. + # Not sure this kind of Iterable dataset is actually useful to benchmark. It + # was necessary when benchmarking async-io stuff, but not anymore. + # If anything, it shows how tricky Iterable datasets are to implement. + def __init__(self, dataset, shuffle): + self.dataset = dataset + self.shuffle = shuffle + + self.size = len(self.dataset) + self.seed = 0 # has to be hard-coded for all DDP workers to have the same shuffling + + def __len__(self): + return self.size // dist.get_world_size() + + def __iter__(self): + + worker_info = torch.utils.data.get_worker_info() + num_dl_workers = worker_info.num_workers + dl_worker_id = worker_info.id + + num_ddp_workers = dist.get_world_size() + ddp_worker_id = dist.get_rank() + + num_total_workers = num_ddp_workers * num_dl_workers + current_worker_id = ddp_worker_id + (num_ddp_workers * dl_worker_id) + + indices = range(self.size) + if self.shuffle: + rng = random.Random(self.seed) + indices = rng.sample(indices, k=self.size) + indices = itertools.islice(indices, current_worker_id, None, num_total_workers) + + samples = (self.dataset[i] for i in indices) + yield from samples + + +# TODO: maybe only generate these when --no-transforms is passed? +_RANDOM_IMAGE_TENSORS = [torch.randn(3, 224, 224) for _ in range(300)] + + +def no_transforms(_): + # see --no-transforms doc + return random.choice(_RANDOM_IMAGE_TENSORS) diff --git a/benchmarks/torchvision_classification/presets.py b/benchmarks/torchvision_classification/presets.py new file mode 100644 index 000000000..00c7bfa8b --- /dev/null +++ b/benchmarks/torchvision_classification/presets.py @@ -0,0 +1,53 @@ +import torch +from torchvision.transforms import transforms + + +class ClassificationPresetTrain: + def __init__( + self, + *, + crop_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + hflip_prob=0.5, + ): + trans = [transforms.RandomResizedCrop(crop_size)] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + + trans.extend( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class ClassificationPresetEval: + def __init__( + self, + *, + crop_size, + resize_size=256, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ): + + self.transforms = transforms.Compose( + [ + transforms.Resize(resize_size), + transforms.CenterCrop(crop_size), + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, img): + return self.transforms(img) diff --git a/benchmarks/torchvision_classification/train.py b/benchmarks/torchvision_classification/train.py new file mode 100644 index 000000000..ff12e902f --- /dev/null +++ b/benchmarks/torchvision_classification/train.py @@ -0,0 +1,346 @@ +import datetime +import os +import time +import warnings + +import helpers +import presets +import torch +import torch.utils.data +import torchvision +import utils +from torch import nn +from torchdata.dataloader2 import adapter, DataLoader2, MultiProcessingReadingService + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) + + header = f"Epoch: [{epoch}]" + for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + if args.data_loading_only: + continue + + start_time = time.time() + image, target = image.to(device), target.to(device) + + output = model(image) + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + batch_size = image.shape[0] + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) + + +def evaluate(model, criterion, data_loader, device, args, print_freq=100, log_suffix=""): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = f"Test: {log_suffix}" + + metric_logger.add_meter("acc1", utils.SmoothedValue()) + metric_logger.add_meter("acc5", utils.SmoothedValue()) + + num_processed_samples = 0 + with torch.inference_mode(): + for image, target in metric_logger.log_every(data_loader, print_freq, header): + if args.data_loading_only: + continue + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + batch_size = image.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + num_processed_samples += batch_size + # gather the stats from all processes + + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + if ( + hasattr(data_loader, "dataset") + and hasattr(data_loader.dataset, "__len__") + and len(data_loader.dataset) != num_processed_samples + and torch.distributed.get_rank() == 0 + ): + warnings.warn( + f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " + "samples were used for the validation, which might bias the results. " + "Try adjusting the batch size and / or the world size. " + "Setting the world size to 1 is always a safe bet." + ) + + metric_logger.synchronize_between_processes() + + print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") + return metric_logger.acc1.global_avg + + +def create_data_loaders(args): + print(f"file-system = {args.fs}") + + if args.fs == "fsx": + dataset_dir = "/datasets01" + elif args.fs == "fsx_isolated": + dataset_dir = "/fsx_isolated" + elif args.fs == "ontap": + dataset_dir = "/datasets01_ontap" + elif args.fs == "ontap_isolated": + dataset_dir = "/ontap_isolated" + else: + raise ValueError(f"bad args.fs, got {args.fs}") + + dataset_dir += "/imagenet_full_size/061417/" + train_dir = os.path.join(dataset_dir, "train") + val_dir = os.path.join(dataset_dir, "val") + + val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size + + if args.no_transforms: + train_preset = val_preset = helpers.no_transforms + else: + train_preset = presets.ClassificationPresetTrain(crop_size=train_crop_size) + val_preset = presets.ClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size) + + if args.ds_type == "dp": + builder = helpers.make_pre_loaded_dp if args.preload_ds else helpers.make_dp + train_dataset = builder(train_dir, transforms=train_preset) + val_dataset = builder(val_dir, transforms=val_preset) + + train_sampler = val_sampler = None + train_shuffle = True + + elif args.ds_type == "iterable": + train_dataset = torchvision.datasets.ImageFolder(train_dir, transform=train_preset) + train_dataset = helpers.MapStyleToIterable(train_dataset, shuffle=True) + + val_dataset = torchvision.datasets.ImageFolder(val_dir, transform=val_preset) + val_dataset = helpers.MapStyleToIterable(val_dataset, shuffle=False) + + train_sampler = val_sampler = None + train_shuffle = None # but actually True + + elif args.ds_type == "mapstyle": + builder = helpers.PreLoadedMapStyle if args.preload_ds else torchvision.datasets.ImageFolder + train_dataset = builder(train_dir, transform=train_preset) + val_dataset = builder(val_dir, transform=val_preset) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + train_shuffle = None # but actually True + + else: + raise ValueError(f"Invalid value for args.ds_type ({args.ds_type})") + + data_loader_arg = args.data_loader.lower() + if data_loader_arg == "v1": + train_data_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=train_shuffle, + sampler=train_sampler, + num_workers=args.workers, + pin_memory=True, + drop_last=True, + ) + val_data_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, + sampler=val_sampler, + num_workers=args.workers, + pin_memory=True, + ) + elif data_loader_arg == "v2": + if args.ds_type != "dp": + raise ValueError("DataLoader2 only works with datapipes.") + + # Note: we are batching and collating here *after the transforms*, which is consistent with DLV1. + # But maybe it would be more efficient to do that before, so that the transforms can work on batches?? + + train_dataset = train_dataset.batch(args.batch_size, drop_last=True).collate() + train_data_loader = DataLoader2( + train_dataset, + datapipe_adapter_fn=adapter.Shuffle(), + reading_service=MultiProcessingReadingService(num_workers=args.workers), + ) + + val_dataset = val_dataset.batch(args.batch_size, drop_last=True).collate() # TODO: Do we need drop_last here? + val_data_loader = DataLoader2( + val_dataset, + reading_service=MultiProcessingReadingService(num_workers=args.workers), + ) + else: + raise ValueError(f"invalid data-loader param. Got {args.data_loader}") + + return train_data_loader, val_data_loader, train_sampler + + +def main(args): + if args.output_dir: + utils.mkdir(args.output_dir) + + utils.init_distributed_mode(args) + print("\n".join(f"{k}: {str(v)}" for k, v in sorted(dict(vars(args)).items()))) + + device = torch.device(args.device) + + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + + train_data_loader, val_data_loader, train_sampler = create_data_loaders(args) + + num_classes = 1000 # I'm lazy. TODO change this + + print("Creating model") + model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + evaluate(model, criterion, val_data_loader, device=device, args=args) + return + + print("Start training") + start_time = time.time() + for epoch in range(args.epochs): + if args.distributed and train_sampler is not None: + train_sampler.set_epoch(epoch) + train_one_epoch(model, criterion, optimizer, train_data_loader, device, epoch, args) + lr_scheduler.step() + evaluate(model, criterion, val_data_loader, device=device, args=args) + + if args.output_dir: + checkpoint = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) + + if epoch == 0: + first_epoch_time = time.time() - start_time + + total_time = time.time() - start_time + print(f"Training time: {datetime.timedelta(seconds=int(total_time))}") + print(f"Training time (w/o 1st epoch): {datetime.timedelta(seconds=int(total_time - first_epoch_time))}") + + +def get_args_parser(add_help=True): + import argparse + + parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) + + parser.add_argument("--fs", default="fsx", type=str) + parser.add_argument("--model", default="resnet18", type=str, help="model name") + parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") + parser.add_argument( + "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" + ) + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=12, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") + + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + + # distributed training parameters + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) + parser.add_argument( + "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" + ) + parser.add_argument( + "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" + ) + parser.add_argument( + "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" + ) + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + + parser.add_argument( + "--ds-type", + default="mapstyle", + type=str, + help="'dp' or 'iterable' or 'mapstyle' (for regular indexable datasets)", + ) + + parser.add_argument( + "--preload-ds", + action="store_true", + help="whether to use a fake dataset where all images are pre-loaded in RAM and already transformed. " + "Mostly useful to benchmark how fast a model training would be without data-loading bottlenecks." + "Acc results are irrevant because we don't cache the entire dataset, only a very small fraction of it.", + ) + parser.add_argument( + "--data-loading-only", + action="store_true", + help="When on, we bypass the model's forward and backward passes. So mostly only the dataloading happens", + ) + parser.add_argument( + "--no-transforms", + action="store_true", + help="Whether to apply transforms to the images. No transforms means we " + "load and decode PIL images as usual, but we don't transform them. Instead we discard them " + "and the dataset will produce random tensors instead. We " + "need to create random tensors because without transforms, the images would still be PIL images " + "and they wouldn't be of the required size." + "Obviously, Acc resuts will not be relevant.", + ) + + parser.add_argument( + "--data-loader", + default="V1", + type=str, + help="'V1' or 'V2'. V2 only works for datapipes", + ) + + return parser + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + main(args) diff --git a/benchmarks/torchvision_classification/utils.py b/benchmarks/torchvision_classification/utils.py new file mode 100644 index 000000000..b41bb8971 --- /dev/null +++ b/benchmarks/torchvision_classification/utils.py @@ -0,0 +1,281 @@ +import datetime +import errno +import os +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + if not self.deque: + return 0 + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + if not self.deque: + return 0 + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + try: + return self.total / self.count + except ZeroDivisionError: + return 0 + + @property + def max(self): + if not self.deque: + return 0 + return max(self.deque) + + @property + def value(self): + if not self.deque: + return 0 + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + model_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "model: {model}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + dtime = time.time() - end + data_time.update(dtime) + yield obj + ttime = time.time() - end + iter_time.update(ttime) + model_time.update(ttime - dtime) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + model=str(model_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.inference_mode(): + maxk = max(topk) + batch_size = target.size(0) + if target.ndim == 2: + target = target.max(dim=1)[1] + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target[None]) + + res = [] + for k in topk: + correct_k = correct[:k].flatten().sum(dtype=torch.float32) + res.append(correct_k * (100.0 / batch_size)) + return res + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + if args.data_loader.lower() != "ffcv": + setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t