Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quark dataset importer for fp8 #96

Merged
merged 28 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6e2e1ec
(WIP) llama fp8 safetensor conversion
dan-garvey Jul 9, 2024
b5e0545
add in changes for loading model.
dan-garvey Jul 9, 2024
2967d86
move file
dan-garvey Jul 9, 2024
409e928
re-add original
dan-garvey Jul 9, 2024
70cadbb
fix importer
dan-garvey Jul 10, 2024
be8352b
dont quantize quantized parameters facedesk
dan-garvey Jul 10, 2024
2acb54c
datatype
dan-garvey Jul 10, 2024
92c2f7c
add some fixes to run
dan-garvey Jul 19, 2024
d285bec
mid-debug
dan-garvey Aug 12, 2024
0e82323
more debug
dan-garvey Aug 13, 2024
3c0690d
update
dan-garvey Aug 13, 2024
82878ec
rebased
dan-garvey Aug 13, 2024
2979f7b
checkpoint before swapping quant style
dan-garvey Aug 21, 2024
d212566
holy $*@($@ a working importer
dan-garvey Aug 21, 2024
cb9e642
stable prefill
dan-garvey Aug 28, 2024
d3bc063
Rework import quark dataset
dan-garvey Aug 28, 2024
0d2d78f
some cleanup
dan-garvey Aug 28, 2024
65256e2
remove some device calls, add some comments
dan-garvey Aug 28, 2024
7433ae9
remove some default values
dan-garvey Aug 28, 2024
7f3c963
last pass?
dan-garvey Aug 28, 2024
b93aa68
add a test for Theta.pop()
dan-garvey Aug 29, 2024
9becccf
Merge branch 'main' into llama_fp8
dan-garvey Sep 4, 2024
57368b5
address comments
dan-garvey Sep 5, 2024
0fe9da1
remove print from linear
dan-garvey Sep 5, 2024
6cd7880
Merge branch 'main' into llama_fp8
dan-garvey Sep 20, 2024
83d385c
adds support for fn->fnuz
dan-garvey Sep 20, 2024
f6f6245
remove spurious clone
dan-garvey Sep 20, 2024
f1e1bcb
Merge branch 'main' into llama_fp8
dan-garvey Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def main():
)

args = cli.parse(parser)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
Expand Down
20 changes: 18 additions & 2 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import Optional

from safetensors import safe_open

import math
import sys

Expand All @@ -19,7 +21,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
from ..utils.tokenizer import InferenceTokenizer


class TorchGenerator:
Expand Down Expand Up @@ -50,6 +52,7 @@ def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)

token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
Expand Down Expand Up @@ -217,13 +220,25 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--attention-dtype",
help="DType to use for attention in the model",
default="float16",
)
parser.add_argument(
"--use-hf",
action="store_true",
default=False,
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)

device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
attention_dtype = getattr(torch, args.attention_dtype)
assert isinstance(activation_dtype, torch.dtype)
assert isinstance(attention_dtype, torch.dtype)
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
Expand All @@ -234,7 +249,8 @@ def main():
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
attention_dtype=attention_dtype,
use_hf=args.use_hf,
)
model = PagedLlamaModelV1(dataset.root_theta, config)
if args.save_intermediates_path:
Expand Down
11 changes: 8 additions & 3 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from typing import Optional

import torch

from .. import ops
from .base import Theta, ThetaLayer
from ..types.layout_utils import saturate_cast
from ..types import (
DynamicScaledQuantizer,
QuantizedTensor,
Expand Down Expand Up @@ -54,19 +52,23 @@ def __init__(
self.qdq_input: Optional[QuantizedTensor] = theta.optional_tensor("qdq_input")
if self.q_input is not None and self.qdq_input is not None:
raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input")
self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output")

def forward(self, x):
weight = self.weight
bias = self.bias
q_input = self.q_input
qdq_input = self.qdq_input

qdq_output = self.qdq_output
original_input = x.clone().detach()
if self.premul_input is not None:
x = ops.elementwise(torch.mul, x, self.premul_input)

if q_input is not None:
x = q_input.quantize(x)
elif qdq_input is not None:
# TODO: probably need a way to only do q_input if exporting.
print("qdq input")
dan-garvey marked this conversation as resolved.
Show resolved Hide resolved
x = qdq_input.quantize(x).unpack().dequant()

y = ops.linear(x, weight, bias)
Expand All @@ -76,4 +78,7 @@ def forward(self, x):
# the QuantizedTensor escape.
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
if qdq_output is not None:
# TODO: same as above.
y = qdq_output.quantize(y).unpack().dequant()
return y
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
weight_name: str = "weight",
epsilon: float = 1e-6,
dtype: torch.dtype = torch.float32,
debug_save_file=None,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
Expand Down
8 changes: 0 additions & 8 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from dataclasses import dataclass
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -150,7 +149,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
),
)
self.add_module("output_lm_head", LinearLayer(theta("output")))

self.attn_blocks = nn.ModuleList(
[
PagedLlamaAttentionBlock(
Expand Down Expand Up @@ -298,7 +296,6 @@ def __init__(
use_hf: bool = False,
):
super().__init__(theta)

self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
Expand Down Expand Up @@ -337,12 +334,9 @@ def forward(
xv_temp: Optional[torch.Tensor] = None,
):
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)

x = self.attn_norm(h)

bs, batch_seq_len, feature_dim = x.shape
assert feature_dim == self.head_count * self.head_dim

xq = self.attn_q(x)
xk = self.attn_k(x)
xv = self.attn_v(x)
Expand Down Expand Up @@ -415,11 +409,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
if attention_mask is not None:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
attn_output = self.attn_output(attn_output)

Expand Down
Loading
Loading