diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 18eb5d746..66270cfcb 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -19,6 +19,9 @@ logger: logging.Logger = logging.getLogger() +log_grad_norm: bool = False +use_64bit_grad_norm: bool = False + @unique class GradientClipping(Enum): @@ -59,6 +62,7 @@ def __init__( self._norm_type = norm_type self._check_meta: bool = True self._enable_global_grad_clip = enable_global_grad_clip + self._step_num = 0 # Group parameters by model parallelism process group if global clipping is enabled. # Otherwise, all parameters are treated as replicated and will be clipped locally. @@ -129,6 +133,7 @@ def step(self, closure: Any = None) -> None: torch.nn.utils.clip_grad_value_(self._replicate_params, self._max_gradient) super().step(closure) + self._step_num += 1 @torch.no_grad() def clip_grad_norm_(self) -> None: @@ -165,6 +170,8 @@ def clip_grad_norm_(self) -> None: ) ) + square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0 + # Process replicated parameters and gradients if self._replicate_params: replicated_grads = [ @@ -189,6 +196,22 @@ def clip_grad_norm_(self) -> None: else total_grad_norm + replicated_grad_norm ) ) + square_replicated_grad_norm = replicated_grad_norm + else: + square_replicated_grad_norm = 0 + + global log_grad_norm + if log_grad_norm: + if total_grad_norm is not None and self._norm_type != torch.inf: + # pyre-ignore[58] + grad_norm = total_grad_norm ** (1.0 / norm_type) + else: + grad_norm = 0 + + rank = dist.get_rank() + logger.info( + f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}" + ) # Aggregation if total_grad_norm is None: @@ -212,10 +235,18 @@ def _batch_cal_norm( """Helper function that calculates the norm of a list of gradients in batches. If process_groups are passed in, the norm will be aggregated across all ranks in the process group. """ - grad_norms = torch.linalg.vector_norm( - torch.stack(torch._foreach_norm(grad_list, norm_type)), - norm_type, - ) + + global use_64bit_grad_norm + if use_64bit_grad_norm: + grad_norms = torch.linalg.vector_norm( + torch.stack(torch._foreach_norm(grad_list, norm_type, dtype=torch.float64)), + norm_type, + ) + else: + grad_norms = torch.linalg.vector_norm( + torch.stack(torch._foreach_norm(grad_list, norm_type)), + norm_type, + ) if norm_type == torch.inf: if process_groups is not None: @@ -227,6 +258,9 @@ def _batch_cal_norm( for pg in process_groups: dist.all_reduce(grad_norms, group=pg) + if use_64bit_grad_norm: + grad_norms = grad_norms.to(torch.float32) + return grad_norms