Skip to content

Commit

Permalink
[Feature] Support save_optimizer=False for DeepSpeed (#1474)
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla authored Jan 24, 2024
1 parent 396cac1 commit cd298e3
Showing 1 changed file with 53 additions and 20 deletions.
73 changes: 53 additions & 20 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,23 @@

import torch

from mmengine.logging import print_log

try:
import deepspeed
except ImportError:
deepspeed = None

import logging

import torch.nn as nn

import mmengine
from mmengine.dist import init_dist
from mmengine.dist import init_dist, is_main_process
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy

Expand Down Expand Up @@ -506,7 +511,7 @@ def save_checkpoint(
"""Save checkpoint to given ``filename``.
Warning:
`save_optimizer` and `callback` parameters are not supported yet.
`callback` parameter is not supported yet.
Args:
filename (str): Filename to save checkpoint.
Expand All @@ -527,25 +532,53 @@ def save_checkpoint(
mmengine=mmengine.__version__ + get_git_hash(),
)

if save_optimizer and hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be thrown
# when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()

if save_param_scheduler and hasattr(self, 'param_schedulers'):
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()

dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
if (not save_optimizer
and self.model.zero_optimization_partition_weights()
and not self.model.zero_gather_16bit_weights_on_model_save()):
print_log(
'Configured to `save_optimizer=False`, but currently using '
"DeepSpeed's ZeRO stage 3 with "
'`gather_16bit_weights_on_model_save=False`. In '
'this configuration, the model cannot be saved properly '
'and will be saved with the optimizer state. '
'To support `save_optimizer=False`, please set '
'`gather_16bit_weights_on_model_save=True` in your '
'DeepSpeed config.',
logger='current',
level=logging.WARNING)
save_optimizer = True

if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
# thrown when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()

dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
if self.model.zero_optimization_partition_weights():
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
# `exclude_frozen_parameters`.
state_dict = self.model._zero3_consolidated_16bit_state_dict()
else:
state_dict = self.model.module_state_dict(
exclude_frozen_parameters=self.exclude_frozen_parameters)
if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)

0 comments on commit cd298e3

Please sign in to comment.