-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sara Adkins
committed
Jun 14, 2024
1 parent
f7bd557
commit 46dc418
Showing
3 changed files
with
29 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,39 @@ | ||
import torch | ||
from datasets import load_dataset | ||
from transformers import AutoTokenizer | ||
|
||
from sparseml.modifiers import GPTQModifier | ||
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot | ||
|
||
|
||
# define a sparseml recipe for GPTQ FP8 quantization | ||
recipe = """ | ||
quant_stage: | ||
quant_modifiers: | ||
GPTQModifier: | ||
sequential_update: false | ||
ignore: ["lm_head"] | ||
config_groups: | ||
group_0: | ||
weights: | ||
num_bits: 8 | ||
type: "float" | ||
symmetric: true | ||
strategy: "tensor" | ||
input_activations: | ||
num_bits: 8 | ||
type: "float" | ||
symmetric: true | ||
strategy: "tensor" | ||
targets: ["Linear"] | ||
""" | ||
|
||
# setting device_map to auto to spread the model evenly across all available GPUs | ||
# load the model in as bfloat16 to save on memory and compute | ||
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" | ||
model = SparseAutoModelForCausalLM.from_pretrained( | ||
model_stub, torch_dtype=torch.bfloat16, device_map="auto" | ||
) | ||
model_stub = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed" | ||
num_calibration_samples = 512 | ||
|
||
# uses SparseML's built-in preprocessing for ultra chat | ||
dataset = "ultrachat-200k" | ||
tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# save location of quantized model out | ||
output_dir = "./output_llama7b_fp8_compressed" | ||
|
||
# set dataset config parameters | ||
splits = {"calibration": "train_gen[:5%]"} | ||
max_seq_length = 512 | ||
pad_to_max_length = False | ||
num_calibration_samples = 512 | ||
def preprocess(batch): | ||
text = tokenizer.apply_chat_template(batch["messages"], tokenize=False) | ||
tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048) | ||
return tokenized | ||
|
||
|
||
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") | ||
examples = ds.map(preprocess, remove_columns=ds.column_names) | ||
|
||
recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"]) | ||
|
||
model = SparseAutoModelForCausalLM.from_pretrained( | ||
model_stub, torch_dtype=torch.bfloat16, device_map="auto" | ||
) | ||
|
||
# apply recipe to the model and save quantized output in fp8 format | ||
oneshot( | ||
model=model, | ||
dataset=dataset, | ||
dataset=examples, | ||
recipe=recipe, | ||
output_dir=output_dir, | ||
splits=splits, | ||
max_seq_length=max_seq_length, | ||
pad_to_max_length=pad_to_max_length, | ||
num_calibration_samples=num_calibration_samples, | ||
save_compressed=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,5 @@ | |
# limitations under the License. | ||
|
||
# flake8: noqa | ||
|
||
from .gptq import * |