Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplication and missing data via prepare dataloader overwrite #26

Merged
merged 10 commits into from
May 8, 2024
40 changes: 38 additions & 2 deletions chemlactica/custom_accelerator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from accelerate import accelerator
from accelerate.state import (
DistributedType,
)
import torch
from accelerate import optimizer
from accelerate import optimizer, accelerator
import inspect
from chemlactica.utils.distributed_utils import custom_prepare_data_loader


class CustomAcceleratedOptimizer(optimizer.AcceleratedOptimizer):
Expand Down Expand Up @@ -39,3 +42,36 @@ def prepare_optimizer(
)
self._optimizers.append(optimizer)
return optimizer

def prepare_data_loader(
self,
data_loader: torch.utils.data.DataLoader,
device_placement=None,
slice_fn_for_dispatch=None,
):
# Ensure we can't double wrap a DataLoader due to `find_batch_size`
if getattr(data_loader, "_is_accelerate_prepared", False):
if data_loader not in self._dataloaders:
self._dataloaders.append(data_loader)
return data_loader
if device_placement is None:
device_placement = (
self.device_placement
if self.distributed_type != DistributedType.XLA
else False
)
prepared_data_loader = custom_prepare_data_loader(
data_loader,
self.device,
num_processes=self.num_processes,
process_index=self.process_index,
split_batches=self.split_batches,
put_on_device=device_placement,
rng_types=self.rng_types.copy(),
dispatch_batches=self.dispatch_batches,
even_batches=self.even_batches,
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
76 changes: 75 additions & 1 deletion chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from torch._tensor import Tensor
from torch.nn.modules import Module
from custom_accelerator import CustomAccelerator
from transformers.utils import is_accelerate_available
from accelerate.utils import GradientAccumulationPlugin

# from torch.distributed.fsdp.fully_sharded_data_parallel import (
# FullyShardedDataParallel as FSDP,
Expand All @@ -15,7 +18,6 @@
from chemlactica.utils.utils import get_tokenizer
from dataclasses import dataclass, field


# if is_torch_tpu_available(check_device=False):
# import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -48,6 +50,78 @@ def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tenso
self.num_samples_to_print = None
return super().training_step(model, inputs)

def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

# create accelerator object
self.accelerator = CustomAccelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`,
# thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = (
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
)
self.is_fsdp_enabled = (
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
)

# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
if is_accelerate_available("0.23.0"):
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if (
fsdp_plugin.activation_checkpointing
and self.args.gradient_checkpointing
):
raise ValueError(
"The activation_checkpointing in FSDP config and "
"the gradient_checkpointing in training arg "
"can't be set to True simultaneously. "
"Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)

if (
self.is_deepspeed_enabled
and getattr(self.args, "hf_deepspeed_config", None) is None
):
self.propagate_args_to_deepspeed()

# `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
if (
self.args.save_only_model
and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
and self.args.load_best_model_at_end
):
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise ValueError(
f"{wrapper} can't be used with `save_only_model` "
"along with `load_best_model_at_end`."
)

# `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP
if (
self.is_deepspeed_enabled or self.is_fsdp_enabled
) and self.args.auto_find_batch_size:
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise NotImplementedError(
f"`{wrapper}` doesn't support `auto_find_batch_size`."
)

def _build_slurm_eval_command(self, train_command, trial):
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
Expand Down
7 changes: 4 additions & 3 deletions chemlactica/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from transformers import (
ProgressCallback,
)
from accelerate import Accelerator, logging, InitProcessGroupKwargs
from accelerate import logging, InitProcessGroupKwargs
from accelerate.utils import broadcast_object_list
from chemlactica.custom_accelerator import CustomAccelerator

from chemlactica.custom_trainer import CustomArguments
from chemlactica.utils.callbacks import (
Expand Down Expand Up @@ -91,7 +92,7 @@ def train(

kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))

accelerator = Accelerator(
accelerator = CustomAccelerator(
kwargs_handlers=[kwargs], log_with="all", project_dir=track_dir
)

Expand Down Expand Up @@ -243,9 +244,9 @@ def train(
num_train_epochs=num_train_epochs,
eval_steps=eval_steps,
save_steps=save_steps,
dispatch_batches=False,
dataloader_drop_last=True,
dataloader_pin_memory=True,
dispatch_batches=False,
# torch_compile=True,
# torch_compile requires to set use_orig_params=true
# which has some conflict with saving checkpoints
Expand Down
Loading
Loading