From 63f71d525948faac40cb4c5612b28c537798e8ae Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 09:57:09 +0400 Subject: [PATCH 1/9] set dispatch batches to false --- chemlactica/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chemlactica/train.py b/chemlactica/train.py index 1826b0e..84734fc 100644 --- a/chemlactica/train.py +++ b/chemlactica/train.py @@ -245,6 +245,7 @@ def train( save_steps=save_steps, dataloader_drop_last=True, dataloader_pin_memory=True, + dispatch_batches=False, # torch_compile=True, # torch_compile requires to set use_orig_params=true # which has some conflict with saving checkpoints From 8d4f78026f3f4e0b44f8b380e885ae0d76e00baf Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 10:04:38 +0400 Subject: [PATCH 2/9] fix accelerator dataloader prepare behaviour --- chemlactica/custom_accelerator.py | 40 +++++++++++++++- chemlactica/custom_trainer.py | 76 ++++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 3 deletions(-) diff --git a/chemlactica/custom_accelerator.py b/chemlactica/custom_accelerator.py index 39ee098..b0454f1 100644 --- a/chemlactica/custom_accelerator.py +++ b/chemlactica/custom_accelerator.py @@ -1,7 +1,10 @@ -from accelerate import accelerator +from accelerate.state import ( + DistributedType, +) import torch -from accelerate import optimizer +from accelerate import optimizer, accelerator import inspect +from chemlactica.utils.distributed_utils import custom_prepare_data_loader class CustomAcceleratedOptimizer(optimizer.AcceleratedOptimizer): @@ -39,3 +42,36 @@ def prepare_optimizer( ) self._optimizers.append(optimizer) return optimizer + + def prepare_data_loader( + self, + data_loader: torch.utils.data.DataLoader, + device_placement=None, + slice_fn_for_dispatch=None, + ): + # Ensure we can't double wrap a DataLoader due to `find_batch_size` + if getattr(data_loader, "_is_accelerate_prepared", False): + if data_loader not in self._dataloaders: + self._dataloaders.append(data_loader) + return data_loader + if device_placement is None: + device_placement = ( + self.device_placement + if self.distributed_type != DistributedType.XLA + else False + ) + prepared_data_loader = custom_prepare_data_loader( + data_loader, + self.device, + num_processes=self.num_processes, + process_index=self.process_index, + split_batches=self.split_batches, + put_on_device=device_placement, + rng_types=self.rng_types.copy(), + dispatch_batches=self.dispatch_batches, + even_batches=self.even_batches, + slice_fn_for_dispatch=slice_fn_for_dispatch, + use_seedable_sampler=self.use_seedable_sampler, + ) + self._dataloaders.append(prepared_data_loader) + return prepared_data_loader diff --git a/chemlactica/custom_trainer.py b/chemlactica/custom_trainer.py index 73e0136..5c79aa6 100644 --- a/chemlactica/custom_trainer.py +++ b/chemlactica/custom_trainer.py @@ -3,6 +3,9 @@ import os from torch._tensor import Tensor from torch.nn.modules import Module +from custom_accelerator import CustomAccelerator +from transformers.utils import is_accelerate_available +from accelerate.utils import GradientAccumulationPlugin # from torch.distributed.fsdp.fully_sharded_data_parallel import ( # FullyShardedDataParallel as FSDP, @@ -15,7 +18,6 @@ from chemlactica.utils.utils import get_tokenizer from dataclasses import dataclass, field - # if is_torch_tpu_available(check_device=False): # import torch_xla.core.xla_model as xm @@ -48,6 +50,78 @@ def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tenso self.num_samples_to_print = None return super().training_step(model, inputs) + def create_accelerator_and_postprocess(self): + grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} + grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + + # create accelerator object + self.accelerator = CustomAccelerator( + deepspeed_plugin=self.args.deepspeed_plugin, + gradient_accumulation_plugin=gradient_accumulation_plugin, + **self.args.accelerator_config.to_dict(), + ) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, + # thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = ( + getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + ) + self.is_fsdp_enabled = ( + getattr(self.accelerator.state, "fsdp_plugin", None) is not None + ) + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + if is_accelerate_available("0.23.0"): + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if ( + fsdp_plugin.activation_checkpointing + and self.args.gradient_checkpointing + ): + raise ValueError( + "The activation_checkpointing in FSDP config and " + "the gradient_checkpointing in training arg " + "can't be set to True simultaneously. " + "Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + + if ( + self.is_deepspeed_enabled + and getattr(self.args, "hf_deepspeed_config", None) is None + ): + self.propagate_args_to_deepspeed() + + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError( + f"{wrapper} can't be used with `save_only_model` " + "along with `load_best_model_at_end`." + ) + + # `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP + if ( + self.is_deepspeed_enabled or self.is_fsdp_enabled + ) and self.args.auto_find_batch_size: + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise NotImplementedError( + f"`{wrapper}` doesn't support `auto_find_batch_size`." + ) + def _build_slurm_eval_command(self, train_command, trial): checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) From 1ffffc50489a4ec749b9bbda531d3ec2d3d6eef1 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 10:06:59 +0400 Subject: [PATCH 3/9] use custom accelerator in train script --- chemlactica/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chemlactica/train.py b/chemlactica/train.py index 84734fc..4a7edbf 100644 --- a/chemlactica/train.py +++ b/chemlactica/train.py @@ -13,8 +13,9 @@ from transformers import ( ProgressCallback, ) -from accelerate import Accelerator, logging, InitProcessGroupKwargs +from accelerate import logging, InitProcessGroupKwargs from accelerate.utils import broadcast_object_list +from custom_accelerator import CustomAccelerator from chemlactica.custom_trainer import CustomArguments from chemlactica.utils.callbacks import ( @@ -91,7 +92,7 @@ def train( kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) - accelerator = Accelerator( + accelerator = CustomAccelerator( kwargs_handlers=[kwargs], log_with="all", project_dir=track_dir ) From c236b241971e1de623e304ca603a2b241e0de3db Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 10:08:01 +0400 Subject: [PATCH 4/9] clarify accelerator import --- chemlactica/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chemlactica/train.py b/chemlactica/train.py index 4a7edbf..2b43e34 100644 --- a/chemlactica/train.py +++ b/chemlactica/train.py @@ -15,7 +15,7 @@ ) from accelerate import logging, InitProcessGroupKwargs from accelerate.utils import broadcast_object_list -from custom_accelerator import CustomAccelerator +from chemlactica.custom_accelerator import CustomAccelerator from chemlactica.custom_trainer import CustomArguments from chemlactica.utils.callbacks import ( From dfe68d410ff4659c4983c6457328aff08a19a480 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 10:11:33 +0400 Subject: [PATCH 5/9] implement fixed dataloader prepare function --- chemlactica/utils/distributed_utils.py | 238 +++++++++++++++++++++++++ 1 file changed, 238 insertions(+) diff --git a/chemlactica/utils/distributed_utils.py b/chemlactica/utils/distributed_utils.py index 3131024..db7c49f 100644 --- a/chemlactica/utils/distributed_utils.py +++ b/chemlactica/utils/distributed_utils.py @@ -1,4 +1,57 @@ import os +from typing import Callable, List, Optional, Union +from accelerate.state import ( + AcceleratorState, + DistributedType, +) +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler +from accelerate.data_loader import ( + DataLoaderShard, + DataLoaderDispatcher, + SeedableRandomSampler, + BatchSamplerShard, + # IterableDatasetShard, + # SkipBatchSampler, + # SkipDataLoader, + # DataLoaderStateMixin, +) + +import torch + +from accelerate.logging import get_logger + +from accelerate.utils import ( + RNGType, + is_torch_version, +) + + +logger = get_logger(__name__) + +# kwargs of the DataLoader in min version 1.4.0. +_PYTORCH_DATALOADER_KWARGS = { + "batch_size": 1, + "shuffle": False, + "sampler": None, + "batch_sampler": None, + "num_workers": 0, + "collate_fn": None, + "pin_memory": False, + "drop_last": False, + "timeout": 0, + "worker_init_fn": None, + "multiprocessing_context": None, + "generator": None, + "prefetch_factor": 2, + "persistent_workers": False, +} + +# kwargs added after by version +_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {} + +for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items(): + if is_torch_version(">=", v): + _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs) def get_experiment_hash(from_pretrained, train_type="pretrain"): @@ -6,3 +59,188 @@ def get_experiment_hash(from_pretrained, train_type="pretrain"): return str(from_pretrained.split(os.path.sep)[-2]) else: return "none" + + +def custom_prepare_data_loader( + dataloader: DataLoader, + device: Optional[torch.device] = None, + num_processes: Optional[int] = None, + process_index: Optional[int] = None, + split_batches: bool = False, + put_on_device: bool = False, + rng_types: Optional[List[Union[str, RNGType]]] = None, + dispatch_batches: Optional[bool] = None, + even_batches: bool = True, + slice_fn_for_dispatch: Optional[Callable] = None, + use_seedable_sampler: bool = False, +) -> DataLoader: + if dispatch_batches is None: + if not put_on_device: + dispatch_batches = False + else: + dispatch_batches = isinstance(dataloader.dataset, IterableDataset) + + if dispatch_batches and not put_on_device: + raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.") + # Grab defaults from AcceleratorState + state = AcceleratorState() + if num_processes is None: + num_processes = state.num_processes + if process_index is None: + process_index = state.process_index + + # Sanity check + if split_batches: + if dataloader.batch_size is not None: + batch_size_for_check = dataloader.batch_size + else: + # For custom batch_sampler + if hasattr(dataloader.batch_sampler, "batch_size"): + batch_size_for_check = dataloader.batch_sampler.batch_size + else: + raise ValueError( + "In order to use `split_batches==True` you must pass `batch_size`" + "to `dataloader` or `dataloader.batch_sampler` objects" + "and it has to return a natural number. " + "Your `dataloader.batch_size` is None and" + "`dataloader.batch_sampler` " + f"(`{type(dataloader.batch_sampler)}`) does not have" + "the `batch_size` attribute set." + ) + + if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0: + raise ValueError( + f"To use a `DataLoader` in `split_batches` mode:" + "the batch size ({dataloader.batch_size}) " + f"needs to be a round multiple of the number of processes ({num_processes})." + ) + + new_dataset = dataloader.dataset + # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it + new_batch_sampler = ( + dataloader.batch_sampler + if not isinstance(new_dataset, IterableDataset) + else None + ) + sampler_is_batch_sampler = False + synchronized_generator = None + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = getattr(dataloader.sampler, "sampler", None) + else: + sampler = getattr(dataloader.batch_sampler, "sampler", None) + if isinstance(sampler, RandomSampler) and use_seedable_sampler: + # When iterating through the dataloader during distributed processes + # we want to ensure that on each process we are iterating through the same + # samples in the same order if a seed is set. This requires a tweak + # to the `torch.utils.data.RandomSampler` class (if used). + sampler = SeedableRandomSampler( + data_source=sampler.data_source, + replacement=sampler.replacement, + num_samples=sampler._num_samples, + generator=getattr(sampler, "generator", torch.Generator()), + ) + + # No change if no multiprocess + if ( + num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM + ) and not dispatch_batches: + if isinstance(new_dataset, IterableDataset): + # if getattr(dataloader.dataset, "generator", None) is not None: + # synchronized_generator = dataloader.dataset.generator + # new_dataset = IterableDatasetShard( + # new_dataset, + # batch_size=dataloader.batch_size, + # drop_last=dataloader.drop_last, + # num_processes=num_processes, + # process_index=process_index, + # split_batches=split_batches, + # ) + pass + else: + batch_sampler = ( + dataloader.sampler + if sampler_is_batch_sampler + else dataloader.batch_sampler + ) + new_batch_sampler = BatchSamplerShard( + batch_sampler, + num_processes=num_processes, + process_index=process_index, + split_batches=split_batches, + even_batches=even_batches, + ) + + # We ignore all of those since they are all dealt with by our new_batch_sampler + ignore_kwargs = [ + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "drop_last", + ] + + if ( + rng_types is not None + and synchronized_generator is None + and "generator" in rng_types + ): + rng_types.remove("generator") + + kwargs = { + k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) + for k in _PYTORCH_DATALOADER_KWARGS + if k not in ignore_kwargs + } + + # Need to provide batch_size as batch_sampler is None for Iterable dataset + if new_batch_sampler is None: + kwargs["drop_last"] = dataloader.drop_last + kwargs["batch_size"] = ( + dataloader.batch_size // num_processes + if split_batches and not dispatch_batches + else dataloader.batch_size + ) + if dispatch_batches: + kwargs.pop("generator") + dataloader = DataLoaderDispatcher( + new_dataset, + split_batches=split_batches, + batch_sampler=new_batch_sampler, + _drop_last=dataloader.drop_last, + slice_fn=slice_fn_for_dispatch, + **kwargs, + ) + elif sampler_is_batch_sampler: + dataloader = DataLoaderShard( + new_dataset, + device=device + if put_on_device and state.distributed_type != DistributedType.XLA + else None, + sampler=new_batch_sampler, + batch_size=dataloader.batch_size, + rng_types=rng_types, + _drop_last=dataloader.drop_last, + synchronized_generator=synchronized_generator, + **kwargs, + ) + else: + dataloader = DataLoaderShard( + new_dataset, + device=device + if put_on_device and state.distributed_type != DistributedType.XLA + else None, + batch_sampler=new_batch_sampler, + rng_types=rng_types, + synchronized_generator=synchronized_generator, + _drop_last=dataloader.drop_last, + **kwargs, + ) + + if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler: + if sampler_is_batch_sampler: + dataloader.sampler.sampler = sampler + else: + dataloader.batch_sampler.sampler = sampler + + return dataloader From 9c9f9e91fbf22300d9711037a4bb607c7f89d9ea Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 10:16:58 +0400 Subject: [PATCH 6/9] remove repeat setting dispatch batches --- chemlactica/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chemlactica/train.py b/chemlactica/train.py index 4b2d613..2b43e34 100644 --- a/chemlactica/train.py +++ b/chemlactica/train.py @@ -244,7 +244,6 @@ def train( num_train_epochs=num_train_epochs, eval_steps=eval_steps, save_steps=save_steps, - dispatch_batches=False, dataloader_drop_last=True, dataloader_pin_memory=True, dispatch_batches=False, From d620f6b3590f1c1e1aed87a9c8e872b23c90ed27 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 10:44:20 +0400 Subject: [PATCH 7/9] remove cuda nsight from env file --- environment.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/environment.yml b/environment.yml index 07ac9fe..20343de 100644 --- a/environment.yml +++ b/environment.yml @@ -29,8 +29,6 @@ dependencies: - cuda-libraries=11.8.0=0 - cuda-libraries-dev=11.8.0=0 - cuda-memcheck=11.8.86=0 - - cuda-nsight=11.8.86=0 - - cuda-nsight-compute=11.8.0=0 - cuda-nvcc=11.8.89=0 - cuda-nvdisasm=11.8.86=0 - cuda-nvml-dev=11.8.86=0 From 16b42bfd74750347376025e520905710f7b791d8 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 11:08:32 +0400 Subject: [PATCH 8/9] remove unnecessary deps --- environment.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/environment.yml b/environment.yml index 20343de..5d80ec2 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: gemma_env_new +name: cl11.9_t_4.39 channels: - pytorch - nvidia @@ -29,6 +29,8 @@ dependencies: - cuda-libraries=11.8.0=0 - cuda-libraries-dev=11.8.0=0 - cuda-memcheck=11.8.86=0 + - cuda-nsight=11.8.86=0 + - cuda-nsight-compute=11.8.0=0 - cuda-nvcc=11.8.89=0 - cuda-nvdisasm=11.8.86=0 - cuda-nvml-dev=11.8.86=0 @@ -178,8 +180,6 @@ dependencies: - backoff==2.2.1 - base58==2.0.1 - bitsandbytes==0.43.0 - - boto3==1.34.84 - - botocore==1.34.84 - cachetools==5.3.3 - certifi==2024.2.2 - cffi==1.16.0 @@ -208,7 +208,6 @@ dependencies: - huggingface-hub==0.22.2 - identify==2.5.35 - idna==3.6 - - jmespath==1.0.1 - joblib==1.3.2 - kiwisolver==1.4.5 - mako==1.3.2 @@ -240,7 +239,6 @@ dependencies: - requests==2.31.0 - restrictedpython==7.1 - rich==13.7.1 - - s3transfer==0.10.1 - safetensors==0.4.2 - scikit-learn==1.4.1.post1 - scipy==1.12.0 @@ -264,4 +262,4 @@ dependencies: - xmltodict==0.13.0 - xxhash==3.4.1 - yarl==1.9.4 -prefix: /auto/home/menuab/miniforge3/envs/gemma_env_new +prefix: /home/philipp/miniforge3/envs/cl11.9_t_4.39 From 94663594ce75c227ff3b9fac4ea15d8346e0a54d Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 8 May 2024 11:22:21 +0400 Subject: [PATCH 9/9] add test status file --- test_status.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_status.yaml b/test_status.yaml index c1c69b9..fe49900 100644 --- a/test_status.yaml +++ b/test_status.yaml @@ -1 +1 @@ -e82e20fd17fe37c73ca4202205d9f42075cc46ba: PASS +16b42bfd74750347376025e520905710f7b791d8: PASS