Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add the support for non-learnable RMS norm for large-scale training in mamba_inner_fn #543

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

younesbelkada
Copy link

Hi Albert Gu and Tri Dao,

First of all, thank you for this package. We would like to upstream some changes that were needed to train the FalconMamba-7B model using the mamba kernels.

This PR introduces a way to pass non learnable RMS norm weights in order to normalize B, C and dt states as per our training procedure.

Another way could be to initialize weight in rms_norm_forward with torch.ones_like, but I'd prefer to force users to pass the non learnable parameters themselves to avoid multiple tensor initialization at each call of mamba_inner_fn, there might be a way to call the rms norm forward without having the need to pass RMS weights which I am not sure.

On transformers side, we would call the interface with the following:

        # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
        self.register_buffer("b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False)
        self.register_buffer("dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False)
        self.rms_eps = config.mixer_rms_eps

    def cuda_kernels_forward(
        self,
        hidden_states: torch.Tensor,
        cache_params: Optional[MambaCache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
    ):
        # 1. Gated MLP's linear projection
        projected_states = self.in_proj(hidden_states).transpose(1, 2)
        
        if self.training and cache_params is None:  # Doesn't support outputting the states -> used for training
            contextualized_states = mamba_inner_fn(
                projected_states,
                conv1d_weight=self.conv1d.weight,
                conv1d_bias=self.conv1d.bias if self.use_conv_bias else None,
                x_proj_weight=self.x_proj.weight,
                delta_proj_weight=self.dt_proj.weight,
                out_proj_weight=self.out_proj.weight,
                out_proj_bias=self.out_proj.bias.float() if self.use_bias else None,
                A=-torch.exp(self.A_log.float()),
                B=None,  # input-dependent B
                C=None,  # input-dependent C
                D=self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                b_rms_weight=self.b_c_rms,
                c_rms_weight=self.b_c_rms,
                dt_rms_weight=self.dt_rms,
                b_c_dt_rms_eps=self.rms_eps
            )

Thank you very much in advance !
@tridao @albertfgu

@younesbelkada
Copy link
Author

A simple snippet to reproduce the current issue:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconMambaForCausalLM

model_id = "tiiuae/falcon-mamba-7b"
text = "Hello today we are going to"

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(model_id)

inputs = tok(text, return_tensors="pt").to(0)

with torch.no_grad():
    logits = torch.argmax(model(**inputs).logits, dim=-1)
    
print(tok.batch_decode(logits))

model.train()
lm_logits = model(**inputs).logits
next_token = torch.argmax(lm_logits, dim=-1)
    
print(tok.batch_decode(logits))
loss = (1 - lm_logits).mean()
loss.backward()

@younesbelkada
Copy link
Author

younesbelkada commented Aug 29, 2024

Hi @tridao @albertfgu
I made an alternative PR in HF transformers: huggingface/transformers#33195 where I simply copied over the kernels there. Let me know if you see any issue potentially merging this PR in mamba-ssm - thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants