diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 47d281565..d729855d2 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -18,6 +18,7 @@ # TODO: Should be using a base class with the protocol supported. from ..models.mixtral.mixtral import * +from ..models.grok.grok import * from ..models.llama.llama import * from ..utils.debugging import trace_tensor from ..utils.tokenizer import InferenceTokenizer, load_tokenizer @@ -239,7 +240,7 @@ def main(): ) if config.hp.expert_count: - model = PagedMixtralModelV1(dataset.root_theta, config) + model = PagedGrokModelV1(dataset.root_theta, config) else: model = PagedLlamaModelV1(dataset.root_theta, config) if args.save_intermediates_path: diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 1d76d1d72..471387853 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -31,36 +31,34 @@ def __init__( self.ffn_up = theta.tensor("ffn_up_exps", "weight") self.ffn_down = theta.tensor("ffn_down_exps", "weight") + # def pre_matmul_gather(self, inputs, weights, experts): + # inputs = inputs[:,:] + # weights = weights[experts.reshape(-1), :, :] + # matmul = torch.einsum("mk,mnk->mn", inputs, weights) + # return matmul def pre_matmul_gather(self, inputs, weights, experts): - inputs = inputs[:, :] - weights = weights[experts.reshape(-1), :, :] - matmul = torch.einsum("mk,mkn->mn", inputs, weights) - return matmul + matmul = torch.einsum("mk,bnk->bmn", inputs, weights) + + # Post mix the experts + oh = ( + torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8) + .transpose(0, 1) + .to(torch.float32) + ) + output = torch.einsum("bm,bmn->mn", oh, matmul) + return output def forward( self, h: torch.Tensor, experts: torch.Tensor, ): - ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) + ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts)) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather(ffn_gate * ffn_up, self.ffn_down, experts) return ffn_down -def extract_ffn_layer( - merged_tensor: DefaultPrimitiveTensor, layer_name: str, expert_idx: int -): - # fetches the block_idx from merged_tensor_name. e.g. blk.0.ffn_gate_exps.weight - expert_layer_name = ( - f"blk.{merged_tensor.name.split('.')[1]}.{layer_name}.{expert_idx}.weight" - ) - expert_tensor = DefaultPrimitiveTensor( - name=expert_layer_name, data=merged_tensor.as_torch()[expert_idx] - ) - return expert_tensor - - class FFNMOE(ThetaLayer): def __init__( self, diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index b78925c78..49eaead95 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -170,6 +170,7 @@ def forward( router_weights, top_k_experts = torch.topk( router_weights, self.expert_used_count, dim=-1 ) + router_weights /= router_weights.sum(dim=-1, keepdim=True) router_weights = router_weights.to(ffn_input.dtype) diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index e04ca11fd..607da014d 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -10,21 +10,20 @@ import torch from shark_turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import SparseMoeBlock +from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from sharktank import ops class SparseMoeBlockTest(unittest.TestCase): - @unittest.skip("Skip test until grok implementation") def test(self): - model = SparseMoeBlock( + model = PreGatherMoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, ) fxb = FxProgramsBuilder(model) - input = make_rand_torch((2, 16, 6144)) + input = make_rand_torch((2, 32, 6144)) @fxb.export_program(name="moe_block", args=(input,)) def _(model, input: torch.Tensor) -> torch.Tensor: