Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 11, 2024
1 parent 2e9e006 commit 62aecb3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base import ThetaLayer
from .linear import LinearLayer
from ..types import Theta, DefaultPrimitiveTensor
from ..ops import einsum_2args
from ..ops import einsum_2args, elementwise

__all__ = [
"FFNMOE",
Expand All @@ -32,6 +32,7 @@ def __init__(
self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")
self.activation = activation

def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
inputs = inputs[:, :]
Expand Down Expand Up @@ -63,7 +64,7 @@ def forward(
expert_gate: torch.Tensor,
):
ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
ffn_gate = ops.elementwise(self.activation, ffn_gate)
ffn_gate = elementwise(self.activation, ffn_gate)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/models/llama/moe_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class MoeBlockTest(unittest.TestCase):
def test(self):
model = PreGatherMoeBlock(
model = MoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
Expand Down

0 comments on commit 62aecb3

Please sign in to comment.