Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not wrap LoRA layers with FSDP #1538

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

janEbert
Copy link
Contributor

When wrapping the full Transformer Block, FSDP wraps both trainable and non-trainable parameters. Because of how FSDP is implemented, this results in way higher memory consumption, making the memory savings from LoRA meaningless.

By instead wrapping only torch.nn.Linear modules, we still make use of FSDP but avoid wrapping the LoRA layers.

For clarity, this is the type of warning given by PyTorch for the old code:

/llm-finetuning-example/env/lib/python3.11/site-packages/torch/distributed/fsdp/_wrap_utils.py:174: UserWarning: transformer.h.0 has both parameters with requires_grad=True and False. We do not recommend wrapping such modules since the gradient memory usage will be higher than expected (1451376640 numel instead of 106496 numel before sharding via reduce-scatter). If possible, wrap the frozen parameters with FSDP separately.
The following parameters have requires_grad=True:
['transformer.h.0.attn.attn.lora_A', 'transformer.h.0.attn.attn.lora_B']
The following parameters have requires_grad=False:
['transformer.h.0.norm_1.weight', [...]]

When wrapping the full Transformer `Block`, FSDP wraps both trainable
and non-trainable parameters. This results in way higher memory
consumption, making the memory savings from LoRA meaningless.

By instead wrapping only `torch.nn.Linear` modules, we still make use of
FSDP but avoid wrapping the LoRA layers.
@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

Thanks for the update @janEbert ! This looks good to me. Btw have you done a comparison (re memory usage) before and after by chance?

@janEbert
Copy link
Contributor Author

janEbert commented Jul 2, 2024

I have not, it was only OOM vs. no OOM. 😅
I did try to solve the same problem in a different way using FSDP's ignored_states argument and then manually initializing the LoRA parameters (necessary because of the meta device initialization). However, I ran into similar OOMs with that method, although I could avoid the warning message.

I can supply some comparisons if that would help!

@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

I see, yeah I think we should do some comparisons to make sure it works as intended. If you want to do them, that'd be nice! I suggest perhaps with a small model (phi-2 or so) and a medium-sized model (e.g., Llama 3 8B) at least

@janEbert
Copy link
Contributor Author

janEbert commented Jul 2, 2024

Will do that in the coming days!

@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

That'd be awesome. And pls let me know in case you need any help!

@janEbert
Copy link
Contributor Author

janEbert commented Jul 3, 2024

Forgot to print GPU ranks, but the point should be clear. :)

Before the change:

Memory used: 18.47 GB
Memory used: 18.47 GB
Memory used: 18.50 GB
Memory used: 18.47 GB

After the change:

Memory used: 10.35 GB
Memory used: 10.35 GB
Memory used: 10.35 GB
Memory used: 10.38 GB

Is this helpful enough or would you like to see more detailed stats?

@williamFalcon
Copy link
Contributor

cc @awaelchli

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

@janEbert Looks awesome, which model is that?

I am also rerunning some of the models in the config hub and will update the numbers accordingly!

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

I just ran a quick comparison on an 4xA10G machine to see if I can reproduce the config hub performance

| falcon-7b/lora.yaml               | falcon-7b              | 4      | 512            | 1                | 4xA10G  | 24.94 min        | $2.0 | 16.69 GB    | 0.945           | 2.573                 | 26.4%        

For some reason, it's been really slow, but it can be a machine issue I have to look into. But here hare some numbers I am getting for the code in this PR:

Epoch 1 | iter 1 step 0 | loss train: 2.299, val: n/a | iter time: 9980.09 ms
Epoch 1 | iter 2 step 1 | loss train: 9.848, val: n/a | iter time: 10256.36 ms (step)
Epoch 1 | iter 3 step 1 | loss train: 14.262, val: n/a | iter time: 10179.08 ms
Epoch 1 | iter 4 step 2 | loss train: 11.866, val: n/a | iter time: 10034.30 ms (step)
Epoch 1 | iter 5 step 2 | loss train: 14.844, val: n/a | iter time: 10490.95 ms
Epoch 1 | iter 6 step 3 | loss train: 17.512, val: n/a | iter time: 10738.12 ms (step)
Epoch 1 | iter 7 step 3 | loss train: 14.514, val: n/a | iter time: 10573.39 ms
Epoch 1 | iter 8 step 4 | loss train: 11.069, val: n/a | iter time: 10545.70 ms (step)
Epoch 1 | iter 9 step 4 | loss train: 11.084, val: n/a | iter time: 10077.49 ms
Epoch 1 | iter 10 step 5 | loss train: 11.105, val: n/a | iter time: 10593.24 ms (step)

