Skip to content

Commit

Permalink
Update Quantization Logging to New Framework (#2313)
Browse files Browse the repository at this point in the history
* use new quant framework for logging

* fix legacy compatability

* fix
  • Loading branch information
Sara Adkins authored Jun 10, 2024
1 parent 3cd9a8c commit 934f0d8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
31 changes: 9 additions & 22 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
import random
import re
import warnings
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from copy import deepcopy
Expand All @@ -30,7 +29,7 @@
import torch
from packaging import version
from torch import Tensor
from torch.nn import Linear, Module, Parameter
from torch.nn import Embedding, Linear, Module, Parameter
from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -780,6 +779,7 @@ def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]:
for (name, mod) in module.named_modules()
if (
isinstance(mod, Linear)
or isinstance(mod, Embedding)
or isinstance(mod, _ConvNd)
or (QATLinear and isinstance(mod, QATLinear))
or (QATConv2d and isinstance(mod, QATConv2d))
Expand All @@ -793,7 +793,7 @@ def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the quantizable layers from
:return: a list containing the names and modules of the quantizable layers
(Linear, Conv2d, Conv3d)
(Embedding, Linear, Conv2d, Conv3d)
"""
if QATLinear is None:
raise ImportError(
Expand All @@ -806,6 +806,7 @@ def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]:
for (name, mod) in module.named_modules()
if (
isinstance(mod, Linear)
or isinstance(mod, Embedding)
or isinstance(mod, Conv2d)
or (QATConv3d and isinstance(mod, Conv3d))
)
Expand All @@ -816,29 +817,15 @@ def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the quantized layers from
:return: a list containing the names and modules of the quantized layers
(Linear, Conv2d, Conv3d)
(Embedding, Linear, Conv2d, Conv3d)
"""
if QATLinear is None:
raise ImportError(
"PyTorch version is not setup for Quantization. "
"Please install a QAT compatible version of PyTorch"
)

quantized_layers = []
for (name, mod) in module.named_modules():
if (
(QATLinear and isinstance(mod, QATLinear))
or (QATConv2d and isinstance(mod, QATConv2d))
or (QATConv3d and isinstance(mod, QATConv3d))
):
quantized_layers.append((name, mod))

elif isinstance(mod, Conv3d) and not QATConv3d:
warnings.warn(
"Pytorch version is not setup for Conv3D Quantization. "
"Quantization of Conv3D layers will be skipped",
UserWarning,
)
if hasattr(mod, "quantization_scheme"):
weight_scheme = getattr(mod.quantization_scheme, "weights", None)
if weight_scheme is not None and hasattr(mod, "weight"):
quantized_layers.append((name, mod))

return quantized_layers

Expand Down
11 changes: 9 additions & 2 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,15 +500,22 @@ def log_model_sparsification(self):
f"Sparsification info for {type(self.model).__name__}: "
f"{sparsification_info.params_total} total params. "
)
sparsity_percent_formatted = "{:.2f}".format(
sparsification_info.params_prunable_sparse_percent
)
_LOGGER.info(
f"There are {sparsification_info.params_prunable_total} prunable "
f"params which have {sparsification_info.params_prunable_sparse_percent} "
f"params which have {sparsity_percent_formatted}% "
"avg sparsity."
)

quant_percent_formatted = "{:.2f}".format(
sparsification_info.params_quantized_percent
)
_LOGGER.info(
f"There are {sparsification_info.params_quantizable} quantizable "
f"params, with a quantization percentage of "
f"{sparsification_info.params_quantized_percent}."
f"{quant_percent_formatted}%."
)

def _prepare_model_for_fsdp(self):
Expand Down

0 comments on commit 934f0d8

Please sign in to comment.