Skip to content

Commit

Permalink
Add state dict translation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Mar 21, 2024
1 parent 121d7fe commit 29f83bb
Show file tree
Hide file tree
Showing 3 changed files with 444 additions and 0 deletions.
253 changes: 253 additions & 0 deletions src/sparseml/transformers/utils/transformations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa #F821,#E501

import functools
import logging
from typing import Dict

import numpy
import torch
from torch import Tensor


__all__ = [
"transform_names",
"add_tensors",
"transform_tensors",
"remove_unwanted_tensors",
"is_quantization_target",
]

_LOGGER = logging.getLogger(__name__)


def _log_call(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_LOGGER.info("Applying transformation: %s", func.__name__.upper())
return_value = func(*args, **kwargs)
_LOGGER.info("Transformation: %s complete", func.__name__.upper())
return return_value

return wrapper


def is_quantization_target(key: str) -> bool:
"""
Assumes self_attn and mlp are the only quantization targets
in model layers of the state_dict.
:param key: The key of the state_dict
:return: True if the key is a quantization target, False otherwise
"""
return "model.layers" in key and ("self_attn" in key or "mlp" in key)


@_log_call
def transform_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Transforms the state_dict keys to match with exllama format
The renames include:
- weight_fake_quant.scale -> scales
- weight_fake_quant.zero_point -> qzeros
- weight -> qweight
Note: does not transforms the actual tensor values
:pre-condition: The state_dict should be for a quantized model
:pre-condition: Targets only the weights of the self_attn and mlp nodes
:param state_dict: The quantized state_dict to be transformed
:return: The transformed state_dict
"""
# mapping of the old names to the new names
name_map: Dict[str, str] = {
".weight_fake_quant.scale": ".scales",
".weight_fake_quant.zero_point": ".qzeros",
".weight": ".qweight",
}

new_state_dict: Dict[str, Tensor] = {}
for key, tensor in state_dict.items():
if is_quantization_target(key) and any(
key.endswith(target_suffix := suffix) for suffix in name_map
):
updated_key = key.replace(target_suffix, name_map[target_suffix])
new_state_dict[updated_key] = tensor
else:
new_state_dict[key] = tensor
return new_state_dict


def pack(weight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor) -> Tensor:
"""
Quantize the weight tensor using the scales, zeros, and g_idx tensors
into 4 bit integers, and packs a group of 8 of them into a single 32 bit integer.
Adapted from:
https://github.com/AutoGPTQ/AutoGPTQ/blob/ea4a99778f90b60c9b5177d7487af1b4ca87744f/auto_gptq/nn_modules/qlinear/qlinear_exllama.py#L118
:param weight: The weight tensor to be quantized of shape [x, 8y]
:param scales: The scales tensor
:param zeros: The zero points tensor
:param g_idx: The group index tensor
:return: The quantized weight tensor of int32 dtype and shape [x, y]
"""
g_idx = g_idx.clone()

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
scales = scales.clone().half()
bits = 4

intweight = []
infeatures = weight.shape[1]
for idx in range(infeatures):
intweight.append(
torch.round(
(weight[:, idx] + scale_zeros[g_idx[idx]]) / scales[g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(numpy.uint32)

i = 0
row = 0
qweight = numpy.zeros(
(intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=numpy.uint32
)
while row < qweight.shape[0]:
if bits in [4]:
for j in range(i, i + (32 // bits)):
qweight[row] |= intweight[j] << (bits * (j - i))
i += 32 // bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")

qweight = qweight.astype(numpy.int32)
qweight = torch.from_numpy(qweight)
return qweight


@_log_call
def add_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:

new_dict: Dict[str, Tensor] = {}

for key, tensor in state_dict.items():
if is_quantization_target(key) and key.endswith(".qweight"):
# add bias and g_idx tensors
bias_key = key.replace(".qweight", ".bias")
g_idx_key = key.replace(".qweight", ".g_idx")

# bias tensor
bias_tensor = torch.zeros(tensor.shape[0], dtype=torch.float16)
new_dict[bias_key] = bias_tensor

# g_idx tensor of shape [num_channels] dtype int32 filled
# with zeros
g_idx_tensor = torch.zeros(tensor.shape[1], dtype=torch.int32)
new_dict[g_idx_key] = g_idx_tensor

# copy the original tensor, (qweight is also copied in this step)
new_dict[key] = tensor
return new_dict


@_log_call
def transform_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:

new_dict: Dict[str, Tensor] = {}

# auxillary dict to store transformed weights
weights_dict: Dict[str, Tensor] = {}

# quantize qweights before scales, and qzeros
# because the ordering is not guaranteed
# in our implementation
for key, tensor in state_dict.items():
if is_quantization_target(key) and key.endswith(".qweight"):
# quantize the weight tensor
qweight = pack(
weight=tensor,
scales=state_dict[key.replace("qweight", "scales")],
zeros=state_dict[key.replace("qweight", "qzeros")],
g_idx=state_dict[key.replace("qweight", "g_idx")],
)
assert qweight.dtype == torch.int32
weights_dict[key] = qweight

# transform scales and zero points
for key, tensor in state_dict.items():
if is_quantization_target(key) and key.endswith(".scales"):
# scales [x] should be reshaped to [1, x]
# and converted to fp16
scales = tensor.reshape(1, -1).to(torch.float16)
new_dict[key] = scales
elif is_quantization_target(key) and key.endswith(".qzeros"):
# zero points [8x] should be reshaped to [1, x]
# of type int32 and filled with zeros (symmetric quantization)
zeros = torch.zeros(tensor.shape[0] // 8, dtype=torch.int32)
new_dict[key] = zeros.reshape(1, -1)
else:
new_dict[key] = tensor

# overwrite old weights with the new quantized weights
new_dict.update(weights_dict)

# auxillary weights_dict not needed anymore
del weights_dict

return new_dict


@_log_call
def remove_unwanted_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Remove unwanted tensors from the state_dict that are not necessary for inference.
These tensors include:
- eps
- min_val
- max_val
- fake_quant_enabled
- observer_enabled
"""
to_delete = ["eps", "min_val", "max_val", "fake_quant_enabled", "observer_enabled"]
keys = list(state_dict.keys())
for key in keys:
if any(key.endswith(suffix) for suffix in to_delete):
del state_dict[key]
return state_dict


def check_dicts(actual, expected):
assert len(actual) == len(
expected
), "The number of tensors in the actual and expected state dicts do not match"

for key, value in actual.items():
assert (
key in expected
), f"The key {key} is not present in the expected state dict"
assert (
value.shape == expected[key].shape
), f"The shape of the tensor {key} in the actual state dict does not match the shape of the tensor in the expected state dict, expected {expected[key].shape} but got {value.shape}"
assert (
value.dtype == expected[key].dtype
), f"The dtype of the tensor {key} in the actual state dict does not match the dtype of the tensor in the expected state dict, expected {expected[key].dtype} but got {value.dtype}"
Loading

0 comments on commit 29f83bb

Please sign in to comment.