diff --git a/src/sparseml/transformers/utils/gptq_utils/vllm_export_helpers.py b/src/sparseml/transformers/utils/gptq_utils/vllm_export_helpers.py index 440c256a3ac..4a4d3ea51ac 100644 --- a/src/sparseml/transformers/utils/gptq_utils/vllm_export_helpers.py +++ b/src/sparseml/transformers/utils/gptq_utils/vllm_export_helpers.py @@ -23,6 +23,7 @@ from torch import Tensor from transformers import PreTrainedModel, PreTrainedTokenizerBase +from sparseml.pytorch.model_load.helpers import fallback_to_cpu from sparseml.transformers.utils.gptq_utils.transformations import ( GPTQ_EXLLAMA_TRANSFORMATIONS, ) @@ -45,6 +46,7 @@ def export_vllm_compatible_checkpoint( tokenizer: Union[PreTrainedTokenizerBase, str, None] = None, format: SUPPORTED_FORMAT_TYPES = "exllama", save_dir: Union[str, Path, None] = None, + device: str = "cuda", ): """ A utility function to export a GPTQ quantized model to safetensors, @@ -61,18 +63,20 @@ def export_vllm_compatible_checkpoint( :param format: The format to which the model should be exported. Default is "exllama". :param save_dir: The directory where the model should be saved. + :param device: The device to use for the model. Default is "cuda". + if cuda is not available, it will fallback to cpu. """ validate_specified_format(format=format) - model, tokenizer = _create_model_and_tokenizer(model, tokenizer) + model, tokenizer = _create_model_and_tokenizer(model=model, tokenizer=tokenizer) _LOGGER.info(f"Translating state dict to {format} format.") translated_state_dict: Dict[str, Any] = translate_state_dict( state_dict=model.state_dict(), format=format ) - model.config.quantization_config = QuantizationConfig() + model.config.quantization_config = _QuantizationConfig() _LOGGER.info(f"Added {format} quantization info to model.config") if save_dir is None: @@ -151,8 +155,20 @@ def validate_specified_format(format: SUPPORTED_FORMAT_TYPES): raise NotImplementedError(f"Exporting to format {format} is not supported yet.") -@dataclass -class QuantizationConfig: +@dataclass(frozen=True) +class _QuantizationConfig: + """ + A dataclass to hold the quantization configuration for the model. + This class is specific to GPTQ style quantization, and an instance + of this class can be added to the model.config.quantization_config + to enable the model to be exported to Exllama format. + + Right now, the defaults are specific to sparseml GPTQ quantization. + In future versions we may support more general quantization configurations. + + This class is frozen to prevent modification of the instance after creation. + """ + bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) group_size: int = field(default=-1) damp_percent: float = field(default=0.01) @@ -206,10 +222,12 @@ def _translate_state_dict_exllama(state_dict: Dict[str, Any]) -> Dict[Any, Any]: def _create_model_and_tokenizer( model: Union[PreTrainedModel, str], tokenizer: Union[PreTrainedTokenizerBase, str, None] = None, + device: str = "cuda", ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: """ Create/infer model and tokenizer instances from the passed - in model and tokenizer. + in model and tokenizer. Additionally moves the model to the + specified device. :param model: The model to be exported, can also be path to a local model directory or a HuggingFace/SparseZoo stub @@ -217,6 +235,8 @@ def _create_model_and_tokenizer( can also be a HuggingFace/SparseZoo stub, if not passed in, it will be inferred from the model. An error will be raised if it cannot be inferred. + :param device: The device to use for the model. Default is "cuda". + if cuda is not available, it will fallback to cpu. :return A tuple of (model, tokenizer) instances. If both were passed into this function, they are returned as is. If tokenizer was not passed in, it is inferred from the @@ -239,4 +259,7 @@ def _create_model_and_tokenizer( if isinstance(model, str): model = SparseAutoModelForCausalLM.from_pretrained(model) + # move model to gpu if avaliable + model.to(fallback_to_cpu(device=device)) + return model, tokenizer