From cd298e30861b960066ba78f76f7fc91a2b444de0 Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:12:54 +0800 Subject: [PATCH] [Feature] Support save_optimizer=False for DeepSpeed (#1474) --- mmengine/_strategy/deepspeed.py | 73 ++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 378616db3d..44e7c2e692 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -6,18 +6,23 @@ import torch +from mmengine.logging import print_log + try: import deepspeed except ImportError: deepspeed = None +import logging + import torch.nn as nn import mmengine -from mmengine.dist import init_dist +from mmengine.dist import init_dist, is_main_process from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS, STRATEGIES) +from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu from mmengine.utils import apply_to, digit_version, get_git_hash from .base import BaseStrategy @@ -506,7 +511,7 @@ def save_checkpoint( """Save checkpoint to given ``filename``. Warning: - `save_optimizer` and `callback` parameters are not supported yet. + `callback` parameter is not supported yet. Args: filename (str): Filename to save checkpoint. @@ -527,25 +532,53 @@ def save_checkpoint( mmengine=mmengine.__version__ + get_git_hash(), ) - if save_optimizer and hasattr(self, 'optim_wrapper'): - # The key can not be 'optimizer', otherwise error will be thrown - # when loading or resuming checkpoint. - extra_ckpt['optim_wrapper'] = self.optim_state_dict() - if save_param_scheduler and hasattr(self, 'param_schedulers'): extra_ckpt['param_schedulers'] = self.scheduler_state_dict() - dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version('0.10.1'): - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False, - exclude_frozen_parameters=self.exclude_frozen_parameters) + if (not save_optimizer + and self.model.zero_optimization_partition_weights() + and not self.model.zero_gather_16bit_weights_on_model_save()): + print_log( + 'Configured to `save_optimizer=False`, but currently using ' + "DeepSpeed's ZeRO stage 3 with " + '`gather_16bit_weights_on_model_save=False`. In ' + 'this configuration, the model cannot be saved properly ' + 'and will be saved with the optimizer state. ' + 'To support `save_optimizer=False`, please set ' + '`gather_16bit_weights_on_model_save=True` in your ' + 'DeepSpeed config.', + logger='current', + level=logging.WARNING) + save_optimizer = True + + if save_optimizer: + if hasattr(self, 'optim_wrapper'): + # The key can not be 'optimizer', otherwise error will be + # thrown when loading or resuming checkpoint. + extra_ckpt['optim_wrapper'] = self.optim_state_dict() + + dirname, basename = osp.split(filename) + if digit_version(deepspeed.__version__) >= digit_version('0.10.1'): + self.model.save_checkpoint( + dirname, + tag=basename, + client_state=extra_ckpt, + save_latest=False, + exclude_frozen_parameters=self.exclude_frozen_parameters) + else: + self.model.save_checkpoint( + dirname, + tag=basename, + client_state=extra_ckpt, + save_latest=False) else: - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False) + if self.model.zero_optimization_partition_weights(): + # TODO: `_zero3_consolidated_16bit_state_dict` doesn't support + # `exclude_frozen_parameters`. + state_dict = self.model._zero3_consolidated_16bit_state_dict() + else: + state_dict = self.model.module_state_dict( + exclude_frozen_parameters=self.exclude_frozen_parameters) + if is_main_process(): + ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt} + save_checkpoint(ckpt, filename)