Feat: Add the support for non-learnable RMS norm for large-scale training in mamba_inner_fn
#543
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi Albert Gu and Tri Dao,
First of all, thank you for this package. We would like to upstream some changes that were needed to train the FalconMamba-7B model using the mamba kernels.
This PR introduces a way to pass non learnable RMS norm weights in order to normalize B, C and dt states as per our training procedure.
Another way could be to initialize
weight
inrms_norm_forward
withtorch.ones_like
, but I'd prefer to force users to pass the non learnable parameters themselves to avoid multiple tensor initialization at each call ofmamba_inner_fn
, there might be a way to call the rms norm forward without having the need to pass RMS weights which I am not sure.On transformers side, we would call the interface with the following:
Thank you very much in advance !
@tridao @albertfgu