If I compare it to the performance before, I notice that the non-step steps are about 30% slower. E.g. from the main branch:

Epoch 1 | iter 1 step 0 | loss train: 2.299, val: n/a | iter time: 7121.10 ms
Epoch 1 | iter 2 step 1 | loss train: 2.163, val: n/a | iter time: 10766.50 ms (step)
Epoch 1 | iter 3 step 1 | loss train: 1.705, val: n/a | iter time: 7182.87 ms
Epoch 1 | iter 4 step 2 | loss train: 1.960, val: n/a | iter time: 10861.94 ms (step)
Epoch 1 | iter 5 step 2 | loss train: 1.891, val: n/a | iter time: 7214.43 ms
Epoch 1 | iter 6 step 3 | loss train: 1.468, val: n/a | iter time: 10887.67 ms (step)
Epoch 1 | iter 7 step 3 | loss train: 1.626, val: n/a | iter time: 7162.67 ms
Epoch 1 | iter 8 step 4 | loss train: 1.167, val: n/a | iter time: 10712.47 ms (step)
Epoch 1 | iter 9 step 4 | loss train: 1.764, val: n/a | iter time: 7206.13 ms
Epoch 1 | iter 10 step 5 | loss train: 2.023, val: n/a | iter time: 10798.46 ms (step)
Epoch 1 | iter 11 step 5 | loss train: 1.892, val: n/a | iter time: 7213.32 ms
Epoch 1 | iter 12 step 6 | loss train: 2.678, val: n/a | iter time: 10819.67 ms (step)
Epoch 1 | iter 13 step 6 | loss train: 2.245, val: n/a | iter time: 7164.10 ms

