Skip to content

Commit

Permalink
Address quantization failures on devices (#204)
Browse files Browse the repository at this point in the history
* add device for quantization, enable embedding quant with device

* typo

* fix filename weirdness

* enable mps embedding table runs

* import os for basename

* fix extraneous updates with group_size
  • Loading branch information
mikekgfb authored Apr 15, 2024
1 parent 55e7583 commit fbdd08c
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 35 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/compile_t4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ jobs:
echo "******************************************"
echo "******* Emb: channel-wise quantized ******"
echo "******************************************"
# python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
# cat ./output_compiled
# python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
# cat ./output_aoti
python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******** Emb: group-wise quantized *******"
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/test_mps-dtype.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ jobs:
python generate.py --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
done
14 changes: 7 additions & 7 deletions .github/workflows/test_mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ jobs:
python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
# python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager
2 changes: 1 addition & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _initialize_model(

if quantize:
t0q = time.time()
quantize_model(model, quantize)
quantize_model(model, builder_args.device, quantize)
device_sync(device=builder_args.device)
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

Expand Down
5 changes: 3 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import os
import time
from pathlib import Path
from typing import Optional, Tuple
Expand Down Expand Up @@ -333,9 +334,9 @@ def _main(
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

is_chat = "chat" in str(builder_args.checkpoint_path)
is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path))
if is_chat:
raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. yuck!")
raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!")

tokenizer = _initialize_tokenizer(tokenizer_args)

Expand Down
43 changes: 30 additions & 13 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def name_to_dtype(name):
##########################################################################
### process quantization dictionary ###

def quantize_model(model: nn.Module, quantize_options):
def quantize_model(model: nn.Module, device, quantize_options):
"""
Quantize the specified model using the quantizers described by
a quantization dict of the form:
Expand All @@ -74,6 +74,7 @@ def quantize_model(model: nn.Module, quantize_options):
if quantizer == "embedding":
model = EmbeddingOnlyInt8QuantHandler(
model,
device,
**q_kwargs
).quantized_model()
elif linears_quantized:
Expand All @@ -82,30 +83,35 @@ def quantize_model(model: nn.Module, quantize_options):
linears_quantized = True
model = WeightOnlyInt8QuantHandler(
model,
device,
**q_kwargs
).quantized_model()
elif quantizer == "linear:int4":
linears_quantized = True
model = WeightOnlyInt4QuantHandler(
model,
device,
**q_kwargs
).quantized_model()
elif quantizer == "linear:a8w4dq":
linears_quantized = True
model = Int8DynActInt4WeightQuantHandler(
model,
device,
**q_kwargs
).quantized_model()
elif quantizer == "linear:gptq":
linears_quantized = True
model = WeightOnlyInt4GPTQQuantHandler(
model,
device,
**q_kwargs
).quantized_model()
elif quantizer == "linear:hqq":
linears_quantized = True
model = WeightOnlyInt4HqqQuantHandler(
model,
device,
**q_kwargs
).quantized_model()
elif quantizer == "precision":
Expand Down Expand Up @@ -371,12 +377,14 @@ class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
mod,
device,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
self.mod = mod
self.device = device,
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
Expand Down Expand Up @@ -494,7 +502,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


def replace_embedding_weight_only_grouped_int8_per_channel(
module, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False
):
for name, child in module.named_children():
# print(f"name: {name}")
Expand All @@ -505,6 +513,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
module,
name,
QuantizedGroupEmbedding(
device=device,
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
groupsize=groupsize,
Expand All @@ -518,10 +527,11 @@ def replace_embedding_weight_only_grouped_int8_per_channel(


class EmbeddingOnlyInt8QuantHandler(QuantHandler):
def __init__(self, mod, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False):
def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False):
if isinstance(packed, str):
packed = (packed == "True")
self.mod = mod
self.device = device
self.groupsize = groupsize
self.bitwidth = bitwidth
self.packed = packed
Expand Down Expand Up @@ -565,7 +575,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:

if packed:
if weight.shape[-1] %2 != 0:
raise RUntimeError("automatic padding not implemented yet")
raise RuntimeError("automatic padding not implemented yet")

weight_range_shifted = weight.add(8).view(torch.uint8)
weight_view = weight_range_shifted.view(
Expand All @@ -578,6 +588,8 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
weight_packed = weight_even + weight_odd
weight = weight_packed

weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
# Update state dict
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
Expand All @@ -587,7 +599,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:

def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.mod, self.bitwidth, self.groupsize, self.packed
self.mod, self.device, self.bitwidth, self.groupsize, self.packed
)
return self.mod

Expand All @@ -601,10 +613,10 @@ def quantized_model(self) -> nn.Module:
class QuantizedGroupEmbedding(torch.nn.Module):
def __init__(
self,
device,
vocab_size: int,
embedding_dim: int,
groupsize: Optional[int] = None,
device=None,
dtype=torch.half,
packed=False,
) -> None:
Expand All @@ -616,20 +628,20 @@ def __init__(
self.packed = packed
if not packed:
self.register_buffer(
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8, device=device)
)
else: # packed
self.register_buffer(
"weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8)
"weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8, device=device)
)
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16, device=device)
)
else:
self.register_buffer(
"scales", torch.ones((vocab_size,), dtype=torch.float16)
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
)

@torch.no_grad()
Expand Down Expand Up @@ -712,8 +724,9 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c


class WeightOnlyInt4QuantHandler(QuantHandler):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True):
self.mod = mod
self.device = device,
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
Expand Down Expand Up @@ -908,12 +921,15 @@ class Int8DynActInt4WeightQuantHandler(QuantHandler):
def __init__(
self,
mod,
device,
* ,
groupsize=256,
padding_allowed=False,
precision=torch.float32,
scales_precision=torch.float32,
):
self.mod = mod
self.device = device
self.groupsize = groupsize
self.padding_allowed = padding_allowed
self.precision = precision
Expand Down Expand Up @@ -1209,9 +1225,10 @@ def convert_for_runtime(self) -> "nn.Module":


class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding=True):
from build.model import find_multiple
self.mod = mod
self.device = device
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
Expand Down Expand Up @@ -1329,7 +1346,7 @@ def quantized_model(self) -> nn.Module:
### WIP: HQQ ###

class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, groupsize):
def __init__(self, mod, device, *, groupsize):
self.mod = mod
self.groupsize = groupsize

Expand Down

0 comments on commit fbdd08c

Please sign in to comment.