Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOE Quantization] Warn against "undercalibrated" modules #2262

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
097bd79
initial commit
dbogunowicz Apr 8, 2024
76970e3
update setup.py
dbogunowicz Apr 8, 2024
bbf4b39
Update setup.py
dbogunowicz Apr 8, 2024
a272a30
fix setup.py
dbogunowicz Apr 8, 2024
c0d3ead
move all config to sparsetensors
Apr 10, 2024
b3f7ff3
Merge branch 'main' into feature/damian/sparsetensors
Apr 10, 2024
a75f8da
cleanup class name and comments
Apr 10, 2024
c5b897e
Merge branch 'main' into feature/damian/sparsetensors
Apr 16, 2024
2c72ab1
initial implementation untested
Apr 16, 2024
9174c1d
fixing issues
Apr 16, 2024
aa17e77
add test script
Apr 17, 2024
f1f114c
update perplexity test
Apr 17, 2024
bbbdcb9
refactor to compressed-tensors
dbogunowicz Apr 18, 2024
5d9c7dd
Merge branch 'main' into feature/damian/sparsetensors
dbogunowicz Apr 18, 2024
7a9f9e5
rename sparsetensors
Apr 18, 2024
fa43088
update setup
Apr 18, 2024
63266d8
Sa/model reload (#2250)
Apr 19, 2024
b0f0fc9
Merge branch 'main' into sa/quant_mod_refactor
Apr 22, 2024
dfa41fb
Merge branch 'main' into feature/damian/sparsetensors
Apr 22, 2024
4af4852
Merge branch 'feature/damian/sparsetensors' into sa/quant_mod_refactor
Apr 22, 2024
55976c5
cleanup
Apr 22, 2024
38f4f77
refactor tests
Apr 22, 2024
6574874
only run oneshot once
Apr 22, 2024
7f5babf
all tests passing
dbogunowicz Apr 23, 2024
c0d6cb9
remove unused config
dbogunowicz Apr 23, 2024
a59e2af
reset models on each parameterize
Apr 23, 2024
cba7c27
Merge branch 'feature/damian/sparsetensors' into sa/quant_mod_refactor
Apr 23, 2024
2a6b0f2
style
Apr 23, 2024
1e7ee94
Merge branch 'main' into feature/damian/sparsetensors
dbogunowicz Apr 24, 2024
a4e0575
bring back SparsityConfigMetadata
dbogunowicz Apr 24, 2024
06d4554
Merge branch 'feature/damian/sparsetensors' of github.com:neuralmagic…
dbogunowicz Apr 24, 2024
644da53
Merge remote-tracking branch 'origin/feature/damian/sparsetensors' in…
dbogunowicz Apr 24, 2024
8ac18e7
Update setup.py
dbogunowicz Apr 24, 2024
de78247
add more comparisons, tighten threshold
Apr 25, 2024
4041f2e
use wikitext for perplexity
Apr 25, 2024
f5adc4e
Merge branch 'main' into feature/damian/sparsetensors
dbogunowicz Apr 25, 2024
c220772
update setup
dbogunowicz Apr 25, 2024
2fe554e
fix import problem
dbogunowicz Apr 25, 2024
4e0413e
fix clearml test
dbogunowicz Apr 25, 2024
a98a193
compressed-tensors are transformers dep
dbogunowicz Apr 25, 2024
b9b684c
Merge branch 'feature/damian/sparsetensors' into sa/quant_mod_refactor
Apr 25, 2024
f4362cf
address PR comments
Apr 25, 2024
ca91c4f
can't repeat freeze
Apr 26, 2024
c894305
UX pr comments
Apr 26, 2024
1c3b31b
Merge branch 'main' into sa/quant_mod_refactor
May 1, 2024
90795bd
quality
May 1, 2024
bf7d0f6
shape consistency
horheynm May 1, 2024
579d201
Merge branch 'sa/quant_mod_refactor' of github.com:neuralmagic/sparse…
horheynm May 1, 2024
2432cf4
address PR comments
May 2, 2024
139f388
only relevant files
dbogunowicz May 6, 2024
b33e393
Merge remote-tracking branch 'origin/main' into feature/damian/moe
dbogunowicz May 7, 2024
d70daa2
add checks for undercalibrated modules
dbogunowicz May 7, 2024
6a4eecb
typo
dbogunowicz May 7, 2024
7a82f56
Delete moe.py
dbogunowicz May 7, 2024
a264819
Merge branch 'main' into feature/damian/moe
dbogunowicz May 10, 2024
66b6cf7
Merge branch 'main' into feature/damian/moe
dbogunowicz May 13, 2024
c5fb319
Merge branch 'main' into feature/damian/moe
dbogunowicz May 14, 2024
453a34c
ready to hand over
May 15, 2024
92191e8
truncate the warning content
May 15, 2024
b859ff6
Merge branch 'main' into feature/damian/moe
dbogunowicz May 16, 2024
de34285
Merge remote-tracking branch 'origin/main' into feature/damian/moe
Jun 13, 2024
c385c68
refresh
Jun 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"opencv-python<=4.6.0.66",
]
_transformers_deps = _pytorch_deps + [
"transformers<4.40",
"transformers<4.41",
"datasets<2.19",
"dvc",
"scikit-learn",
Expand Down
46 changes: 46 additions & 0 deletions src/sparseml/modifiers/utils/pytorch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from itertools import cycle
from typing import Callable, Dict, Optional

Expand All @@ -20,11 +21,14 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from compressed_tensors.quantization.observers.helpers import get_observer_token_count
from sparseml.pytorch.utils import tensors_module_forward, tensors_to_device


__all__ = ["apply_pad_mask_to_batch", "run_calibration_forward"]

_LOGGER = logging.getLogger(__name__)


def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Expand All @@ -46,6 +50,7 @@ def run_calibration_forward(
calibration_function: Optional[Callable] = None,
device: Optional[str] = None,
mask_padding: bool = False,
undercalibrated_module_treshold: float = 0.2,
):
"""
Helper function used by one-shot modifiers, runs calibration data through a model to
Expand All @@ -58,6 +63,9 @@ def run_calibration_forward(
:param calibration_function: option to pass a custom forward function for model
:param device: option to move the model to a specific device before calibration
:param mask_padding: whether to zero out padding tokens during calibration
:param undercalibrated_module_treshold: the minimum percentage of tokens
(out of all the tokens in a batch) a module should receive during each
forward pass of the calibration
"""
model.eval()

Expand Down Expand Up @@ -85,3 +93,41 @@ def run_calibration_forward(
batch = tensors_to_device(batch, model_device)
with torch.no_grad():
forward_fn(batch, module=model)
check_for_undercalibrated_modules(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
batch["input_ids"], model, undercalibrated_module_treshold
)


def check_for_undercalibrated_modules(
batch: torch.Tensor, model: Module, threshold: float
):
"""
A helper function that warns when a module has seen
fewer than threshold % of all the tokens in the batch

:param batch: the batch of tokens (batch_size, sequence_length)
:param model: the model to investigate
:param threshold: the minimum percentage of tokens
(out of all the tokens in a batch) a module should
receive during each forward pass of the calibration
"""
total_token_count = len(batch.flatten())
counter = get_observer_token_count(model)
for module_name, token_count in counter.items():
if token_count is None:
# the module has not been observed
# or its token_count is not being recorded
# by the observer (refer to the observer's
# implementation in the source code)
continue
if token_count / total_token_count < threshold:
_LOGGER.warning(
f"The module_name: {module_name} "
f"received less than {int(threshold * 100)}% "
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
f"of calibration batch tokens ({token_count} tokens). "
"This could result may harm the quantization quality."
"\nTo address this issue either:"
"\n1) Increase the batch_size of the calibration inputs"
"\n2) Use calibration data, that is more representative "
"of the original dataset used for the model"
)
Loading