From 9f3f70f7f5018702abfa0dbb4954630179abb768 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:14:25 -0500 Subject: [PATCH] initial grok (#169) Very rough, putting this up for visibility --------- Co-authored-by: Kyle Herndon Co-authored-by: archana-ramalingam Co-authored-by: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> --- .../sharktank/examples/export_paged_llm_v1.py | 7 +- sharktank/sharktank/examples/paged_llm_v1.py | 18 +- .../sharktank/export_layer/export_moe.py | 80 ++++++ sharktank/sharktank/layers/__init__.py | 4 +- .../sharktank/layers/configs/__init__.py | 2 +- .../sharktank/layers/configs/llm_configs.py | 81 ++++-- sharktank/sharktank/layers/ffn_moe_block.py | 61 +++++ .../layers/mixture_of_experts_block.py | 73 +++++- sharktank/sharktank/layers/norm.py | 1 - .../layers/paged_llama_attention_block.py | 36 ++- .../sharktank/layers/rotary_embedding.py | 3 +- sharktank/sharktank/models/grok/grok.py | 233 ++++++++++++++++++ sharktank/sharktank/models/llama/llama.py | 77 +----- sharktank/sharktank/models/llama/testing.py | 8 +- sharktank/sharktank/models/mixtral/mixtral.py | 54 +--- sharktank/sharktank/utils/create_cache.py | 34 +++ .../tests/models/llama/attention_test.py | 1 - sharktank/tests/models/llama/kv_cache_test.py | 4 +- .../tests/models/llama/moe_block_test.py | 10 +- .../tests/models/llama/sharded_llama_test.py | 1 + sharktank/tests/types/dataset_test.py | 2 +- 21 files changed, 607 insertions(+), 183 deletions(-) create mode 100644 sharktank/sharktank/export_layer/export_moe.py create mode 100644 sharktank/sharktank/models/grok/grok.py create mode 100644 sharktank/sharktank/utils/create_cache.py diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 58509753d..3e094b494 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -17,6 +17,7 @@ # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 from ..models.mixtral.mixtral import * +from ..models.grok.grok import * def main(): @@ -61,8 +62,12 @@ def main(): llama_config.use_hf = False llama_config.static_tables = False # Rely on the compiler for hoisting tables. llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" + if llama_config.hp.expert_count: - model = PagedMixtralModelV1(dataset.root_theta, llama_config) + if llama_config.hp.model_arch == "grok": + model = PagedGrokModelV1(dataset.root_theta, llama_config) + else: + model = PagedMixtralModelV1(dataset.root_theta, llama_config) else: model = PagedLlamaModelV1(dataset.root_theta, llama_config) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index c425bfefe..70efddfbd 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -20,6 +20,7 @@ # TODO: Should be using a base class with the protocol supported. from ..models.mixtral.mixtral import * +from ..models.grok.grok import * from ..models.llama.llama import * from ..models.llama.sharding import shard_theta from ..utils.debugging import trace_tensor @@ -222,11 +223,6 @@ def main(): help="DType to use for activations in the model", default="float32", ) - parser.add_argument( - "--attention-dtype", - help="DType to use for attention in the model", - default="float16", - ) parser.add_argument( "--use-hf", action="store_true", @@ -244,9 +240,8 @@ def main(): device = torch.device(args.device) if args.device else None activation_dtype = getattr(torch, args.activation_dtype) - attention_dtype = getattr(torch, args.attention_dtype) assert isinstance(activation_dtype, torch.dtype) - assert isinstance(attention_dtype, torch.dtype) + dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) prompts = args.prompt @@ -257,7 +252,7 @@ def main(): kv_cache_type=args.kv_cache_type, device=device, activation_dtype=activation_dtype, - attention_dtype=attention_dtype, + attention_dtype=activation_dtype, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, ) @@ -265,9 +260,13 @@ def main(): dataset.root_theta = shard_theta(dataset.root_theta, config) if config.hp.expert_count: - model = PagedMixtralModelV1(dataset.root_theta, config) + if config.hp.model_arch == "grok": + model = PagedGrokModelV1(dataset.root_theta, config) + else: + model = PagedMixtralModelV1(dataset.root_theta, config) else: model = PagedLlamaModelV1(dataset.root_theta, config) + if args.save_intermediates_path: from ..utils.patching import SaveModuleResultTensorsPatch @@ -297,6 +296,7 @@ def main(): ) print(f":: Result tokens: {batch.results}") batch.print_current_results() + counter += 1 if __name__ == "__main__": diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py new file mode 100644 index 000000000..e8f257bfe --- /dev/null +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -0,0 +1,80 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from shark_turbine.aot import * +from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch +from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from ..utils import cli + + +def main(): + parser = cli.create_parser() + parser.add_argument( + "--output-mlir", + help="Output file path for exported MLIR file", + default="/tmp/batch_llama_v1.mlir", + ) + parser.add_argument( + "--batch-size", + "-bs", + help="Batch size to generate, e.g. `4` or `2`", + type=lambda arg: int(arg), + default="2", + ) + parser.add_argument( + "--verbose", + "-v", + help="Include verbose logging", + action="store_true", + ) + parser.add_argument( + "--strict", + help="Enables strictness during export", + action="store_true", + ) + parser.add_argument( + "--use-grok", + help="Enable to export Grok model's version of MOE block", + action="store_true", + ) + + args = cli.parse(parser) + + bs = args.batch_size + + model = PreGatherMoeBlock( + theta=make_moe_block_theta()("blk.0"), + expert_count=8, + expert_used_count=2, + rms_epsilon=1e-5, + use_grok=args.use_grok, + ) + fxb = FxProgramsBuilder(model) + input = make_rand_torch((bs, 32, 6144)) + + @fxb.export_program(name="prefill_moe", args=(input,)) + def _(model, input: torch.Tensor) -> torch.Tensor: + return model(input) + + input = make_rand_torch((bs, 1, 6144)) + + @fxb.export_program(name="decode_moe", args=(input,)) + def _(model, input: torch.Tensor) -> torch.Tensor: + return model(input) + + if args.verbose: + for name, ep in fxb.programs.items(): + print(f"EXPORT {name}:\n{ep}") + + print("Exporting") + output = export(fxb) + print(f"Saving to '{args.output_mlir}'") + output.save_mlir(args.output_mlir) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 181544763..a90def3a9 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -16,6 +16,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock from .ffn_block import FFN from .ffn_moe_block import FFNMOE -from .mixture_of_experts_block import SparseMoeBlock +from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock -from . import configs +from .configs import * diff --git a/sharktank/sharktank/layers/configs/__init__.py b/sharktank/sharktank/layers/configs/__init__.py index 21336d1d2..c5d75c602 100644 --- a/sharktank/sharktank/layers/configs/__init__.py +++ b/sharktank/sharktank/layers/configs/__init__.py @@ -4,4 +4,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .llm_configs import LlamaHParams +from .llm_configs import * diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index ab3a582e4..ea7b88175 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -16,10 +16,9 @@ from dataclasses import dataclass from typing import Any, Optional - import torch -__all__ = ["LlamaHParams"] +__all__ = ["LlamaHParams", "LlamaModelConfig"] @dataclass @@ -29,48 +28,55 @@ class LlamaHParams: Comments are only provided if they differ from this source. """ + model_arch: str context_length: int embedding_length: int block_count: int feed_forward_length: int - rope_dimension_count: int - rope_freq_base: float attention_head_count: int attn_head_dim: int attention_layer_norm_rms_epsilon: float attention_head_count_kv: int - expert_count: int - expert_used_count: int + rope_dimension_count: Optional[int] = None + rope_freq_base: Optional[float] = None + expert_count: Optional[int] = None + expert_used_count: Optional[int] = None @staticmethod def from_gguf_props(p: dict[str, Any]): + name_prefix = p["general.architecture"] default_expert_count = 0 default_expert_used_count = 0 default_rope_freq_base = 10000.0 - attention_head_count = _int_prop(p, "llama.attention.head_count") + default_rope_dimension_count = 128 + attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count") + rope_dimension_count = _optional_int_prop( + p, f"{name_prefix}.rope.dimension_count", default_rope_dimension_count + ) return LlamaHParams( - context_length=_int_prop(p, "llama.context_length"), - embedding_length=_int_prop(p, "llama.embedding_length"), - block_count=_int_prop(p, "llama.block_count"), - feed_forward_length=_int_prop(p, "llama.feed_forward_length"), - attn_head_dim=_int_prop(p, "llama.rope.dimension_count"), - rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"), + model_arch=name_prefix, + context_length=_int_prop(p, f"{name_prefix}.context_length"), + embedding_length=_int_prop(p, f"{name_prefix}.embedding_length"), + block_count=_int_prop(p, f"{name_prefix}.block_count"), + feed_forward_length=_int_prop(p, f"{name_prefix}.feed_forward_length"), attention_head_count=attention_head_count, attention_layer_norm_rms_epsilon=_float_prop( - p, "llama.attention.layer_norm_rms_epsilon" + p, f"{name_prefix}.attention.layer_norm_rms_epsilon" ), attention_head_count_kv=_optional_int_prop( - p, "llama.attention.head_count_kv", attention_head_count + p, f"{name_prefix}.attention.head_count_kv", attention_head_count ), + attn_head_dim=rope_dimension_count, + rope_dimension_count=rope_dimension_count, rope_freq_base=_optional_float_prop( - p, "llama.rope.freq_base", default_rope_freq_base + p, f"{name_prefix}.rope.freq_base", default_rope_freq_base ), expert_count=_optional_int_prop( - p, "llama.expert_count", default_expert_count + p, f"{name_prefix}.expert_count", default_expert_count ), expert_used_count=_optional_int_prop( - p, "llama.expert_used_count", default_expert_used_count + p, f"{name_prefix}.expert_used_count", default_expert_used_count ), ) @@ -107,3 +113,42 @@ def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int: return int(value) except ValueError as e: raise ValueError(f"Property '{name}' expected to be an int and was not") from e + + +@dataclass +class LlamaModelConfig: + hp: LlamaHParams + + # Block sequence stride for a paged KV cache. This must divide evenly + # into the context length. + block_seq_stride: int = 16 + + # Either "paged" or "direct". + kv_cache_type: str = "paged" + + # The device on which to place intermediate state. + device: Optional[torch.device] = None + + # Dtype to use for general FP activations not otherwise configured. + activation_dtype: torch.dtype = torch.float16 + + # Dtype to use for attention. + attention_dtype: torch.dtype = torch.float16 + + # How many devices are involved for tensor parallel sharding. + # If greater than 1, the model will expect sharded model parameters and function + # arguments. + tensor_parallelism_size: int = 1 + + # Indicates if running with HuggingFace implementation and ensures + # numerical equivalency to HuggingFace's LLaMa if true (by modifying + # rotary embedding). + use_hf: bool = False + + # If true, then the model may pre-initialize certain tables during + # init. This can be better for eager execution but when capturing a program, + # it is often better to preserve the calculation explicitly and rely on + # the compiler to transform it to an initialization time step. This can + # be the difference of many gigabytes of static data being embedded in + # the program and not. + static_tables: bool = True diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index d2f51d2d9..266537c89 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -15,9 +15,70 @@ __all__ = [ "FFNMOE", + "PreGatherFFNMOE", ] +class PreGatherFFNMOE(ThetaLayer): + def __init__( + self, + theta: Theta, + use_grok: bool = False, + ): + + super().__init__(theta) + self.use_grok = use_grok + + self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") + self.ffn_up = theta.tensor("ffn_up_exps", "weight") + self.ffn_down = theta.tensor("ffn_down_exps", "weight") + + def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): + inputs = inputs[:, :] + weights = weights[experts, :, :] + matmul = torch.einsum(einstring, inputs, weights.float()) + return matmul + + def bigger_mmg(self, inputs, weights, experts): + inputs = inputs[:, :] + weights = weights[experts, :, :] + matmul = torch.einsum("mek,menk->men", inputs, weights.float()) + return matmul + + def one_hot_matmul(self, inputs, weights, experts): + matmul = torch.einsum("mk,bnk->bmn", inputs, weights) + # Post mix the experts + oh = ( + torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8) + .transpose(0, 1) + .to(torch.float32) + ) + output = torch.einsum("bm,bmn->mn", oh, matmul) + return output + + def forward( + self, + h: torch.Tensor, + experts: torch.Tensor, + expert_gate: torch.Tensor, + ): + if self.use_grok: + ffn_gate = F.gelu( + self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts) + ) + else: + ffn_gate = F.silu( + self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts) + ) + + ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) + ffn_down = self.pre_matmul_gather( + ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men" + ) + ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down) + return torch.sum(ffn_down, dim=1) + + class FFNMOE(ThetaLayer): def __init__( self, diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index 5f6e592f9..f788d06f0 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -13,10 +13,11 @@ from .base import Theta, ThetaLayer from .linear import LinearLayer from .norm import RMSNormLayer -from .ffn_moe_block import FFNMOE +from .ffn_moe_block import FFNMOE, PreGatherFFNMOE __all__ = [ "SparseMoeBlock", + "PreGatherMoeBlock", ] @@ -113,3 +114,73 @@ def forward( moe_output = self.layer_output_norm(moe_output) return h + moe_output + + +class PreGatherMoeBlock(ThetaLayer): + """ + This implementation considers MoE operations as block-sparse + operations to support imbalanced token assignments to experts. + This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens + (or reduced performance). + """ + + def __init__( + self, + theta: Theta, + expert_count: int, + expert_used_count: int, + rms_epsilon: float, + use_grok: Optional[bool] = False, + ): + super().__init__(theta) + + self.expert_count = expert_count + self.expert_used_count = expert_used_count + self.use_grok = use_grok + + # Add router gate + self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) + + # Add FFN norm + self.add_module( + "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) + ) + + # Add FFN output norm layer for Grok + if self.use_grok: + self.add_module( + "layer_output_norm", + RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), + ) + + # Add expert_count x FFN + self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok) + + def forward( + self, + h: torch.Tensor, + ): + ffn_input = self.ffn_norm(h) + batch_size, sequence_length, feature_dim = ffn_input.shape + ffn_input = ffn_input.view(-1, feature_dim) + + # For each token, the router calculates the router weights for all experts + # router_logits: (batch_size * sequence_length, expert_count) + router_logits = self.ffn_gate_inp(ffn_input) + router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + # Select top k experts from router weights + expert_gate, top_k_experts = torch.topk( + router_weights, self.expert_used_count, dim=-1 + ) + + expert_gate /= expert_gate.sum(dim=-1, keepdim=True) + expert_gate = expert_gate.to(ffn_input.dtype) + + moe_output = self.experts(ffn_input, top_k_experts, expert_gate) + moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) + + if self.use_grok: + moe_output = self.layer_output_norm(moe_output) + + return h + moe_output diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index a696a1336..4fa08050a 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -25,7 +25,6 @@ def __init__( weight_name: str = "weight", epsilon: float = 1e-6, dtype: torch.dtype = torch.float32, - debug_save_file=None, ): super().__init__(theta) self.weight = self.theta_tensor(weight_name) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 5ffd9999d..fb6a98e1b 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,9 +37,17 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_hf: bool = False, + use_grok: Optional[bool] = False, ): super().__init__(theta) + + self.block_index = block_index + self.cache = cache + self.head_count = head_count + self.head_dim = head_dim + self.head_count_kv = head_count_kv + self.use_grok = use_grok + self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) ) @@ -48,12 +56,11 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - self.block_index = block_index - self.cache = cache - self.head_count = head_count - self.head_dim = head_dim - self.head_count_kv = head_count_kv - self.use_hf = use_hf + if self.use_grok: + self.add_module( + "attn_output_norm", + RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), + ) def forward( self, @@ -141,7 +148,15 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = xv.transpose(1, 2) # Flash attention. - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if not self.use_grok: + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + elif self.use_grok: + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + attn_weights = 30.0 * torch.tanh( + attn_weights * (0.08838834764831845 / 30.0) + ) self.assert_not_nan(attn_weights) # Apply attention mask. @@ -158,7 +173,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Project. attn_output = self.attn_output(attn_output) - # Remainder of the block. + if self.use_grok: + attn_output = self.attn_output_norm(attn_output) + h = h + attn_output return h @@ -183,6 +200,7 @@ def transact_cache_direct( return xk_cache_update, xv_cache_update else: # Decode. Write a single timestep. + # TODO: This needs to be reworked with index ops. assert xk_cache_update.shape[1] == 1 assert xv_cache_update.shape[1] == 1 for b in range(bs): diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index d177f8f43..8549c26eb 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -21,7 +21,7 @@ def __init__( *, rope_dimension_count: int, max_seqlen: int, - rope_freq_base: Optional[float] = None, + rope_freq_base: float, device: Optional[torch.device] = None, use_hf: bool = False, static_tables: bool = True, @@ -35,6 +35,7 @@ def __init__( self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf + self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 self.tensor_parallelism_size = tensor_parallelism_size if static_tables: diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py new file mode 100644 index 000000000..debeb30c7 --- /dev/null +++ b/sharktank/sharktank/models/grok/grok.py @@ -0,0 +1,233 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import torch +import torch.nn as nn + + +from ...layers import * +from ...utils.create_cache import * +from ...types import Theta + +torch.set_printoptions(profile="full") + +__all__ = [ + "PagedGrokModelV1", +] + +################################################################################ +# Models +################################################################################ + + +class PagedGrokModelV1(BaseCausalLMModel): + """Grok model with a paged KV cache and supporting variable sequence + length batched inference. + + As both the caching and batching setup is complicated, this model variant + is modular, intending to be instantiated and used in an overall assembly + vs trying to providing one-stop methods that do everything. + + The inference procedure is typically: + + 1. Initialize the PagedKVCache state tensors. + 2. Generate an input mask given a vector of sequence lengths. + 3. Generate an attention mask from the input mask. + 4. Allocate a block mapping table. + 5. Invoke prefill() with a batch of sequences. + 6. Extract tokens from batched logits. + 7. Iteratively invoke decode() for as long as there are sequences needing + to be serviced. + + Various samplers and schedulers can be interleaved throughout. + """ + + def __init__(self, theta: Theta, config: LlamaModelConfig): + hp = config.hp + super().__init__( + theta, + context_length=config.hp.context_length, + device=config.device, + activation_dtype=config.activation_dtype, + attention_dtype=config.attention_dtype, + ) + self.config = config + self.hp = hp + self.cache = create_kv_cache(self.config) + self.activation_dtype = config.activation_dtype + self.add_module( + "token_embedding", + TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + ) + self.add_module( + "attention_embedding", + RotaryEmbeddingLayer( + rope_dimension_count=hp.rope_dimension_count, + rope_freq_base=hp.rope_freq_base, + max_seqlen=hp.context_length, + device=self.device, + use_hf=True, + ), + ) + self.add_module( + "output_norm", + RMSNormLayer( + theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon + ), + ) + self.add_module("output_lm_head", LinearLayer(theta("output"))) + + self.attn_blocks = nn.ModuleList() + + for n in range(hp.block_count): + self.attn_blocks.append( + PagedLlamaAttentionBlock( + theta("blk", n), + block_index=n, + cache=self.cache, + head_count=hp.attention_head_count, + head_dim=hp.attn_head_dim, + head_count_kv=hp.attention_head_count_kv, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + use_grok=True, + ) + ) + self.attn_blocks.append( + PreGatherMoeBlock( + theta("blk", n), + expert_count=hp.expert_count, + expert_used_count=hp.expert_used_count, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + use_grok=True, + ) + ) + + def prefill( + self, + # [bs, batch_seq_len] + tokens: torch.Tensor, + *, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self._assert_device(tokens) + self._assert_device(attention_mask, dtype=self.activation_dtype) + self._assert_device(seq_block_ids) + self._assert_device(*cache_state, dtype=self.activation_dtype) + h = self.token_embedding(tokens) + h *= 78.38367176906169 + self.trace_tensor("grok.token_embedding", h) + + # Iterate over attention blocks. + for block_idx, block in enumerate(self.attn_blocks): + if block_idx == 0: + self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) + + if block.__class__.__name__ == "PagedLlamaAttentionBlock": + h = block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + elif block.__class__.__name__ == "PreGatherMoeBlock": + h = block( + h, + ) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + logits = logits * 0.5773502691896257 + return logits + + def decode( + self, + # [bs, 1] + tokens: torch.Tensor, + *, + # [bs, 1, 1, batch_seq_len] + attention_mask: torch.Tensor, + # [bs] of starting positions + start_positions: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self._assert_device(tokens) + self._assert_device(attention_mask, dtype=self.activation_dtype) + self._assert_device(start_positions) + self._assert_device(*cache_state, dtype=self.activation_dtype) + bs, _ = tokens.shape + # Precompute a position based mask for computing rope embeddings + # as it is the same for all blocks. + embedding_batch_mask = self.attention_embedding.compute_batch_mask( + start_positions, batch_seq_len=1 + ) + self.trace_tensor("grok.embedding_batch_mask", embedding_batch_mask) + + # Allocate per-block temporary K/V tensors. These temporaries hold + # one block's K/V state for the maximum context length. + xk_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + xv_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + + h = self.token_embedding(tokens) + h *= 78.38367176906169 + self.trace_tensor("grok.token_embedding", h) + + # Iterate over attention blocks. + for block_idx, block in enumerate(self.attn_blocks): + if block_idx == 0: + self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) + + if block.__class__.__name__ == "PagedLlamaAttentionBlock": + h = block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + elif block.__class__.__name__ == "PreGatherMoeBlock": + h = block( + h, + ) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + logits = logits * 0.5773502691896257 + return logits diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 8e5a53d41..c324a79d5 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -16,83 +16,13 @@ from ...layers import * from ...types import * +from ...utils.create_cache import * from ... import ops __all__ = [ - "LlamaModelConfig", "PagedLlamaModelV1", ] -################################################################################ -# Config -################################################################################ - - -@dataclass -class LlamaModelConfig: - hp: configs.LlamaHParams - - # Block sequence stride for a paged KV cache. This must divide evenly - # into the context length. - block_seq_stride: int = 16 - - # Either "paged" or "direct". - kv_cache_type: str = "paged" - - # The device on which to place intermediate state. - device: Optional[torch.device] = None - - # Dtype to use for general FP activations not otherwise configured. - activation_dtype: torch.dtype = torch.float16 - - # Dtype to use for attention. - attention_dtype: torch.dtype = torch.float16 - - # Indicates if running with HuggingFace implementation and ensures - # numerical equivalency to HuggingFace's LLaMa if true (by modifying - # rotary embedding). - use_hf: bool = False - - # If true, then the model may pre-initialize certain tables during - # init. This can be better for eager execution but when capturing a program, - # it is often better to preserve the calculation explicitly and rely on - # the compiler to transform it to an initialization time step. This can - # be the difference of many gigabytes of static data being embedded in - # the program and not. - static_tables: bool = True - - # How many devices are involved for tensor parallel sharding. - # If greater than 1, the model will expect sharded model parameters and function - # arguments. - tensor_parallelism_size: int = 1 - - def create_kv_cache(self) -> BaseKVCache: - hp = self.hp - if self.kv_cache_type == "direct": - return DirectKVCache( - block_seq_stride=self.block_seq_stride, - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - seq_length=hp.context_length, - device=self.device, - dtype=self.attention_dtype, - ) - elif self.kv_cache_type == "paged": - return PagedKVCache( - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. - block_seq_stride=self.block_seq_stride, - device=self.device, - dtype=self.attention_dtype, - shard_count=self.tensor_parallelism_size, - ) - else: - raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") - - ################################################################################ # Models ################################################################################ @@ -145,7 +75,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = config.create_kv_cache() + self.cache = create_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.use_hf = config.use_hf @@ -182,7 +112,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_hf=self.use_hf, ) for n in range(hp.block_count) ] @@ -399,7 +328,6 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_hf: bool = False, ): super().__init__(theta) self.add_module( @@ -412,7 +340,6 @@ def __init__( head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, - use_hf=use_hf, ), ) self.add_module( diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index c332cf90b..079602b28 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -102,7 +102,7 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta return Theta( { "blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor( - data=make_rand_torch((feature_dim, ffn_dim)) + data=make_rand_torch((num_experts, ffn_dim)) ), "blk.0.ffn_norm.weight": DefaultPrimitiveTensor( data=make_rand_torch((ffn_dim)) @@ -111,13 +111,13 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta data=make_rand_torch((ffn_dim)) ), "blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((8, feature_dim * num_experts, ffn_dim)) + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)) ), "blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((8, feature_dim * num_experts, ffn_dim)) + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)) ), "blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((8, ffn_dim, feature_dim * num_experts)) + data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)) ), } ) diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 5a179e5b9..1fc86f87d 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -13,65 +13,15 @@ from ...layers import * +from ...utils.create_cache import * from ...types import Theta torch.set_printoptions(profile="full") __all__ = [ - "LlamaModelConfig", "PagedMixtralModelV1", ] -################################################################################ -# Config -################################################################################ - - -@dataclass -class LlamaModelConfig: - hp: configs.LlamaHParams - - # Block sequence stride for a paged KV cache. This must divide evenly - # into the context length. - block_seq_stride: int = 16 - - # Either "paged" or "direct". - kv_cache_type: str = "paged" - - # The device on which to place intermediate state. - device: Optional[torch.device] = None - - # Dtype to use for general FP activations not otherwise configured. - activation_dtype: torch.dtype = torch.float16 - - # Dtype to use for attention. - attention_dtype: torch.dtype = torch.float16 - - def create_kv_cache(self) -> BaseKVCache: - hp = self.hp - if self.kv_cache_type == "direct": - return DirectKVCache( - block_seq_stride=self.block_seq_stride, - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - seq_length=hp.context_length, - device=self.device, - dtype=self.attention_dtype, - ) - elif self.kv_cache_type == "paged": - return PagedKVCache( - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. - block_seq_stride=self.block_seq_stride, - device=self.device, - dtype=self.attention_dtype, - ) - else: - raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") - ################################################################################ # Models @@ -111,7 +61,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = config.create_kv_cache() + self.cache = create_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.add_module( "token_embedding", diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py new file mode 100644 index 000000000..c1691c8a8 --- /dev/null +++ b/sharktank/sharktank/utils/create_cache.py @@ -0,0 +1,34 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..layers import * + + +def create_kv_cache(config: LlamaModelConfig) -> BaseKVCache: + hp = config.hp + if config.kv_cache_type == "direct": + return DirectKVCache( + block_seq_stride=config.block_seq_stride, + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + seq_length=hp.context_length, + device=config.device, + dtype=config.attention_dtype, + ) + elif config.kv_cache_type == "paged": + return PagedKVCache( + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=config.block_seq_stride, + device=config.device, + dtype=config.attention_dtype, + shard_count=config.tensor_parallelism_size, + ) + else: + raise NotImplementedError(f"kv_cache_type = {config.kv_cache_type}") diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index bb5eb254d..eb8013a8c 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -58,7 +58,6 @@ def test(self): head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, - use_hf=True, ) attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=rope_dimension_count, diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index 9d36db2e2..a80575951 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -28,6 +28,7 @@ def setUp(self): self.block_seq_stride = 16 self.rms_epsilon = 1e-5 self.rope_dimension_count = 128 + self.rope_freq_base = 10000.0 self.max_seq_len = 4096 self.start_positions = torch.tensor([8]) self.bs = 1 @@ -58,6 +59,7 @@ def setUp(self): ) self.attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, + rope_freq_base=self.rope_freq_base, max_seqlen=self.max_seq_len, device=self.device, use_hf=False, @@ -72,7 +74,6 @@ def setUp(self): head_dim=self.head_dim, head_count_kv=self.head_count_kv, rms_epsilon=self.rms_epsilon, - use_hf=False, ) for n in range(self.block_count) ] @@ -87,7 +88,6 @@ def setUp(self): head_dim=self.head_dim, head_count_kv=self.head_count_kv, rms_epsilon=self.rms_epsilon, - use_hf=False, ) for n in range(self.block_count) ] diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index e04ca11fd..53706f1bd 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -10,23 +10,23 @@ import torch from shark_turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import SparseMoeBlock +from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from sharktank import ops class SparseMoeBlockTest(unittest.TestCase): - @unittest.skip("Skip test until grok implementation") def test(self): - model = SparseMoeBlock( + model = PreGatherMoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, + use_grok=False, ) fxb = FxProgramsBuilder(model) - input = make_rand_torch((2, 16, 6144)) + input = make_rand_torch((2, 32, 6144)) - @fxb.export_program(name="moe_block", args=(input,)) + @fxb.export_program(name="moe_block", args=(input,), strict=False) def _(model, input: torch.Tensor) -> torch.Tensor: return model(input) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 174709ab7..5598c29f4 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -44,6 +44,7 @@ def testToyModelCompareToUnsharded(self): attention_head_count_kv=attention_head_count_kv, expert_count=0, expert_used_count=0, + model_arch="llama", ), block_seq_stride=block_seq_stride, activation_dtype=dtype, diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 1164fdbcf..353bacbb0 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -16,7 +16,7 @@ def _t(name: str, *dims: int): - return DefaultPrimitiveTensor(name=name, data=torch.empty(*dims)) + return DefaultPrimitiveTensor(name=name, data=torch.ones(*dims)) def _flat_t_dict(*ts):