Skip to content

Commit

Permalink
Refactor the quantization modification logic (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored Apr 29, 2024
1 parent 8467ee4 commit 7cd2feb
Show file tree
Hide file tree
Showing 23 changed files with 371 additions and 425 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from torch import nn

from sparseml.transformers.sparsification.modification import modify_model
from sparseml.transformers.sparsification.modification.modification_objects import (
QATLinear,
)


def test_modifying_mobilebert(mobilebert_model):

mobilebert_ = deepcopy(mobilebert_model)
mobilebert = modify_model(mobilebert_model)

assert isinstance(mobilebert_.embeddings.embedding_transformation, nn.Linear)
assert isinstance(mobilebert.embeddings.embedding_transformation, QATLinear)
# flake8: noqa
from .modify_model import modify_model
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""
Set of helper objects that are used to modify
the HuggingFace transformer models
the quantized models
"""

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,34 @@
import logging
import os

import torch

from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry


_LOGGER = logging.getLogger(__name__)


def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Module:
def modify_model(
model: "torch.nn.Module", disable: bool = False # noqa: F821
) -> "torch.nn.Module": # noqa: F821
"""
Modify the original transformers model so that it is
compatible with the SparseML library.
Modify the original model so that it is
compatible with the quantization format required by the
SparseML library.
The model will be modified, if there exist a modification
function for the model in the registry of modifications.
Otherwise, the original model will be returned.
:param model: The original HuggingFace transformers model
:return: The potentially modified model
:param model: The original model to be modified
:param disable: If True, the modification will be disabled
:return: The potentially modified model to support
SparseML quantization
"""
model_name = model.__class__.__name__
NM_DISABLE_TRANSFORMERS_MODIFICATION = os.environ.get(
"NM_DISABLE_TRANSFORMERS_MODIFICATION", "False"
NM_DISABLE_QUANTIZATION_MODIFICATION = os.environ.get(
"NM_DISABLE_QUANTIZATION_MODIFICATION", "False"
).lower() in ["true", "1"]

try:
modification_func = ModificationRegistry.get_value_from_registry(model_name)
except KeyError:
Expand All @@ -50,7 +53,7 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
)
return model

if NM_DISABLE_TRANSFORMERS_MODIFICATION:
if NM_DISABLE_QUANTIZATION_MODIFICATION:
_LOGGER.debug(
"Application of the modification function to model "
"disabled through the environment variable."
Expand All @@ -65,6 +68,6 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
return model

_LOGGER.info(
f"Modifying the model {model_name} to be compatible with SparseML library"
f"Modifying the model {model_name} to be compatible with SparseML quantization"
)
return modification_func(model)
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
)
from sparsezoo.utils.registry import RegistryMixin


class ModificationRegistry(RegistryMixin):
"""
A registry for modification functions that can be applied to models
so that they can be used in the context of sparseml.transformers
so that they can be compatible with the quantization format required by the
SparseML library.
"""

@classmethod
def get_value_from_registry(cls, name: str):
"""
Extends the base class method to check the transformers version after
successfully retrieving the value from the registry. The motivation is
to ensure that the transformers version falls within the supported range
before we proceed with model modification.
"""
retrieved_value = super().get_value_from_registry(name)
check_transformers_version()
return retrieved_value
6 changes: 6 additions & 0 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
Expand Down Expand Up @@ -73,11 +74,16 @@ def __init__(self, **kwargs):

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
# before the structure is modified to support quantization,
# we need to potentially modify the model architecture
module = modify_model(module)
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
module = state.model.model
module = modify_model(module)
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/sparsification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# flake8: noqa

from .modification import *
from .question_answering import *
from .sparse_config import *
from .sparse_model import *
Expand Down
20 changes: 13 additions & 7 deletions src/sparseml/transformers/sparsification/modification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
from .modify_model import modify_model
from .modifying_bert import *
from .modifying_distilbert import *
from .modifying_llama import *
from .modifying_mistral import *
from .modifying_mobilebert import *
from .modifying_opt import *
# isort:skip_file

# the modification module that adds modifications
# for transformers models to enable quantization

# import all the modification functions for the different models
from .modifying_bert import modify
from .modifying_llama import modify
from .modifying_mistral import modify
from .modifying_distilbert import modify
from .modifying_mobilebert import modify
from .modifying_opt import modify
Original file line number Diff line number Diff line change
Expand Up @@ -14,68 +14,59 @@

"""
Modification to the original Bert model required in the
context of SparseML
context of SparseML quantization
"""


import logging
import math
from typing import Optional, Tuple

import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertAttention, BertSelfAttention
from transformers.models.bert.modeling_bert import BertSelfAttention

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
QATMatMul,
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
)
from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)


@ModificationRegistry.register(name="BertModel", alias=["BertForQuestionAnswering"])
def modify(model: nn.Module) -> nn.Module:
"""
Modify the Bert model to be compatible with SparseML
quantization
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as BertAttention
Replaces the attention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
:param model: the original Bert model
:return: the modified Bert model
"""
check_transformers_version()
for name, submodule in model.named_modules():
if isinstance(submodule, BertSelfAttention):
if isinstance(submodule, BertSelfAttention) and not isinstance(
submodule, BertSelfAttentionWithQuantizableMatmuls
):
swap_modules(
model, name, BertSelfAttentionWithQuantizableMatmuls(submodule)
)
elif isinstance(submodule, BertAttention):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


class BertSelfAttentionWithQuantizableMatmuls(BertSelfAttention):
"""
Wrapper around the original BertSelfAttention module to replace the
Wrapper around the original attention module to replace the
matmul operations with quantizable matmul operations
:param bert_self_attention: the original BertSelfAttention module
:param bert_self_attention: the original attention module to be
wrapped and modified
"""

def __init__(self, bert_self_attention: BertSelfAttention):
self.__class__ = type(
bert_self_attention.__class__.__name__,
self.__class__.__name__,
(self.__class__, bert_self_attention.__class__),
{},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

"""
Modification to the original DistilBert model required in the
context of SparseML
context of SparseML quantization
"""

import logging
import math
from typing import Optional, Tuple

Expand All @@ -28,56 +27,49 @@
MultiHeadSelfAttention,
)

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
QATMatMul,
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
)
from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)


@ModificationRegistry.register(name="DistilBertModel")
def modify(model: nn.Module) -> nn.Module:
"""
Modify the DistilBert model to be compatible with SparseML
quantization
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as DistilBertFlashAttention2
Replaces the attention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
:param model: the original DistilBert model
:return: the modified DistilBert model
"""
check_transformers_version()
for name, submodule in model.named_modules():
if isinstance(submodule, MultiHeadSelfAttention):
if isinstance(
submodule, (MultiHeadSelfAttention, DistilBertFlashAttention2)
) and not isinstance(submodule, MultiHeadSelfAttentionWithQuantizableMatmuls):
swap_modules(
model, name, MultiHeadSelfAttentionWithQuantizableMatmuls(submodule)
)
if isinstance(submodule, DistilBertFlashAttention2):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


class MultiHeadSelfAttentionWithQuantizableMatmuls(MultiHeadSelfAttention):
"""
Wrapper around the original MultiHeadSelfAttention module to replace the
matmul operations with quantizable matmul operations
Wrapper around the original attention module to introduce
MultiHeadSelfAttention with quantizable matmul operations
:param mhs_attention: the original MultiHeadSelfAttention module
:param mhs_attention: the original attention module to be
wrapped and modified
"""

def __init__(self, mhs_attention: MultiHeadSelfAttention):
self.__class__ = type(
mhs_attention.__class__.__name__,
self.__class__.__name__,
(self.__class__, mhs_attention.__class__),
{},
)
Expand Down
Loading

0 comments on commit 7cd2feb

Please sign in to comment.