-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA OOM when saving checkpoint (in consolidate_state_dict()) using OSS #973
Comments
My fairscale version is 0.4.6, my PyTorch version is 1.11.0+cu113, and my PyTorch Lightning version is 1.6.1. |
It is the same as this issue AFAICT: huggingface/transformers#14542 |
Except that I can't fix it by setting |
Do you have a full trace back for the OOM crash? Pasting it in a gist is fine. |
Getting this as well |
I fixed this by enabling |
I am experiencing CUDA out of memory crashes when consolidating my optimizer state dict before saving it. I am training on 32 40GB A100s, four nodes with eight GPUs each, using PyTorch Lightning's
'ddp_sharded'
strategy, which is OSS. I get the OOM crash in the middle of running consolidate_state_dict(). I have tried addingdel
statements,gc.collect()
andtorch.cuda.empty_cache()
inside the loop to no avail. I am using a custom optimizer class, a modified AdamW that also saves an exponential moving average of the weights, and I need optimizer state sharding because the extra memory overhead for the EMA weights is so onerous. Here is the custom optimizer code: https://gist.github.com/crowsonkb/ea0ed1f6e88594046c72735f3cef1d05. I don't understand how I am running out of GPU memory partway through consolidate_state_dict() (I put in print statements and it got through 27 of 32 ranks) since it moves the tensors to CPU after each broadcast. I am using NCCL so it has to broadcast on GPU but it copies to CPU right afterwards.Thank you,
Katherine Crowson
The text was updated successfully, but these errors were encountered: