Skip to content

Commit

Permalink
more hack
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Sep 6, 2024
1 parent b7965b1 commit 6d3d261
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
3 changes: 2 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
Expand Down Expand Up @@ -239,7 +240,7 @@ def main():
)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
model = PagedGrokModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)
if args.save_intermediates_path:
Expand Down
34 changes: 16 additions & 18 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,34 @@ def __init__(
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")

# def pre_matmul_gather(self, inputs, weights, experts):
# inputs = inputs[:,:]
# weights = weights[experts.reshape(-1), :, :]
# matmul = torch.einsum("mk,mnk->mn", inputs, weights)
# return matmul
def pre_matmul_gather(self, inputs, weights, experts):
inputs = inputs[:, :]
weights = weights[experts.reshape(-1), :, :]
matmul = torch.einsum("mk,mkn->mn", inputs, weights)
return matmul
matmul = torch.einsum("mk,bnk->bmn", inputs, weights)

# Post mix the experts
oh = (
torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8)
.transpose(0, 1)
.to(torch.float32)
)
output = torch.einsum("bm,bmn->mn", oh, matmul)
return output

def forward(
self,
h: torch.Tensor,
experts: torch.Tensor,
):
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts))
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts))
ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(ffn_gate * ffn_up, self.ffn_down, experts)
return ffn_down


def extract_ffn_layer(
merged_tensor: DefaultPrimitiveTensor, layer_name: str, expert_idx: int
):
# fetches the block_idx from merged_tensor_name. e.g. blk.0.ffn_gate_exps.weight
expert_layer_name = (
f"blk.{merged_tensor.name.split('.')[1]}.{layer_name}.{expert_idx}.weight"
)
expert_tensor = DefaultPrimitiveTensor(
name=expert_layer_name, data=merged_tensor.as_torch()[expert_idx]
)
return expert_tensor


class FFNMOE(ThetaLayer):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def forward(
router_weights, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)

router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(ffn_input.dtype)

Expand Down
7 changes: 3 additions & 4 deletions sharktank/tests/models/llama/moe_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,20 @@
import torch
from shark_turbine.aot import *
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import SparseMoeBlock
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from sharktank import ops


class SparseMoeBlockTest(unittest.TestCase):
@unittest.skip("Skip test until grok implementation")
def test(self):
model = SparseMoeBlock(
model = PreGatherMoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((2, 16, 6144))
input = make_rand_torch((2, 32, 6144))

@fxb.export_program(name="moe_block", args=(input,))
def _(model, input: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 6d3d261

Please sign in to comment.