diff --git a/quantize.py b/quantize.py index 766efac04..4912c1b79 100644 --- a/quantize.py +++ b/quantize.py @@ -9,8 +9,6 @@ from math import gcd from typing import Dict, Optional, Tuple -import quantized_ops - import torch import torch.nn as nn import torch.nn.functional as F @@ -1318,7 +1316,7 @@ def get_inputs( inputs = input_recorder.get_recorded_inputs() assert inputs is not None, ( f"No inputs were collected, use a task other than {calibration_tasks}, " - + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " + + "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " + f"{calibration_seq_length})" ) print(f"Obtained {len(inputs[0].values)} calibration samples") @@ -1335,7 +1333,7 @@ def create_quantized_state_dict( calibration_limit, calibration_seq_length, pad_calibration_inputs, - ) -> "StateDict": + ) -> Dict: # "StateDict": inputs = GPTQQuantHandler.get_inputs( self.mod, tokenizer, @@ -1511,7 +1509,7 @@ def create_quantized_state_dict(self): from hqq.core.quantize import Quantizer # TODO maybe torchao for m in self.mod.modules(): - for name, child in m.named_children(): + for _name, child in m.named_children(): if isinstance(child, torch.nn.Linear): child.weight = torch.nn.Parameter( Quantizer.dequantize(