You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.
hey there, seeing this a bit late, no context from me really I guess that it was a type "fix" at some point
edit: I don´t really understand your link @min-xu-ai, the linked commit made sure that the norm was computed in fp32 locally (even if the type fp16 for instance), but this is not what @glample is suggesting here, right ? I'm a bit lost with this PR title
edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.
edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.
Yes, that's what I meant. Thanks a lot for the context, Ben!
Copied from: https://github.com/fairinternal/xlformers/issues/117
Shouldn't we remove the
.to(dtype=parameters[0].dtype)
from this line?fairscale/fairscale/internal/params.py
Line 75 in ee647b9
It seems weird (and it results in inaccuracies) to convert partial gradient norms to
fp16
/bf16
before summing them.Context:
We use:
fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Line 621 in ee647b9
which calculates grad norms via:
fairscale/fairscale/internal/params.py
Line 59 in ee647b9
which downcasts to param dtype via:
fairscale/fairscale/internal/params.py
Line 75 in ee647b9
before the allreduce:
fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Line 672 in ee647b9
Spotted from looking at how unusually even grad norms look at each training step:
The text was updated successfully, but these errors were encountered: