From 39ed23fae87267de8ef044a94a597a6904fe7f85 Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:25:54 +0800 Subject: [PATCH] [Enhance] Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` (#1517) --- mmengine/_strategy/deepspeed.py | 41 +++++++++++++++------------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 44e7c2e692..3f89ff760d 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -311,8 +311,8 @@ def __init__( self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half assert (exclude_frozen_parameters is None or - digit_version(deepspeed.__version__) >= digit_version('0.10.1') - ), ('DeepSpeed >= 0.10.1 is required to enable ' + digit_version(deepspeed.__version__) >= digit_version('0.13.2') + ), ('DeepSpeed >= 0.13.2 is required to enable ' 'exclude_frozen_parameters') self.exclude_frozen_parameters = exclude_frozen_parameters @@ -430,7 +430,7 @@ def load_checkpoint( self.logger.info(f'Load checkpoint from {filename}') dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version('0.10.1'): + if digit_version(deepspeed.__version__) >= digit_version('0.13.2'): _, extra_ckpt = self.model.load_checkpoint( dirname, tag=basename, @@ -468,7 +468,7 @@ def resume( self.logger.info(f'Resume checkpoint from {filename}') dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version('0.10.1'): + if digit_version(deepspeed.__version__) >= digit_version('0.13.2'): _, extra_ckpt = self.model.load_checkpoint( dirname, tag=basename, @@ -551,6 +551,11 @@ def save_checkpoint( level=logging.WARNING) save_optimizer = True + state_dict_kwargs = {} + if digit_version(deepspeed.__version__) >= digit_version('0.13.2'): + state_dict_kwargs[ + 'exclude_frozen_parameters'] = self.exclude_frozen_parameters + if save_optimizer: if hasattr(self, 'optim_wrapper'): # The key can not be 'optimizer', otherwise error will be @@ -558,27 +563,19 @@ def save_checkpoint( extra_ckpt['optim_wrapper'] = self.optim_state_dict() dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version('0.10.1'): - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False, - exclude_frozen_parameters=self.exclude_frozen_parameters) - else: - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False) + self.model.save_checkpoint( + dirname, + tag=basename, + client_state=extra_ckpt, + save_latest=False, + **state_dict_kwargs) else: if self.model.zero_optimization_partition_weights(): - # TODO: `_zero3_consolidated_16bit_state_dict` doesn't support - # `exclude_frozen_parameters`. - state_dict = self.model._zero3_consolidated_16bit_state_dict() + state_dict = self.model._zero3_consolidated_16bit_state_dict( + **state_dict_kwargs) else: - state_dict = self.model.module_state_dict( - exclude_frozen_parameters=self.exclude_frozen_parameters) + state_dict = self.model.module_state_dict(**state_dict_kwargs) + if is_main_process(): ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt} save_checkpoint(ckpt, filename)