Skip to content

Commit

Permalink
fix moe-ffn
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Sep 6, 2024
1 parent 6d3d261 commit b5f535d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
39 changes: 22 additions & 17 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,37 @@ def __init__(
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")

# def pre_matmul_gather(self, inputs, weights, experts):
# inputs = inputs[:,:]
# weights = weights[experts.reshape(-1), :, :]
# matmul = torch.einsum("mk,mnk->mn", inputs, weights)
# return matmul
def pre_matmul_gather(self, inputs, weights, experts):
matmul = torch.einsum("mk,bnk->bmn", inputs, weights)

# Post mix the experts
oh = (
torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8)
.transpose(0, 1)
.to(torch.float32)
)
output = torch.einsum("bm,bmn->mn", oh, matmul)
return output
inputs = inputs[:, :]
weights = weights[experts, :, :]
matmul = torch.einsum("mk,menk->men", inputs, weights)
return matmul

def bigger_mmg(self, inputs, weights, experts):
inputs = inputs[:, :]
weights = weights[experts, :, :]
matmul = torch.einsum("mek,menk->men", inputs, weights)
return matmul

# def pre_matmul_gather(self, inputs, weights, experts):
# matmul = torch.einsum("mk,bnk->bmn", inputs, weights)
#
# # Post mix the experts
# oh = torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8).transpose(0, 1).to(torch.float32)
# output = torch.einsum("bm,bmn->mn", oh, matmul)
# return output

def forward(
self,
h: torch.Tensor,
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts))
ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(ffn_gate * ffn_up, self.ffn_down, experts)
return ffn_down
ffn_down = self.bigger_mmg(ffn_gate * ffn_up, self.ffn_down, experts)
ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down)
return torch.sum(ffn_down, dim=1)


class FFNMOE(ThetaLayer):
Expand Down
8 changes: 4 additions & 4 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ def forward(
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

# Select top k experts from router weights
router_weights, top_k_experts = torch.topk(
expert_gate, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)

router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(ffn_input.dtype)
# router_weights /= router_weights.sum(dim=-1, keepdim=True)
# router_weights = router_weights.to(ffn_input.dtype)

moe_output = self.mix(ffn_input, top_k_experts)
moe_output = self.mix(ffn_input, top_k_experts, expert_gate)
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

moe_output = self.layer_output_norm(moe_output)
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/llama/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta
return Theta(
{
"blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor(
data=make_rand_torch((feature_dim, ffn_dim))
data=make_rand_torch((8, ffn_dim))
),
"blk.0.ffn_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((ffn_dim))
Expand Down

0 comments on commit b5f535d

Please sign in to comment.