diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index f536fc98..b33e2dd8 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -3,3 +3,4 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from mamba_ssm.models.config_mamba import MambaConfig diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py index ffd31abc..57ada96d 100644 --- a/mamba_ssm/models/config_mamba.py +++ b/mamba_ssm/models/config_mamba.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, asdict +import json @dataclass @@ -12,3 +13,9 @@ class MambaConfig: residual_in_fp32: bool = True fused_add_norm: bool = True pad_vocab_size_multiple: int = 8 + + def to_json_string(self): + return json.dumps(asdict(self)) + + def to_dict(self): + return asdict(self) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 5b3ddfcf..132537ca 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -4,11 +4,13 @@ from functools import partial import json import os +from typing import Optional from collections import namedtuple import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.modules.mamba_simple import Mamba, Block @@ -225,7 +227,12 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + def forward(self, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + position_ids=None, + inference_params=None, + num_last_tokens=0): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens @@ -234,13 +241,29 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return {"loss": loss, "logits": lm_logits, "hidden_states": hidden_states,} + else: + CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "hidden_states"]) + return CausalLMOutput(logits=lm_logits, hidden_states=hidden_states,) @classmethod def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): config_data = load_config_hf(pretrained_model_name) - config = MambaConfig(**config_data) + config = MambaConfig(**{k:v for k, v in config_data.items() if k not in ('_name_or_path', 'architectures',)}) model = cls(config, device=device, dtype=dtype, **kwargs) model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) return model