From ea900d88daac9de3a6823ee2b4e288299a03416e Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 16:57:35 +0100 Subject: [PATCH] black is black --- onmt/modules/rmsnorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 53df70988e..3d8515aa5f 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -26,12 +26,12 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): def forward(self, hidden_states): if AWQ_INFERENCE_ENGINE: output = torch.empty_like(hidden_states) - if hidden_states.dim() == 2: # patch for multi experts + if hidden_states.dim() == 2: # patch for multi experts hidden_states = hidden_states.unsqueeze(0) awq_inference_engine.layernorm_forward_cuda( hidden_states, self.weight, output, self.eps ) - if hidden_states.dim() == 2: # patch for multi experts + if hidden_states.dim() == 2: # patch for multi experts output = output.unsqueeze(0) return output else: