Skip to content

Commit

Permalink
fix device linear:int8 quant (#206)
Browse files Browse the repository at this point in the history
* fix device int8 quant

* fix duplicate device on linear int8

* typo

* typo (missed comma made it a tuple)

* missing fqn
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 1d8228a commit b71d83d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 35 deletions.
42 changes: 21 additions & 21 deletions .github/workflows/compile_t4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,35 +63,35 @@ jobs:
echo "******************************************"
echo "******** Emb: group-wise quantized *******"
echo "******************************************"
# python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
# cat ./output_compiled
# python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
# cat ./output_aoti
python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******* INT8 channel-wise quantized ******"
echo "******************************************"
# python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
# cat ./output_compiled
# python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
# cat ./output_aoti
python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******** INT8 group-wise quantized *******"
echo "******************************************"
# python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
# cat ./output_compiled
# python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
# cat ./output_aoti
python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "tests complete"
echo "******************************************"
Expand Down
23 changes: 19 additions & 4 deletions .github/workflows/test_mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,29 @@ jobs:
python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
echo "************************************************************"
echo "*** embedding"
echo "************************************************************"
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
echo "************************************************************"
echo "*** linear int8"
echo "************************************************************"
python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
echo "************************************************************"
echo "*** linear int4"
echo "************************************************************"
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
2 changes: 1 addition & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _initialize_model(
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from model_et import PTEModel
from build.model_et import PTEModel
model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
Expand Down
20 changes: 11 additions & 9 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def quantized_model(self) -> nn.Module:
##### Weight-only int8 per-channel quantized code ######


def replace_linear_weight_only_int8_per_channel(module, node_type, groupsize=None):
def replace_linear_weight_only_int8_per_channel(module, device, node_type, groupsize=None):
if groupsize is not None and groupsize != 0:
pass # groupsize = 2 ** groupsize

Expand All @@ -367,10 +367,10 @@ def replace_linear_weight_only_int8_per_channel(module, node_type, groupsize=Non
setattr(
module,
name,
WeightOnlyInt8Linear(child.in_features, child.out_features, groupsize),
WeightOnlyInt8Linear(device, child.in_features, child.out_features, groupsize),
)
else:
replace_linear_weight_only_int8_per_channel(child, node_type, groupsize)
replace_linear_weight_only_int8_per_channel(child, device, node_type, groupsize)


class WeightOnlyInt8QuantHandler(QuantHandler):
Expand All @@ -384,7 +384,7 @@ def __init__(
groupsize: Optional[int] = None,
):
self.mod = mod
self.device = device,
self.device = device
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
Expand Down Expand Up @@ -434,14 +434,16 @@ def create_quantized_state_dict(self) -> Dict:
scales_dtype=mod.weight.dtype,
)

weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)

return cur_state_dict

def convert_for_runtime(self) -> nn.Module:
replace_linear_weight_only_int8_per_channel(self.mod, self.node_type, self.groupsize)
replace_linear_weight_only_int8_per_channel(self.mod, self.device, self.node_type, self.groupsize)
return self.mod

def quantized_model(self) -> nn.Module:
Expand All @@ -459,11 +461,11 @@ class WeightOnlyInt8Linear(torch.nn.Module):

def __init__(
self,
device,
in_features: int,
out_features: int,
groupsize: Optional[int] = None,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
Expand All @@ -472,14 +474,14 @@ def __init__(
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
"weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device)
)
dtype=get_precision()
if groupsize is None or (groupsize == 0):
self.register_buffer("scales", torch.ones(out_features, dtype=dtype))
self.register_buffer("scales", torch.ones(out_features, dtype=dtype, device=device))
else:
groups = (in_features + groupsize - 1) // groupsize
self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype))
self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype, device=device))

def forward(self, input: torch.Tensor) -> torch.Tensor:
scales = self.scales
Expand Down

0 comments on commit b71d83d

Please sign in to comment.