Skip to content

Commit

Permalink
Log the grad norms in clipping.py (#2489)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2489

Log the grad norms for debugging purpose.

Reviewed By: iamzainhuda

Differential Revision: D64275984

fbshipit-source-id: 83326a7fe6f1684b494c50d0599248ecb7c93e9b
  • Loading branch information
ckluk2 authored and facebook-github-bot committed Oct 17, 2024
1 parent bce8ae3 commit 54ec8aa
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

logger: logging.Logger = logging.getLogger()

log_grad_norm: bool = False
use_64bit_grad_norm: bool = False


@unique
class GradientClipping(Enum):
Expand Down Expand Up @@ -59,6 +62,7 @@ def __init__(
self._norm_type = norm_type
self._check_meta: bool = True
self._enable_global_grad_clip = enable_global_grad_clip
self._step_num = 0

# Group parameters by model parallelism process group if global clipping is enabled.
# Otherwise, all parameters are treated as replicated and will be clipped locally.
Expand Down Expand Up @@ -129,6 +133,7 @@ def step(self, closure: Any = None) -> None:
torch.nn.utils.clip_grad_value_(self._replicate_params, self._max_gradient)

super().step(closure)
self._step_num += 1

@torch.no_grad()
def clip_grad_norm_(self) -> None:
Expand Down Expand Up @@ -165,6 +170,8 @@ def clip_grad_norm_(self) -> None:
)
)

square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0

# Process replicated parameters and gradients
if self._replicate_params:
replicated_grads = [
Expand All @@ -189,6 +196,22 @@ def clip_grad_norm_(self) -> None:
else total_grad_norm + replicated_grad_norm
)
)
square_replicated_grad_norm = replicated_grad_norm
else:
square_replicated_grad_norm = 0

global log_grad_norm
if log_grad_norm:
if total_grad_norm is not None and self._norm_type != torch.inf:
# pyre-ignore[58]
grad_norm = total_grad_norm ** (1.0 / norm_type)
else:
grad_norm = 0

rank = dist.get_rank()
logger.info(
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
)

# Aggregation
if total_grad_norm is None:
Expand All @@ -212,10 +235,18 @@ def _batch_cal_norm(
"""Helper function that calculates the norm of a list of gradients in batches. If process_groups
are passed in, the norm will be aggregated across all ranks in the process group.
"""
grad_norms = torch.linalg.vector_norm(
torch.stack(torch._foreach_norm(grad_list, norm_type)),
norm_type,
)

global use_64bit_grad_norm
if use_64bit_grad_norm:
grad_norms = torch.linalg.vector_norm(
torch.stack(torch._foreach_norm(grad_list, norm_type, dtype=torch.float64)),
norm_type,
)
else:
grad_norms = torch.linalg.vector_norm(
torch.stack(torch._foreach_norm(grad_list, norm_type)),
norm_type,
)

if norm_type == torch.inf:
if process_groups is not None:
Expand All @@ -227,6 +258,9 @@ def _batch_cal_norm(
for pg in process_groups:
dist.all_reduce(grad_norms, group=pg)

if use_64bit_grad_norm:
grad_norms = grad_norms.to(torch.float32)

return grad_norms


Expand Down

0 comments on commit 54ec8aa

Please sign in to comment.