Skip to content

Commit

Permalink
Freeze Quantization Config once instantiated
Browse files Browse the repository at this point in the history
Add docstring to QuantizationConfig
  • Loading branch information
rahul-tuli committed Apr 3, 2024
1 parent c59ef95 commit 8257040
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions src/sparseml/transformers/utils/gptq_utils/vllm_export_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import Tensor
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from sparseml.pytorch.model_load.helpers import fallback_to_cpu
from sparseml.transformers.utils.gptq_utils.transformations import (
GPTQ_EXLLAMA_TRANSFORMATIONS,
)
Expand All @@ -45,6 +46,7 @@ def export_vllm_compatible_checkpoint(
tokenizer: Union[PreTrainedTokenizerBase, str, None] = None,
format: SUPPORTED_FORMAT_TYPES = "exllama",
save_dir: Union[str, Path, None] = None,
device: str = "cuda",
):
"""
A utility function to export a GPTQ quantized model to safetensors,
Expand All @@ -61,18 +63,20 @@ def export_vllm_compatible_checkpoint(
:param format: The format to which the model should be exported.
Default is "exllama".
:param save_dir: The directory where the model should be saved.
:param device: The device to use for the model. Default is "cuda".
if cuda is not available, it will fallback to cpu.
"""

validate_specified_format(format=format)

model, tokenizer = _create_model_and_tokenizer(model, tokenizer)
model, tokenizer = _create_model_and_tokenizer(model=model, tokenizer=tokenizer)

_LOGGER.info(f"Translating state dict to {format} format.")
translated_state_dict: Dict[str, Any] = translate_state_dict(
state_dict=model.state_dict(), format=format
)

model.config.quantization_config = QuantizationConfig()
model.config.quantization_config = _QuantizationConfig()
_LOGGER.info(f"Added {format} quantization info to model.config")

if save_dir is None:
Expand Down Expand Up @@ -151,8 +155,20 @@ def validate_specified_format(format: SUPPORTED_FORMAT_TYPES):
raise NotImplementedError(f"Exporting to format {format} is not supported yet.")


@dataclass
class QuantizationConfig:
@dataclass(frozen=True)
class _QuantizationConfig:
"""
A dataclass to hold the quantization configuration for the model.
This class is specific to GPTQ style quantization, and an instance
of this class can be added to the model.config.quantization_config
to enable the model to be exported to Exllama format.
Right now, the defaults are specific to sparseml GPTQ quantization.
In future versions we may support more general quantization configurations.
This class is frozen to prevent modification of the instance after creation.
"""

bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
group_size: int = field(default=-1)
damp_percent: float = field(default=0.01)
Expand Down Expand Up @@ -206,17 +222,21 @@ def _translate_state_dict_exllama(state_dict: Dict[str, Any]) -> Dict[Any, Any]:
def _create_model_and_tokenizer(
model: Union[PreTrainedModel, str],
tokenizer: Union[PreTrainedTokenizerBase, str, None] = None,
device: str = "cuda",
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
"""
Create/infer model and tokenizer instances from the passed
in model and tokenizer.
in model and tokenizer. Additionally moves the model to the
specified device.
:param model: The model to be exported, can also be
path to a local model directory or a HuggingFace/SparseZoo stub
:param tokenizer: The tokenizer associated with the model,
can also be a HuggingFace/SparseZoo stub, if not passed in,
it will be inferred from the model. An error will be raised if it
cannot be inferred.
:param device: The device to use for the model. Default is "cuda".
if cuda is not available, it will fallback to cpu.
:return A tuple of (model, tokenizer) instances. If both were
passed into this function, they are returned as is.
If tokenizer was not passed in, it is inferred from the
Expand All @@ -239,4 +259,7 @@ def _create_model_and_tokenizer(
if isinstance(model, str):
model = SparseAutoModelForCausalLM.from_pretrained(model)

# move model to gpu if avaliable
model.to(fallback_to_cpu(device=device))

return model, tokenizer

0 comments on commit 8257040

Please sign in to comment.