Skip to content

Commit

Permalink
Quark dataset importer for fp8 (#96)
Browse files Browse the repository at this point in the history
:edited:

Provides a entry point for quark models that are stored as a safetensor
file + config.json. With a little effort this could also be adapted to
become the equivalent of the hf importer from gguf without the
intermediate step.

I removed a lot of the useful debugging tooling I developed for this
because it doesn't play nicely with torch.export, but I imagine a lot of
it could be re-added if I guarded it in some way, so I'm keeping a
branch with those changes for reference.
  • Loading branch information
dan-garvey authored Sep 24, 2024
1 parent bb7683e commit 9f4283d
Show file tree
Hide file tree
Showing 11 changed files with 507 additions and 10 deletions.
2 changes: 2 additions & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def main():
)

args = cli.parse(parser)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
Expand Down
20 changes: 18 additions & 2 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import Optional

from safetensors import safe_open

import math
import sys

Expand All @@ -20,7 +22,7 @@
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
from ..utils.tokenizer import InferenceTokenizer


class TorchGenerator:
Expand Down Expand Up @@ -51,6 +53,7 @@ def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)

token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
Expand Down Expand Up @@ -218,13 +221,25 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--attention-dtype",
help="DType to use for attention in the model",
default="float16",
)
parser.add_argument(
"--use-hf",
action="store_true",
default=False,
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)

device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
attention_dtype = getattr(torch, args.attention_dtype)
assert isinstance(activation_dtype, torch.dtype)
assert isinstance(attention_dtype, torch.dtype)
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
Expand All @@ -235,7 +250,8 @@ def main():
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
attention_dtype=attention_dtype,
use_hf=args.use_hf,
)

if config.hp.expert_count:
Expand Down
9 changes: 6 additions & 3 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from typing import Optional

import torch

from .. import ops
from .base import Theta, ThetaLayer
from ..types.layout_utils import saturate_cast
from ..types import (
DynamicScaledQuantizer,
QuantizedTensor,
Expand Down Expand Up @@ -54,19 +52,21 @@ def __init__(
self.qdq_input: Optional[QuantizedTensor] = theta.optional_tensor("qdq_input")
if self.q_input is not None and self.qdq_input is not None:
raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input")
self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output")

def forward(self, x):
weight = self.weight
bias = self.bias
q_input = self.q_input
qdq_input = self.qdq_input

qdq_output = self.qdq_output
if self.premul_input is not None:
x = ops.elementwise(torch.mul, x, self.premul_input)

if q_input is not None:
x = q_input.quantize(x)
elif qdq_input is not None:
# TODO: probably need a way to only do q_input if exporting.
x = qdq_input.quantize(x).unpack().dequant()

y = ops.linear(x, weight, bias)
Expand All @@ -76,4 +76,7 @@ def forward(self, x):
# the QuantizedTensor escape.
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
if qdq_output is not None:
# TODO: same as above.
y = qdq_output.quantize(y).unpack().dequant()
return y
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
weight_name: str = "weight",
epsilon: float = 1e-6,
dtype: torch.dtype = torch.float32,
debug_save_file=None,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
Expand Down
3 changes: 1 addition & 2 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from dataclasses import dataclass
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -151,7 +150,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
),
)
self.add_module("output_lm_head", LinearLayer(theta("output")))

self.attn_blocks = nn.ModuleList(
[
AttentionFFNBlock(
Expand Down Expand Up @@ -349,6 +347,7 @@ def forward(
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Feed forward network.
ffn_input = self.ffn_norm(h)
ffn_down = self.ffn(ffn_input)
Expand Down
Loading

0 comments on commit 9f4283d

Please sign in to comment.