Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Jun 14, 2024
1 parent f7bd557 commit 46dc418
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 47 deletions.
65 changes: 23 additions & 42 deletions examples/llama7b_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,39 @@
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from sparseml.modifiers import GPTQModifier
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot


# define a sparseml recipe for GPTQ FP8 quantization
recipe = """
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "float"
symmetric: true
strategy: "tensor"
input_activations:
num_bits: 8
type: "float"
symmetric: true
strategy: "tensor"
targets: ["Linear"]
"""

# setting device_map to auto to spread the model evenly across all available GPUs
# load the model in as bfloat16 to save on memory and compute
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)
model_stub = "meta-llama/Meta-Llama-3-8B-Instruct"
output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed"
num_calibration_samples = 512

# uses SparseML's built-in preprocessing for ultra chat
dataset = "ultrachat-200k"
tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

# save location of quantized model out
output_dir = "./output_llama7b_fp8_compressed"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]"}
max_seq_length = 512
pad_to_max_length = False
num_calibration_samples = 512
def preprocess(batch):
text = tokenizer.apply_chat_template(batch["messages"], tokenize=False)
tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048)
return tokenized


ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = ds.map(preprocess, remove_columns=ds.column_names)

recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"])

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

# apply recipe to the model and save quantized output in fp8 format
oneshot(
model=model,
dataset=dataset,
dataset=examples,
recipe=recipe,
output_dir=output_dir,
splits=splits,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
9 changes: 4 additions & 5 deletions examples/llama7b_w8a8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@
num_bits: 8
type: "int"
symmetric: true
strategy: "channel"
strategy: "tensor"
input_activations:
num_bits: 8
type: "int"
symmetric: true
dynamic: True
strategy: "token"
strategy: "tensor"
targets: ["Linear"]
"""

# setting device_map to auto to spread the model evenly across all available GPUs
# load the model in as bfloat16 to save on memory and compute
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)
Expand All @@ -37,7 +36,7 @@
dataset = "ultrachat-200k"

# save location of quantized model out
output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed"
output_dir = "./TEST_MAIN_BRANCH_TENSOR"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]"}
Expand Down
2 changes: 2 additions & 0 deletions src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# limitations under the License.

# flake8: noqa

from .gptq import *

0 comments on commit 46dc418

Please sign in to comment.