diff --git a/.github/workflows/compile_t4.yml b/.github/workflows/compile_t4.yml index 65f795a71..e96d42fba 100644 --- a/.github/workflows/compile_t4.yml +++ b/.github/workflows/compile_t4.yml @@ -63,35 +63,35 @@ jobs: echo "******************************************" echo "******** Emb: group-wise quantized *******" echo "******************************************" - # python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled - # cat ./output_compiled - # python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so - # python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti - # cat ./output_aoti + python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti echo "******************************************" echo "******* INT8 channel-wise quantized ******" echo "******************************************" - # python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled - # cat ./output_compiled - # python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so - # python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti - # cat ./output_aoti + python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti echo "******************************************" echo "******** INT8 group-wise quantized *******" echo "******************************************" - # python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled - # cat ./output_compiled - # python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so - # python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti - # cat ./output_aoti + python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti echo "tests complete" echo "******************************************" diff --git a/.github/workflows/test_mps.yml b/.github/workflows/test_mps.yml index f8e166790..736433b1d 100644 --- a/.github/workflows/test_mps.yml +++ b/.github/workflows/test_mps.yml @@ -48,14 +48,29 @@ jobs: python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager cat ./output_eager + + echo "************************************************************" + echo "*** embedding" + echo "************************************************************" + python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager cat ./output_eager python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager cat ./output_eager - # python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager + + echo "************************************************************" + echo "*** linear int8" + echo "************************************************************" + + python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + + echo "************************************************************" + echo "*** linear int4" + echo "************************************************************" + # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager \ No newline at end of file diff --git a/quantize.py b/quantize.py index 2d596b795..0be6b4415 100644 --- a/quantize.py +++ b/quantize.py @@ -349,7 +349,7 @@ def quantized_model(self) -> nn.Module: ##### Weight-only int8 per-channel quantized code ###### -def replace_linear_weight_only_int8_per_channel(module, node_type, groupsize=None): +def replace_linear_weight_only_int8_per_channel(module, device, node_type, groupsize=None): if groupsize is not None and groupsize != 0: pass # groupsize = 2 ** groupsize @@ -367,10 +367,10 @@ def replace_linear_weight_only_int8_per_channel(module, node_type, groupsize=Non setattr( module, name, - WeightOnlyInt8Linear(child.in_features, child.out_features, groupsize), + WeightOnlyInt8Linear(device, child.in_features, child.out_features, groupsize), ) else: - replace_linear_weight_only_int8_per_channel(child, node_type, groupsize) + replace_linear_weight_only_int8_per_channel(child, device, node_type, groupsize) class WeightOnlyInt8QuantHandler(QuantHandler): @@ -434,6 +434,8 @@ def create_quantized_state_dict(self) -> Dict: scales_dtype=mod.weight.dtype, ) + weight = weight.to(device=device) + scales = scales.to(device=device) cur_state_dict[f"{fqn}.weight"] = weight # squeeze makes groupsize=rowsize unidimensional cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1) @@ -441,7 +443,7 @@ def create_quantized_state_dict(self) -> Dict: return cur_state_dict def convert_for_runtime(self) -> nn.Module: - replace_linear_weight_only_int8_per_channel(self.mod, self.node_type, self.groupsize) + replace_linear_weight_only_int8_per_channel(self.mod, self.device, self.node_type, self.groupsize) return self.mod def quantized_model(self) -> nn.Module: @@ -459,6 +461,7 @@ class WeightOnlyInt8Linear(torch.nn.Module): def __init__( self, + device, in_features: int, out_features: int, groupsize: Optional[int] = None, @@ -472,14 +475,14 @@ def __init__( self.in_features = in_features self.out_features = out_features self.register_buffer( - "weight", torch.empty((out_features, in_features), dtype=torch.int8) + "weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device) ) dtype=get_precision() if groupsize is None or (groupsize == 0): - self.register_buffer("scales", torch.ones(out_features, dtype=dtype)) + self.register_buffer("scales", torch.ones(out_features, dtype=dtype, device=device)) else: groups = (in_features + groupsize - 1) // groupsize - self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype)) + self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype, device=device)) def forward(self, input: torch.Tensor) -> torch.Tensor: scales = self.scales