diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 9c4855091..a4f507578 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -297,17 +297,19 @@ def forward( ) ) else: + # Add singleton dimensions to make shapes compatible for broadcasting: w = einops.rearrange( self.W_O, - "head_index d_head d_model -> d_model head_index d_head", + "head_index d_head d_model -> 1 1 head_index d_head d_model", ) - result = self.hook_result( - einops.einsum( - z, - w, - "... head_index d_head, d_model head_index d_head -> ... head_index d_model", - ) - ) # [batch, pos, head_index, d_model] + z = einops.rearrange( + z, "batch pos head_index d_head -> batch pos head_index d_head 1" + ) + + # Multiply the z tensor by the W_O tensor, summing over the d_head dimension + unhooked_result = (z * w).sum(-2) + + result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] out = ( einops.reduce(result, "batch position index model->batch position model", "sum") + self.b_O