Just curious, have you observed something similar? (In this case we maybe could also think about a "optimize runtime|memory" setting here.

For comparison, the single-GPU speeds:

Epoch 1 | iter 1 step 0 | loss train: 1.281, val: n/a | iter time: 562.95 ms
Epoch 1 | iter 2 step 0 | loss train: 1.338, val: n/a | iter time: 173.90 ms
Epoch 1 | iter 3 step 0 | loss train: 1.738, val: n/a | iter time: 110.99 ms
Epoch 1 | iter 4 step 0 | loss train: 1.681, val: n/a | iter time: 269.70 ms
Epoch 1 | iter 5 step 0 | loss train: 1.830, val: n/a | iter time: 160.60 ms
Epoch 1 | iter 6 step 0 | loss train: 1.871, val: n/a | iter time: 99.58 ms
Epoch 1 | iter 7 step 0 | loss train: 1.870, val: n/a | iter time: 93.07 ms
Epoch 1 | iter 8 step 1 | loss train: 1.775, val: n/a | iter time: 319.57 ms (step)
Epoch 1 | iter 9 step 1 | loss train: 1.873, val: n/a | iter time: 94.81 ms
Epoch 1 | iter 10 step 1 | loss train: 1.797, val: n/a | iter time: 329.24 ms
Epoch 1 | iter 11 step 1 | loss train: 1.693, val: n/a | iter time: 258.53 ms

So I am thinking this could be due to a slow interconnect at the GPUs. I will look into it and do some more experiments.

@Andrei-Aksionov
Copy link
Collaborator

Why does the loss train increases (for the code from this PR)?
From 2.299 up to 17.512.

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

Not sure. I observed it with Phi-2 too:

Main branch:

litgpt finetune_lora checkpoints/microsoft/phi-2/ --devices 4
Epoch 1 | iter 1 step 0 | loss train: 2.424, val: n/a | iter time: 5537.97 ms
Epoch 1 | iter 2 step 0 | loss train: 2.519, val: n/a | iter time: 5578.44 ms
Epoch 1 | iter 3 step 0 | loss train: 2.646, val: n/a | iter time: 5563.88 ms
Epoch 1 | iter 4 step 1 | loss train: 2.516, val: n/a | iter time: 6942.96 ms (step)
Epoch 1 | iter 5 step 1 | loss train: 2.467, val: n/a | iter time: 5483.51 ms

PR branch:

litgpt finetune_lora checkpoints/microsoft/phi-2/ --devices 4
Epoch 1 | iter 1 step 0 | loss train: 2.424, val: n/a | iter time: 7818.10 ms
Epoch 1 | iter 2 step 0 | loss train: 6.647, val: n/a | iter time: 8075.87 ms
Epoch 1 | iter 3 step 0 | loss train: 8.358, val: n/a | iter time: 7731.87 ms
Epoch 1 | iter 4 step 1 | loss train: 9.103, val: n/a | iter time: 7654.43 ms (step)
Epoch 1 | iter 5 step 1 | loss train: 11.207, val: n/a | iter time: 7824.40 ms

Something I need to investigate more in the next few days. I'll try this also on a different machine since I think the A10G machine has very slow GPU connections.

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

Why does the loss train increases (for the code from this PR)? From 2.299 up to 17.512.

I am curious if the whole Block was maybe accidentally trainable (instead of just the LoRA linear layers) before, which could explain the sharper loss decrease. But we should have tests for that, and I need to double-check that with a debugger. Just leaving this here as a note to myself so I can pick it up next week.

@janEbert
Copy link
Contributor Author

janEbert commented Jul 3, 2024

Funny enough, in an entirely unrelated example, I've also noticed PyTorch Distributed becoming increasingly less reproducible for slightly changed settings the higher the PyTorch version. Could that maybe be the case here as well? Do you get a reproducible loss when running the same version of the code?

@janEbert Looks awesome, which model is that?

It's Mistral-7B-Instruct-v0.3 on a very small (4 samples) dummy JSON dataset, global batch size = 4, all other settings default.

@janEbert
Copy link
Contributor Author

janEbert commented Jul 3, 2024

BTW my iteration speed is also slightly slower. I'll check if the version with ignored_states performs better tomorrow.

@rasbt
Copy link
Collaborator

rasbt commented Jul 3, 2024

That's a good point, but I think there is a different issue here that I am not understanding yet 😅. When I reran the code I observed basically the same higher loss. That's also independent of the model I tried.

@TensorTemplar
Copy link
Contributor

I ran into this as well but looking at the code of the lora blocks we will not be able to avoid mixed gradients within one block without rewriting the code there to not include frozen layers in the same block. If we wrap only frozen linear layers it is not clear at all how the optimizer (still fully sharded) is updating - maybe it doesn't properly, hence the loss not going down or maybe it does but who knows how the memory access is working in that case, especially with multiple nodes.

@rasbt
Copy link
Collaborator

rasbt commented Oct 7, 2024

Thanks for looking into this @TensorTemplar . I think that this may not be feasible then, so I am closing the PR for now. But happy to revisit this with other solutions in the future.

@rasbt rasbt closed this Oct 7, 2024
@janEbert
Copy link
Contributor Author

janEbert commented Oct 9, 2024

If this is indeed what's happening, this would be a substantial oversight and would be really important to raise in PyTorch!
However, I'm cautious about accepting the explanation just yet: the optimizer only receives the model's non-sharded trainable parameters (via model.parameters()), so the frozen layers and the effect of sharding should be fully ignored here by the optimizer, at least in theory. There could be a bug in multiple locations though, that would lead to the sharding not being ignored for some reason. Maybe the hooks aren't set up correctly, maybe there's a torch.nn.Parameter._fsdp_flattened being queried with a wrong assumption, etc...

Thanks a lot for giving this thought, though! I'm sorry I never managed to follow up.

@rasbt
Copy link
Collaborator

rasbt commented Oct 9, 2024

I haven't looked into the implementations (and even explanations), but just a spontaneous thought that we should perhaps retry this with FSDPv2

@rasbt rasbt reopened this Oct 9, 2024
@TensorTemplar
Copy link
Contributor

If this is indeed what's happening, this would be a substantial oversight and would be really important to raise in PyTorch! However, I'm cautious about accepting the explanation just yet: the optimizer only receives the model's non-sharded trainable parameters (via model.parameters()), so the frozen layers and the effect of sharding should be fully ignored here by the optimizer, at least in theory. There could be a bug in multiple locations though, that would lead to the sharding not being ignored for some reason. Maybe the hooks aren't set up correctly, maybe there's a torch.nn.Parameter._fsdp_flattened being queried with a wrong assumption, etc...

Thanks a lot for giving this thought, though! I'm sorry I never managed to follow up.

Sorry to clarify i was referring to the original {block} wrap, not the linear layer approach, though given where these layers are, it may still have similar effects

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants