Skip to content

Commit

Permalink
patch rmsnorm for multiexperts
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Dec 19, 2023
1 parent c920d65 commit 3f81b8f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions onmt/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
def forward(self, hidden_states):
if AWQ_INFERENCE_ENGINE:
output = torch.empty_like(hidden_states)
if hidden_states.dim() == 2: # patch for multi experts
hidden_states = hidden_states.unsqueeze(0)
awq_inference_engine.layernorm_forward_cuda(
hidden_states, self.weight, output, self.eps
)
if hidden_states.dim() == 2: # patch for multi experts
output = output.unsqueeze(0)
return output
else:
hidden_states = hidden_states.to(torch.float32)
Expand Down

0 comments on commit 3f81b8f

Please sign in to comment.