diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py index ffd31abc..29dd4b11 100644 --- a/mamba_ssm/models/config_mamba.py +++ b/mamba_ssm/models/config_mamba.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Union @dataclass @@ -12,3 +13,5 @@ class MambaConfig: residual_in_fp32: bool = True fused_add_norm: bool = True pad_vocab_size_multiple: int = 8 + bidirectional: bool = False + bidirectional_strategy: Union[str, None] = None diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 5b3ddfcf..aeaf2849 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -4,6 +4,7 @@ from functools import partial import json import os +from typing import Optional from collections import namedtuple @@ -29,13 +30,19 @@ def create_block( residual_in_fp32=False, fused_add_norm=False, layer_idx=None, + bidirectional=False, + bidirectional_strategy=None, device=None, dtype=None, ): if ssm_cfg is None: ssm_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + bidirectional_kwargs = { + "bidirectional": bidirectional, + "bidirectional_strategy": bidirectional_strategy, + } + mixer_cls = partial(MambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) @@ -83,6 +90,53 @@ def _init_weights( p /= math.sqrt(n_residuals_per_layer * n_layer) +class MambaWrapper(nn.Module): + """Thin wrapper around Mamba to support bi-directionality.""" + def __init__( + self, + d_model: int, + bidirectional: bool = False, + bidirectional_strategy: Optional[str] = None, + **mamba_kwargs, + ): + super().__init__() + if bidirectional and bidirectional_strategy is None: + bidirectional_strategy = "add" # Default strategy: `add` + if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]: + raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!") + self.bidirectional = bidirectional + self.bidirectional_strategy = bidirectional_strategy + self.mamba_fwd = Mamba( + d_model=d_model, + **mamba_kwargs + ) + if bidirectional: + self.mamba_rev = Mamba( + d_model=d_model, + **mamba_kwargs + ) + else: + self.mamba_rev = None + + def forward(self, hidden_states, inference_params=None): + """Bidirectional-enabled forward pass + + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + out = self.mamba_fwd(hidden_states, inference_params=inference_params) + if self.bidirectional: + out_rev = self.mamba_rev( + hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension + inference_params=inference_params + ).flip(dims=(1,)) # Flip back for combining with forward hidden states + if self.bidirectional_strategy == "add": + out = out + out_rev + elif self.bidirectional_strategy == "ew_multiply": + out = out * out_rev + return out + + class MixerModel(nn.Module): def __init__( self, @@ -95,6 +149,8 @@ def __init__( initializer_cfg=None, fused_add_norm=False, residual_in_fp32=False, + bidirectional: bool = False, + bidirectional_strategy: Optional[str] = None, device=None, dtype=None, ) -> None: @@ -124,6 +180,8 @@ def __init__( residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, layer_idx=i, + bidirectional=bidirectional, + bidirectional_strategy=bidirectional_strategy, **factory_kwargs, ) for i in range(n_layer) @@ -191,6 +249,8 @@ def __init__( residual_in_fp32 = config.residual_in_fp32 fused_add_norm = config.fused_add_norm pad_vocab_size_multiple = config.pad_vocab_size_multiple + bidirectional = config.bidirectional + bidirectional_strategy = config.bidirectional_strategy factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -205,6 +265,8 @@ def __init__( initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, + bidirectional=bidirectional, + bidirectional_strategy=bidirectional_strategy, **factory_kwargs, ) self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) @@ -234,8 +296,8 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) + LMOutput = namedtuple("LMOutput", ["logits"]) + return LMOutput(logits=lm_logits) @classmethod def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):