Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for left padding and masking in forward() and generate() #70

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mamba_ssm/models/config_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class MambaConfig:
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
use_fast_path: bool = True
18 changes: 11 additions & 7 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def create_block(
layer_idx=None,
device=None,
dtype=None,
use_fast_path=True,
):
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
mixer_cls = partial(Mamba, layer_idx=layer_idx, use_fast_path=use_fast_path, **ssm_cfg, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
residual_in_fp32=False,
device=None,
dtype=None,
use_fast_path=True,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
use_fast_path=use_fast_path,
**factory_kwargs,
)
for i in range(n_layer)
Expand All @@ -148,12 +151,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs)
for i, layer in enumerate(self.layers)
}

def forward(self, input_ids, inference_params=None):
def forward(self, input_ids, mask=None, inference_params=None):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
hidden_states, residual, mask=mask, inference_params=inference_params
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
Expand Down Expand Up @@ -205,6 +208,7 @@ def __init__(
initializer_cfg=initializer_cfg,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
use_fast_path=config.use_fast_path,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
Expand All @@ -225,12 +229,12 @@ def tie_weights(self):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
def forward(self, input_ids, attention_mask=None, position_ids=None, inference_params=None, num_last_tokens=0):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
hidden_states = self.backbone(input_ids, mask=attention_mask, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
Expand All @@ -240,8 +244,8 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
config = MambaConfig(**config_data, **kwargs)
model = cls(config, device=device, dtype=dtype)
Comment on lines +247 to +248

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you make this change with the kwargs?

model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
return model

Expand Down
14 changes: 11 additions & 3 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(

self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

def forward(self, hidden_states, inference_params=None):
def forward(self, hidden_states, mask=None, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
Expand Down Expand Up @@ -156,10 +156,15 @@ def forward(self, hidden_states, inference_params=None):
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
mask=mask,
delta_softplus=True,
)
else:
x, z = xz.chunk(2, dim=1)

if mask is not None:
x = x * mask.unsqueeze(1)

# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
Expand All @@ -176,6 +181,9 @@ def forward(self, hidden_states, inference_params=None):
activation=self.activation,
)

if mask is not None:
x = x * mask.unsqueeze(1)

# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
Expand Down Expand Up @@ -322,7 +330,7 @@ def __init__(
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
self, hidden_states: Tensor, residual: Optional[Tensor] = None, mask: Optional[Tensor] = None, inference_params=None
):
r"""Pass the input through the encoder layer.

Expand All @@ -346,7 +354,7 @@ def forward(
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
hidden_states = self.mixer(hidden_states, mask=mask, inference_params=inference_params)
return hidden_states, residual

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Expand Down
10 changes: 7 additions & 3 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
"""
xz: (batch, dim, seqlen)
"""
Expand All @@ -177,6 +177,8 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
if mask is not None:
x = x * mask.unsqueeze(1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
# We're being very careful here about the layout, to avoid extra transposes.
Expand Down Expand Up @@ -214,6 +216,8 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
C = C.contiguous()
if D is not None:
D = D.contiguous()
if mask is not None:
conv1d_out = conv1d_out * mask.unsqueeze(1)
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
Expand Down Expand Up @@ -301,11 +305,11 @@ def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
C_proj_bias=None, mask=None, delta_softplus=True
):
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
Copy link
Contributor

@zigzagcai zigzagcai Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some confusion about this line of code. Since MambaInnerFn doesn't provides parameters option for mask, it seems that it has no effect one the fwd and bwd pass.
Hence, how can the mask be applied to mark the sequence boundaries?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but the PR modified Line 162 in the very same file, which is the definition of MambaInnerFn.forward method. Also, multiplications on Line 181 and 220 used mask.

Copy link
Contributor

@zigzagcai zigzagcai Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! You are right. I might missed this line of code on L162.
But I see in this PR, attention_mask is only used in the forward pass, and seem not to be used in the backward pass. So when I tried to feed batch data with left padding and masking (batch_size, seq_len, hidden_dim) into mamba block , it reported error. Has anyone encountered a similar error?

Err Msg:

  File "/blahblah/miniconda3/envs/dev/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/blahblah/miniconda3/envs/dev/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: function MambaInnerFnBackward returned an incorrect number of gradients (expected 16, got 15)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xtwigs pointed out that one can add None in the returned gradients to fix the issue. (For fellows above who wonder where to put None, I did it at the end of the tuple.) I agree with @xtwigs as we don't calculate gradient on the mask tensor.

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

Copy link
Contributor

@zigzagcai zigzagcai Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xtwigs pointed out that one can add None in the returned gradients to fix the issue. (For fellows above who wonder where to put None, I did it at the end of the tuple.) I agree with @xtwigs as we don't calculate gradient on the mask tensor.

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

I have tried this approach but encountered with CUDA OOM, even with much more GPUs and much smaller seq_len. (8x nodes, 64x A100 GPUs, and seq_len=512 for 1.4B mamba model)

out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, mask, delta_softplus)


def mamba_inner_ref(
Expand Down
16 changes: 11 additions & 5 deletions mamba_ssm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def decode(
input_ids,
model,
max_length,
attention_mask=None,
top_k=1,
top_p=0.0,
temperature=1.0,
Expand Down Expand Up @@ -171,10 +172,11 @@ def get_logits(input_ids, inference_params):
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
attention_mask=attention_mask, # mask is not used in incremental step() calls, so don't update
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset
input_ids, attention_mask, position_ids, inference_params.seqlen_offset
).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits

Expand Down Expand Up @@ -234,6 +236,7 @@ def generate(
self,
input_ids,
max_length,
attention_mask=None,
top_k=1,
top_p=0.0,
temperature=1.0,
Expand All @@ -242,7 +245,7 @@ def generate(
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
input_ids, self, max_length, attention_mask=attention_mask, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
Expand Down Expand Up @@ -312,9 +315,9 @@ def update_graph_cache(
n_warmups=n_warmups,
)

def dispatch(input_ids, position_ids, seqlen):
def dispatch(input_ids, attention_mask, position_ids, seqlen):
batch_size, decoding_seqlen = input_ids.shape[:2]
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
return cache.callables[batch_size, decoding_seqlen](input_ids, attention_mask, position_ids, seqlen)

cache.run = dispatch
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
Expand All @@ -326,6 +329,7 @@ def capture_graph(
):
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
attention_mask = torch.full((batch_size, decoding_seqlen), 1, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
seqlen_offset_og = inference_params.seqlen_offset
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
Expand All @@ -338,6 +342,7 @@ def capture_graph(
for _ in range(n_warmups):
logits = model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=decoding_seqlen,
Expand All @@ -355,12 +360,13 @@ def capture_graph(
with torch.cuda.graph(graph, pool=mempool):
logits = model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=decoding_seqlen,
).logits

def run(new_input_ids, new_position_ids, seqlen):
def run(new_input_ids, attention_mask, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen
input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from transformers import AutoTokenizer

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel


model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-1.4b', use_fast_path=True).to('cuda')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token

pad_count = 10

# Check prefill logits
input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda')
input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0] * pad_count]), input_ids], dim=1)
attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0] * pad_count]), torch.ones_like(input_ids)], dim=1)

out = model(input_ids_padded).logits.detach().cpu()
out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu()
out_true = model(input_ids).logits.detach().cpu()

print("max L2 error:", (out_true - out[:, pad_count:]).norm(dim=-1).max())
print("max L2 errors (padded):", (out_true - out_padded[:, pad_count:]).norm(dim=-1).max())


# Check decoding outputs
text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit.'

print("\n\nNo CUDA graph:")
inputs = tokenizer([text], return_tensors='pt').to('cuda')
x = model.generate(inputs.input_ids, max_length=100, temperature=0, cg=False)
print("\nNo pad, no mask:")
print(tokenizer.decode(x[0], skip_special_tokens=True))

inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda')
x = model.generate(inputs.input_ids, max_length=100 + pad_count, temperature=0, cg=False)
print("\nPad, no mask:")
print(tokenizer.decode(x[0], skip_special_tokens=True))

inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda')
inputs.attention_mask[:, :pad_count] = 0
x = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100 + pad_count, temperature=0, cg=False)
print("\nPad, mask:")
print(tokenizer.decode(x[0], skip_special_tokens=True))

print("\n\nCUDA graph:")
inputs = tokenizer([text], return_tensors='pt').to('cuda')
x = model.generate(inputs.input_ids, max_length=100, temperature=0, cg=True)
print("\nNo pad, no mask:")
print(tokenizer.decode(x[0], skip_special_tokens=True))

inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda')
x = model.generate(inputs.input_ids, max_length=100 + pad_count, temperature=0, cg=True)
print("\nPad, no mask:")
print(tokenizer.decode(x[0], skip_special_tokens=True))

inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda')
inputs.attention_mask[:, :pad_count] = 0
x = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100 + pad_count, temperature=0, cg=True)
print("\nPad, mask:")
print(tokenizer.decode(x[0], skip_special_tokens=True))