Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement bi-directionality #52

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

yair-schiff
Copy link
Contributor

@yair-schiff yair-schiff commented Dec 13, 2023

Edit:

  • Implement bi-directionality by applying Mamba module twice: (1) to the forward sequence and (2) to the backward sequence.
  • Implement 3 2 strategies for combining forward / backward Mamba hidden states:
    1. add: Add the states.
    2. concat: Concatenate the states. This doubles the hidden dimension,d_model, which also prevents weight tying between embedding and lm_head weights.
    3. ew_multiply: perform element-wise multiplication between the states.

Copy link

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nits

mamba_ssm/models/mixer_seq_simple.py Show resolved Hide resolved
@sentialx
Copy link

What if the sequences have paddings? E.g.
Input is
[1 2 3 0 0 0]
So flipped input would be
[0 0 0 3 2 1].
Shouldn't it be
[3 2 1 0 0 0]?

@yair-schiff
Copy link
Contributor Author

@sentialx , agreed. That's a good catch.

@jimmieliu
Copy link

how the speed compares to uni-directional?

@yair-schiff
Copy link
Contributor Author

yair-schiff commented Jan 3, 2024

how the speed compares to uni-directional?

@jimmieliu, it's about 2x

@albertfgu albertfgu mentioned this pull request Jan 11, 2024
@pengzhangzhi
Copy link

@yair-schiff I am just curious, did you solve the

What if the sequences have paddings? E.g. Input is [1 2 3 0 0 0] So flipped input would be [0 0 0 3 2 1]. Shouldn't it be [3 2 1 0 0 0]?

Just curious, is this problem solved?

@pengzhangzhi
Copy link

I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.

given: x
out = x + f(x.flip()).flip()

@xuanwuji
Copy link

I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.

given: x
out = x + f(x.flip()).flip()

Hi, Your approach is clever! But I have a question: if you flip the input to [0,0,1,2,3], does the padding in front of it affect sequence hidden features learning? i.e., does it produce a different result(bad repersentation of sequence) than the input of [3,2,1,0,0]?
I don't know enough about it, could you possibly give me some guidance? This will help me a lot. Thank you very much!

@Museum7432
Copy link

Museum7432 commented Jul 14, 2024

@xuanwuji well, you can remove the leading paddings by shifting each row of x before flipping x. As for its effect, since the hidden state is initialized with 0, it should still be filled with 0 after scanning through the paddings. So, those padding shouldn't have any effect on the result. However, you can use the following function just to be sure.

def flip_padded_hidden_states(hidden_states, seq_lens):
    batch_size, seq_len, hidden_dim = hidden_states.shape

    indices = torch.arange(batch_size * seq_len, device=hidden_states.device).reshape(
        batch_size, seq_len
    )

    indices_offset = seq_len - seq_lens

    indices = (indices - indices_offset.unsqueeze(1)) % (seq_len * batch_size)

    indices = indices.flip(1)

    return hidden_states.reshape(batch_size * seq_len, hidden_dim)[indices]

To check the effect of paddings:

import torch
from mamba_ssm import Mamba2, Mamba
from torch.nn import functional as F

batch, length, dim = 2, 64, 16

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

x = torch.randn(batch, length, dim).to("cuda")
padded_x = F.pad(x, (0,0, 4,0))

y = model(x)
padded_y = model(padded_x)

unpadded_y = padded_y[:,4:]

print(f'Output max diff: {(unpadded_y - y).abs().max().item()}')
print(f'Output mean diff: {(unpadded_y - y).abs().mean().item()}')

However, these errors do stack after multiple layers, so you should use the flip_padded_hidden_states function just to be certain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants