Skip to content

Commit

Permalink
initial grok (#169)
Browse files Browse the repository at this point in the history
Very rough, putting this up for visibility

---------

Co-authored-by: Kyle Herndon <[email protected]>
Co-authored-by: archana-ramalingam <[email protected]>
Co-authored-by: Archana Ramalingam <[email protected]>
  • Loading branch information
4 people authored Sep 26, 2024
1 parent 2818131 commit 9f3f70f
Show file tree
Hide file tree
Showing 21 changed files with 607 additions and 183 deletions.
7 changes: 6 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *


def main():
Expand Down Expand Up @@ -61,8 +62,12 @@ def main():
llama_config.use_hf = False
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"

if llama_config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
if llama_config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, llama_config)
else:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

Expand Down
18 changes: 9 additions & 9 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from ..models.llama.llama import *
from ..models.llama.sharding import shard_theta
from ..utils.debugging import trace_tensor
Expand Down Expand Up @@ -222,11 +223,6 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--attention-dtype",
help="DType to use for attention in the model",
default="float16",
)
parser.add_argument(
"--use-hf",
action="store_true",
Expand All @@ -244,9 +240,8 @@ def main():

device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
attention_dtype = getattr(torch, args.attention_dtype)
assert isinstance(activation_dtype, torch.dtype)
assert isinstance(attention_dtype, torch.dtype)

dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
Expand All @@ -257,17 +252,21 @@ def main():
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
attention_dtype=attention_dtype,
attention_dtype=activation_dtype,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
if config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, config)
else:
model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)

if args.save_intermediates_path:
from ..utils.patching import SaveModuleResultTensorsPatch

Expand Down Expand Up @@ -297,6 +296,7 @@ def main():
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
counter += 1


if __name__ == "__main__":
Expand Down
80 changes: 80 additions & 0 deletions sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from shark_turbine.aot import *
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from ..utils import cli


def main():
parser = cli.create_parser()
parser.add_argument(
"--output-mlir",
help="Output file path for exported MLIR file",
default="/tmp/batch_llama_v1.mlir",
)
parser.add_argument(
"--batch-size",
"-bs",
help="Batch size to generate, e.g. `4` or `2`",
type=lambda arg: int(arg),
default="2",
)
parser.add_argument(
"--verbose",
"-v",
help="Include verbose logging",
action="store_true",
)
parser.add_argument(
"--strict",
help="Enables strictness during export",
action="store_true",
)
parser.add_argument(
"--use-grok",
help="Enable to export Grok model's version of MOE block",
action="store_true",
)

args = cli.parse(parser)

bs = args.batch_size

model = PreGatherMoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
use_grok=args.use_grok,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((bs, 32, 6144))

@fxb.export_program(name="prefill_moe", args=(input,))
def _(model, input: torch.Tensor) -> torch.Tensor:
return model(input)

input = make_rand_torch((bs, 1, 6144))

@fxb.export_program(name="decode_moe", args=(input,))
def _(model, input: torch.Tensor) -> torch.Tensor:
return model(input)

if args.verbose:
for name, ep in fxb.programs.items():
print(f"EXPORT {name}:\n{ep}")

print("Exporting")
output = export(fxb)
print(f"Saving to '{args.output_mlir}'")
output.save_mlir(args.output_mlir)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock
from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock

from . import configs
from .configs import *
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .llm_configs import LlamaHParams
from .llm_configs import *
81 changes: 63 additions & 18 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

from dataclasses import dataclass
from typing import Any, Optional

import torch

__all__ = ["LlamaHParams"]
__all__ = ["LlamaHParams", "LlamaModelConfig"]


@dataclass
Expand All @@ -29,48 +28,55 @@ class LlamaHParams:
Comments are only provided if they differ from this source.
"""

model_arch: str
context_length: int
embedding_length: int
block_count: int
feed_forward_length: int
rope_dimension_count: int
rope_freq_base: float
attention_head_count: int
attn_head_dim: int
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int
expert_count: int
expert_used_count: int
rope_dimension_count: Optional[int] = None
rope_freq_base: Optional[float] = None
expert_count: Optional[int] = None
expert_used_count: Optional[int] = None

@staticmethod
def from_gguf_props(p: dict[str, Any]):
name_prefix = p["general.architecture"]
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
attention_head_count = _int_prop(p, "llama.attention.head_count")
default_rope_dimension_count = 128
attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count")
rope_dimension_count = _optional_int_prop(
p, f"{name_prefix}.rope.dimension_count", default_rope_dimension_count
)

return LlamaHParams(
context_length=_int_prop(p, "llama.context_length"),
embedding_length=_int_prop(p, "llama.embedding_length"),
block_count=_int_prop(p, "llama.block_count"),
feed_forward_length=_int_prop(p, "llama.feed_forward_length"),
attn_head_dim=_int_prop(p, "llama.rope.dimension_count"),
rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"),
model_arch=name_prefix,
context_length=_int_prop(p, f"{name_prefix}.context_length"),
embedding_length=_int_prop(p, f"{name_prefix}.embedding_length"),
block_count=_int_prop(p, f"{name_prefix}.block_count"),
feed_forward_length=_int_prop(p, f"{name_prefix}.feed_forward_length"),
attention_head_count=attention_head_count,
attention_layer_norm_rms_epsilon=_float_prop(
p, "llama.attention.layer_norm_rms_epsilon"
p, f"{name_prefix}.attention.layer_norm_rms_epsilon"
),
attention_head_count_kv=_optional_int_prop(
p, "llama.attention.head_count_kv", attention_head_count
p, f"{name_prefix}.attention.head_count_kv", attention_head_count
),
attn_head_dim=rope_dimension_count,
rope_dimension_count=rope_dimension_count,
rope_freq_base=_optional_float_prop(
p, "llama.rope.freq_base", default_rope_freq_base
p, f"{name_prefix}.rope.freq_base", default_rope_freq_base
),
expert_count=_optional_int_prop(
p, "llama.expert_count", default_expert_count
p, f"{name_prefix}.expert_count", default_expert_count
),
expert_used_count=_optional_int_prop(
p, "llama.expert_used_count", default_expert_used_count
p, f"{name_prefix}.expert_used_count", default_expert_used_count
),
)

Expand Down Expand Up @@ -107,3 +113,42 @@ def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int:
return int(value)
except ValueError as e:
raise ValueError(f"Property '{name}' expected to be an int and was not") from e


@dataclass
class LlamaModelConfig:
hp: LlamaHParams

# Block sequence stride for a paged KV cache. This must divide evenly
# into the context length.
block_seq_stride: int = 16

# Either "paged" or "direct".
kv_cache_type: str = "paged"

# The device on which to place intermediate state.
device: Optional[torch.device] = None

# Dtype to use for general FP activations not otherwise configured.
activation_dtype: torch.dtype = torch.float16

# Dtype to use for attention.
attention_dtype: torch.dtype = torch.float16

# How many devices are involved for tensor parallel sharding.
# If greater than 1, the model will expect sharded model parameters and function
# arguments.
tensor_parallelism_size: int = 1

# Indicates if running with HuggingFace implementation and ensures
# numerical equivalency to HuggingFace's LLaMa if true (by modifying
# rotary embedding).
use_hf: bool = False

# If true, then the model may pre-initialize certain tables during
# init. This can be better for eager execution but when capturing a program,
# it is often better to preserve the calculation explicitly and rely on
# the compiler to transform it to an initialization time step. This can
# be the difference of many gigabytes of static data being embedded in
# the program and not.
static_tables: bool = True
61 changes: 61 additions & 0 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,70 @@

__all__ = [
"FFNMOE",
"PreGatherFFNMOE",
]


class PreGatherFFNMOE(ThetaLayer):
def __init__(
self,
theta: Theta,
use_grok: bool = False,
):

super().__init__(theta)
self.use_grok = use_grok

self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")

def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
inputs = inputs[:, :]
weights = weights[experts, :, :]
matmul = torch.einsum(einstring, inputs, weights.float())
return matmul

def bigger_mmg(self, inputs, weights, experts):
inputs = inputs[:, :]
weights = weights[experts, :, :]
matmul = torch.einsum("mek,menk->men", inputs, weights.float())
return matmul

def one_hot_matmul(self, inputs, weights, experts):
matmul = torch.einsum("mk,bnk->bmn", inputs, weights)
# Post mix the experts
oh = (
torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8)
.transpose(0, 1)
.to(torch.float32)
)
output = torch.einsum("bm,bmn->mn", oh, matmul)
return output

def forward(
self,
h: torch.Tensor,
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
if self.use_grok:
ffn_gate = F.gelu(
self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts)
)
else:
ffn_gate = F.silu(
self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts)
)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men"
)
ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down)
return torch.sum(ffn_down, dim=1)


class FFNMOE(ThetaLayer):
def __init__(
self,
Expand Down
Loading

0 comments on commit 9f3f70f

Please sign in to comment.