Skip to content

Commit

Permalink
format (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb authored Apr 16, 2024
1 parent 61f38d2 commit 6d582a7
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from math import gcd
from typing import Dict, Optional, Tuple

import quantized_ops

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -1318,7 +1316,7 @@ def get_inputs(
inputs = input_recorder.get_recorded_inputs()
assert inputs is not None, (
f"No inputs were collected, use a task other than {calibration_tasks}, "
+ f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "
+ "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "
+ f"{calibration_seq_length})"
)
print(f"Obtained {len(inputs[0].values)} calibration samples")
Expand All @@ -1335,7 +1333,7 @@ def create_quantized_state_dict(
calibration_limit,
calibration_seq_length,
pad_calibration_inputs,
) -> "StateDict":
) -> Dict: # "StateDict":
inputs = GPTQQuantHandler.get_inputs(
self.mod,
tokenizer,
Expand Down Expand Up @@ -1511,7 +1509,7 @@ def create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao

for m in self.mod.modules():
for name, child in m.named_children():
for _name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
Expand Down

0 comments on commit 6d582a7

Please sign in to comment.