Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better HF integration for MambaLMHeadModel #471

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 14 additions & 34 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
# Copyright (c) 2023, Albert Gu, Tri Dao.

import math
from functools import partial
import json
import os
import copy

import math
from collections import namedtuple
from functools import partial

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.block import Block
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.modules.block import Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
Expand Down Expand Up @@ -212,7 +208,15 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
return hidden_states


class MambaLMHeadModel(nn.Module, GenerationMixin):
class MambaLMHeadModel(
nn.Module,
GenerationMixin,
PyTorchModelHubMixin,
library_name="mamba-ssm",
repo_url="https://github.com/state-spaces/mamba",
tags=["arXiv:2312.00752", "arXiv:2405.21060"],
pipeline_tag="text-generation",
):

def __init__(
self,
Expand Down Expand Up @@ -283,27 +287,3 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
return model

def save_pretrained(self, save_directory):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
os.makedirs(save_directory, exist_ok=True)

# Save the model's state_dict
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)

# Save the configuration of the model
config_path = os.path.join(save_directory, 'config.json')
with open(config_path, 'w') as f:
json.dump(self.config.__dict__, f, indent=4)
4 changes: 1 addition & 3 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined

from huggingface_hub import PyTorchModelHubMixin


class Mamba2(nn.Module, PyTorchModelHubMixin):
class Mamba2(nn.Module):
def __init__(
self,
d_model,
Expand Down
23 changes: 0 additions & 23 deletions mamba_ssm/utils/hf.py

This file was deleted.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def run(self):
"einops",
"triton",
"transformers",
"huggingface_hub>=0.23.5",
# "causal_conv1d>=1.4.0",
],
)