diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index fae2257a..bdd7efdb 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -1,24 +1,20 @@ # Copyright (c) 2023, Albert Gu, Tri Dao. - -import math -from functools import partial -import json -import os import copy - +import math from collections import namedtuple +from functools import partial import torch import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin from mamba_ssm.models.config_mamba import MambaConfig -from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.block import Block from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP -from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin -from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn @@ -212,7 +208,15 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs): return hidden_states -class MambaLMHeadModel(nn.Module, GenerationMixin): +class MambaLMHeadModel( + nn.Module, + GenerationMixin, + PyTorchModelHubMixin, + library_name="mamba-ssm", + repo_url="https://github.com/state-spaces/mamba", + tags=["arXiv:2312.00752", "arXiv:2405.21060"], + pipeline_tag="text-generation", + ): def __init__( self, @@ -283,27 +287,3 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) - @classmethod - def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): - config_data = load_config_hf(pretrained_model_name) - config = MambaConfig(**config_data) - model = cls(config, device=device, dtype=dtype, **kwargs) - model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) - return model - - def save_pretrained(self, save_directory): - """ - Minimal implementation of save_pretrained for MambaLMHeadModel. - Save the model and its configuration file to a directory. - """ - # Ensure save_directory exists - os.makedirs(save_directory, exist_ok=True) - - # Save the model's state_dict - model_path = os.path.join(save_directory, 'pytorch_model.bin') - torch.save(self.state_dict(), model_path) - - # Save the configuration of the model - config_path = os.path.join(save_directory, 'config.json') - with open(config_path, 'w') as f: - json.dump(self.config.__dict__, f, indent=4) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 1859ab0d..85fd6dec 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -31,10 +31,8 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -from huggingface_hub import PyTorchModelHubMixin - -class Mamba2(nn.Module, PyTorchModelHubMixin): +class Mamba2(nn.Module): def __init__( self, d_model, diff --git a/mamba_ssm/utils/hf.py b/mamba_ssm/utils/hf.py deleted file mode 100644 index 0d7555ac..00000000 --- a/mamba_ssm/utils/hf.py +++ /dev/null @@ -1,23 +0,0 @@ -import json - -import torch - -from transformers.utils import WEIGHTS_NAME, CONFIG_NAME -from transformers.utils.hub import cached_file - - -def load_config_hf(model_name): - resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) - return json.load(open(resolved_archive_file)) - - -def load_state_dict_hf(model_name, device=None, dtype=None): - # If not fp32, then we don't want to load directly to the GPU - mapped_device = "cpu" if dtype not in [torch.float32, None] else device - resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - return torch.load(resolved_archive_file, map_location=mapped_device) - # Convert dtype before moving to GPU to save memory - if dtype is not None: - state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} - return state_dict diff --git a/setup.py b/setup.py index dd8d8128..eeae8c10 100755 --- a/setup.py +++ b/setup.py @@ -374,6 +374,7 @@ def run(self): "einops", "triton", "transformers", + "huggingface_hub>=0.23.5", # "causal_conv1d>=1.4.0", ], )