diff --git a/nemo/collections/llm/recipes/optim/adam.py b/nemo/collections/llm/recipes/optim/adam.py index b5a60b6f8b3f..2010885f09bd 100644 --- a/nemo/collections/llm/recipes/optim/adam.py +++ b/nemo/collections/llm/recipes/optim/adam.py @@ -30,6 +30,7 @@ def distributed_fused_adam_with_cosine_annealing( max_lr: float = 1e-4, min_lr: Optional[float] = None, clip_grad: float = 1.0, + use_precision_aware_optimizer: bool = False, ) -> run.Config[PytorchOptimizerModule]: opt_cfg = run.Config( @@ -45,6 +46,8 @@ def distributed_fused_adam_with_cosine_annealing( use_distributed_optimizer=True, clip_grad=clip_grad, ) + if hasattr(opt_cfg, "use_precision_aware_optimizer"): + opt_cfg.use_precision_aware_optimizer = use_precision_aware_optimizer min_lr = min_lr if min_lr is not None else (0.1 * max_lr) sched = run.Config( diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 24b2b20b81be..f9dabd5c0f38 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -625,6 +625,10 @@ def setup_megatron_optimization(self, optim_config: Union[Dict[str, Any], DictCo 'overlap_param_gather_with_optimizer_step', False ), ) + if hasattr(megatron_optim_config, 'use_precision_aware_optimizer'): + megatron_optim_config.use_precision_aware_optimizer = self.cfg.optim.get( + 'use_precision_aware_optimizer', False + ) return megatron_optim_config def setup_optimization(